azalea/pathfinder/
astar.rs

1use std::{
2    cmp::{self},
3    collections::BinaryHeap,
4    fmt::{self, Debug},
5    hash::{BuildHasherDefault, Hash},
6    time::{Duration, Instant},
7};
8
9use indexmap::IndexMap;
10use num_format::ToFormattedString;
11use rustc_hash::FxHasher;
12use tracing::{debug, trace, warn};
13
14pub struct Path<P, M>
15where
16    P: Eq + Hash + Copy + Debug,
17{
18    pub movements: Vec<Movement<P, M>>,
19    pub is_partial: bool,
20}
21
22// used for better results when timing out
23// see https://github.com/cabaletta/baritone/blob/1.19.4/src/main/java/baritone/pathing/calc/AbstractNodeCostSearch.java#L68
24const COEFFICIENTS: [f32; 7] = [1.5, 2., 2.5, 3., 4., 5., 10.];
25
26const MIN_IMPROVEMENT: f32 = 0.01;
27
28type FxIndexMap<K, V> = IndexMap<K, V, BuildHasherDefault<FxHasher>>;
29
30// Sources:
31// - https://en.wikipedia.org/wiki/A*_search_algorithm
32// - https://github.com/evenfurther/pathfinding/blob/main/src/directed/astar.rs
33// - https://github.com/cabaletta/baritone/blob/1.19.4/src/main/java/baritone/pathing/calc/AbstractNodeCostSearch.java
34pub fn a_star<P, M, HeuristicFn, SuccessorsFn, SuccessFn>(
35    start: P,
36    heuristic: HeuristicFn,
37    mut successors: SuccessorsFn,
38    success: SuccessFn,
39    min_timeout: PathfinderTimeout,
40    max_timeout: PathfinderTimeout,
41) -> Path<P, M>
42where
43    P: Eq + Hash + Copy + Debug,
44    HeuristicFn: Fn(P) -> f32,
45    SuccessorsFn: FnMut(P) -> Vec<Edge<P, M>>,
46    SuccessFn: Fn(P) -> bool,
47{
48    let start_time = Instant::now();
49
50    let mut open_set = BinaryHeap::<WeightedNode>::new();
51    open_set.push(WeightedNode {
52        g_score: 0.,
53        f_score: 0.,
54        index: 0,
55    });
56    let mut nodes: FxIndexMap<P, Node> = IndexMap::default();
57    nodes.insert(
58        start,
59        Node {
60            came_from: usize::MAX,
61            g_score: 0.,
62        },
63    );
64
65    let mut best_paths: [usize; 7] = [0; 7];
66    let mut best_path_scores: [f32; 7] = [heuristic(start); 7];
67
68    let mut num_nodes = 0_usize;
69    let mut num_movements = 0;
70
71    while let Some(WeightedNode { index, g_score, .. }) = open_set.pop() {
72        num_nodes += 1;
73
74        let (&node, node_data) = nodes.get_index(index).unwrap();
75        if success(node) {
76            debug!("Nodes considered: {num_nodes}");
77
78            return Path {
79                movements: reconstruct_path(nodes, index, successors),
80                is_partial: false,
81            };
82        }
83
84        if g_score > node_data.g_score {
85            continue;
86        }
87
88        for neighbor in successors(node) {
89            let tentative_g_score = g_score + neighbor.cost;
90            // let neighbor_heuristic = heuristic(neighbor.movement.target);
91            let neighbor_heuristic;
92            let neighbor_index;
93
94            num_movements += 1;
95
96            match nodes.entry(neighbor.movement.target) {
97                indexmap::map::Entry::Occupied(mut e) => {
98                    if e.get().g_score > tentative_g_score {
99                        neighbor_heuristic = heuristic(*e.key());
100                        neighbor_index = e.index();
101                        e.insert(Node {
102                            came_from: index,
103                            g_score: tentative_g_score,
104                        });
105                    } else {
106                        continue;
107                    }
108                }
109                indexmap::map::Entry::Vacant(e) => {
110                    neighbor_heuristic = heuristic(*e.key());
111                    neighbor_index = e.index();
112                    e.insert(Node {
113                        came_from: index,
114                        g_score: tentative_g_score,
115                    });
116                }
117            }
118
119            open_set.push(WeightedNode {
120                index: neighbor_index,
121                g_score: tentative_g_score,
122                f_score: tentative_g_score + neighbor_heuristic,
123            });
124
125            for (coefficient_i, &coefficient) in COEFFICIENTS.iter().enumerate() {
126                let node_score = neighbor_heuristic + tentative_g_score / coefficient;
127                if best_path_scores[coefficient_i] - node_score > MIN_IMPROVEMENT {
128                    best_paths[coefficient_i] = neighbor_index;
129                    best_path_scores[coefficient_i] = node_score;
130                }
131            }
132        }
133
134        // check for timeout every ~10ms
135        if num_nodes.is_multiple_of(10_000) {
136            let min_timeout_reached = match min_timeout {
137                PathfinderTimeout::Time(max_duration) => start_time.elapsed() >= max_duration,
138                PathfinderTimeout::Nodes(max_nodes) => num_nodes >= max_nodes,
139            };
140
141            if min_timeout_reached {
142                // means we have a non-empty path
143                if best_paths[6] != 0 {
144                    break;
145                }
146
147                if min_timeout_reached {
148                    let max_timeout_reached = match max_timeout {
149                        PathfinderTimeout::Time(max_duration) => {
150                            start_time.elapsed() >= max_duration
151                        }
152                        PathfinderTimeout::Nodes(max_nodes) => num_nodes >= max_nodes,
153                    };
154
155                    if max_timeout_reached {
156                        // timeout, we're gonna be returning an empty path :(
157                        trace!("A* couldn't find a path in time, returning best path");
158                        break;
159                    }
160                }
161            }
162        }
163    }
164
165    let best_path = determine_best_path(best_paths, 0);
166
167    let elapsed_seconds = start_time.elapsed().as_secs_f64();
168    let nodes_per_second = (num_nodes as f64 / elapsed_seconds) as u64;
169    let num_movements_per_second = (num_movements as f64 / elapsed_seconds) as u64;
170    debug!(
171        "A* ran at {} nodes per second and {} movements per second",
172        nodes_per_second.to_formatted_string(&num_format::Locale::en),
173        num_movements_per_second.to_formatted_string(&num_format::Locale::en),
174    );
175
176    Path {
177        movements: reconstruct_path(nodes, best_path, successors),
178        is_partial: true,
179    }
180}
181
182fn determine_best_path(best_paths: [usize; 7], start: usize) -> usize {
183    // this basically makes sure we don't create a path that's really short
184
185    for node in best_paths {
186        if node != start {
187            return node;
188        }
189    }
190    warn!("No best node found, returning first node");
191    best_paths[0]
192}
193
194fn reconstruct_path<P, M, SuccessorsFn>(
195    nodes: FxIndexMap<P, Node>,
196    mut current_index: usize,
197    mut successors: SuccessorsFn,
198) -> Vec<Movement<P, M>>
199where
200    P: Eq + Hash + Copy + Debug,
201    SuccessorsFn: FnMut(P) -> Vec<Edge<P, M>>,
202{
203    let mut path = Vec::new();
204    while let Some((&node_position, node)) = nodes.get_index(current_index) {
205        if node.came_from == usize::MAX {
206            break;
207        }
208        let came_from_position = *nodes.get_index(node.came_from).unwrap().0;
209
210        // find the movement data for this successor, we have to do this again because
211        // we don't include the movement data in the Node (as an optimization)
212        let mut best_successor = None;
213        let mut best_successor_cost = f32::INFINITY;
214        for successor in successors(came_from_position) {
215            if successor.movement.target == node_position && successor.cost < best_successor_cost {
216                best_successor_cost = successor.cost;
217                best_successor = Some(successor);
218            }
219        }
220        let Some(found_successor) = best_successor else {
221            warn!(
222                "a successor stopped being possible while reconstructing the path, returning empty path"
223            );
224            return vec![];
225        };
226
227        path.push(Movement {
228            target: node_position,
229            data: found_successor.movement.data,
230        });
231
232        current_index = node.came_from;
233    }
234    path.reverse();
235    path
236}
237
238pub struct Node {
239    pub came_from: usize,
240    pub g_score: f32,
241}
242
243#[derive(Clone, Debug)]
244pub struct Edge<P: Hash + Copy, M> {
245    pub movement: Movement<P, M>,
246    pub cost: f32,
247}
248
249pub struct Movement<P: Hash + Copy, M> {
250    pub target: P,
251    pub data: M,
252}
253
254impl<P: Hash + Copy + Debug, M: Debug> Debug for Movement<P, M> {
255    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
256        f.debug_struct("Movement")
257            .field("target", &self.target)
258            .field("data", &self.data)
259            .finish()
260    }
261}
262impl<P: Hash + Copy + Clone, M: Clone> Clone for Movement<P, M> {
263    fn clone(&self) -> Self {
264        Self {
265            target: self.target,
266            data: self.data.clone(),
267        }
268    }
269}
270
271#[derive(PartialEq)]
272#[repr(C)]
273pub struct WeightedNode {
274    /// Sum of the g_score and heuristic
275    pub f_score: f32,
276    /// The actual cost to get to this node
277    pub g_score: f32,
278    pub index: usize,
279}
280
281impl Ord for WeightedNode {
282    #[inline]
283    fn cmp(&self, other: &Self) -> cmp::Ordering {
284        // intentionally inverted to make the BinaryHeap a min-heap
285        match other.f_score.total_cmp(&self.f_score) {
286            cmp::Ordering::Equal => self.g_score.total_cmp(&other.g_score),
287            s => s,
288        }
289    }
290}
291impl Eq for WeightedNode {}
292impl PartialOrd for WeightedNode {
293    #[inline]
294    fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
295        Some(self.cmp(other))
296    }
297}
298
299#[derive(Debug, Clone, Copy, PartialEq)]
300pub enum PathfinderTimeout {
301    /// Time out after a certain duration has passed. This is a good default so
302    /// you don't waste too much time calculating a path if you're on a slow
303    /// computer.
304    Time(Duration),
305    /// Time out after this many nodes have been considered.
306    ///
307    /// This is useful as an alternative to a time limit if you're doing
308    /// something like running tests where you want consistent results.
309    Nodes(usize),
310}
311impl Default for PathfinderTimeout {
312    fn default() -> Self {
313        Self::Time(Duration::from_secs(1))
314    }
315}
316
317#[cfg(test)]
318mod tests {
319    use super::*;
320
321    fn weighted_node(f: f32, g: f32) -> WeightedNode {
322        WeightedNode {
323            f_score: f,
324            g_score: g,
325            index: 0,
326        }
327    }
328
329    #[test]
330    fn test_weighted_node_eq() {
331        let a = weighted_node(0., 0.);
332        let b = weighted_node(0., 0.);
333        assert!(a == b);
334    }
335    #[test]
336    fn test_weighted_node_le() {
337        let a = weighted_node(1., 0.);
338        let b = weighted_node(0., 0.);
339        assert_eq!(a.cmp(&b), cmp::Ordering::Less);
340        assert!(a.le(&b));
341    }
342    #[test]
343    fn test_weighted_node_le_g() {
344        let a = weighted_node(0., 1.);
345        let b = weighted_node(0., 0.);
346        assert_eq!(a.cmp(&b), cmp::Ordering::Greater);
347        assert!(!a.le(&b));
348    }
349}