azalea/pathfinder/
astar.rs

1use std::{
2    cmp::{self},
3    collections::BinaryHeap,
4    fmt::{self, Debug},
5    hash::{BuildHasherDefault, Hash},
6    time::{Duration, Instant},
7};
8
9use indexmap::IndexMap;
10use num_format::ToFormattedString;
11use rustc_hash::FxHasher;
12use tracing::{debug, trace, warn};
13
14pub struct Path<P, M>
15where
16    P: Eq + Hash + Copy + Debug,
17{
18    pub movements: Vec<Movement<P, M>>,
19    pub is_partial: bool,
20}
21
22// used for better results when timing out
23// see https://github.com/cabaletta/baritone/blob/1.19.4/src/main/java/baritone/pathing/calc/AbstractNodeCostSearch.java#L68
24const COEFFICIENTS: [f32; 7] = [1.5, 2., 2.5, 3., 4., 5., 10.];
25
26const MIN_IMPROVEMENT: f32 = 0.01;
27
28type FxIndexMap<K, V> = IndexMap<K, V, BuildHasherDefault<FxHasher>>;
29
30// Sources:
31// - https://en.wikipedia.org/wiki/A*_search_algorithm
32// - https://github.com/evenfurther/pathfinding/blob/main/src/directed/astar.rs
33// - https://github.com/cabaletta/baritone/blob/1.19.4/src/main/java/baritone/pathing/calc/AbstractNodeCostSearch.java
34pub fn a_star<P, M, HeuristicFn, SuccessorsFn, SuccessFn>(
35    start: P,
36    heuristic: HeuristicFn,
37    mut successors: SuccessorsFn,
38    success: SuccessFn,
39    min_timeout: PathfinderTimeout,
40    max_timeout: PathfinderTimeout,
41) -> Path<P, M>
42where
43    P: Eq + Hash + Copy + Debug,
44    HeuristicFn: Fn(P) -> f32,
45    SuccessorsFn: FnMut(P) -> Vec<Edge<P, M>>,
46    SuccessFn: Fn(P) -> bool,
47{
48    let start_time = Instant::now();
49
50    let mut open_set = BinaryHeap::<WeightedNode>::new();
51    open_set.push(WeightedNode {
52        g_score: 0.,
53        f_score: 0.,
54        index: 0,
55    });
56    let mut nodes: FxIndexMap<P, Node> = IndexMap::default();
57    nodes.insert(
58        start,
59        Node {
60            came_from: usize::MAX,
61            g_score: 0.,
62        },
63    );
64
65    let mut best_paths: [usize; 7] = [0; 7];
66    let mut best_path_scores: [f32; 7] = [heuristic(start); 7];
67
68    let mut num_nodes = 0;
69    let mut num_movements = 0;
70
71    while let Some(WeightedNode { index, g_score, .. }) = open_set.pop() {
72        num_nodes += 1;
73
74        let (&node, node_data) = nodes.get_index(index).unwrap();
75        if success(node) {
76            debug!("Nodes considered: {num_nodes}");
77
78            return Path {
79                movements: reconstruct_path(nodes, index, successors),
80                is_partial: false,
81            };
82        }
83
84        if g_score > node_data.g_score {
85            continue;
86        }
87
88        for neighbor in successors(node) {
89            let tentative_g_score = g_score + neighbor.cost;
90            // let neighbor_heuristic = heuristic(neighbor.movement.target);
91            let neighbor_heuristic;
92            let neighbor_index;
93
94            // skip neighbors that don't result in a big enough improvement
95            if tentative_g_score - g_score < MIN_IMPROVEMENT {
96                continue;
97            }
98            num_movements += 1;
99
100            match nodes.entry(neighbor.movement.target) {
101                indexmap::map::Entry::Occupied(mut e) => {
102                    if e.get().g_score > tentative_g_score {
103                        neighbor_heuristic = heuristic(*e.key());
104                        neighbor_index = e.index();
105                        e.insert(Node {
106                            came_from: index,
107                            g_score: tentative_g_score,
108                        });
109                    } else {
110                        continue;
111                    }
112                }
113                indexmap::map::Entry::Vacant(e) => {
114                    neighbor_heuristic = heuristic(*e.key());
115                    neighbor_index = e.index();
116                    e.insert(Node {
117                        came_from: index,
118                        g_score: tentative_g_score,
119                    });
120                }
121            }
122
123            open_set.push(WeightedNode {
124                index: neighbor_index,
125                g_score: tentative_g_score,
126                f_score: tentative_g_score + neighbor_heuristic,
127            });
128
129            for (coefficient_i, &coefficient) in COEFFICIENTS.iter().enumerate() {
130                let node_score = neighbor_heuristic + tentative_g_score / coefficient;
131                if best_path_scores[coefficient_i] - node_score > MIN_IMPROVEMENT {
132                    best_paths[coefficient_i] = neighbor_index;
133                    best_path_scores[coefficient_i] = node_score;
134                }
135            }
136        }
137
138        // check for timeout every ~10ms
139        if num_nodes % 10000 == 0 {
140            let min_timeout_reached = match min_timeout {
141                PathfinderTimeout::Time(max_duration) => start_time.elapsed() >= max_duration,
142                PathfinderTimeout::Nodes(max_nodes) => num_nodes >= max_nodes,
143            };
144
145            if min_timeout_reached {
146                // means we have a non-empty path
147                if best_paths[6] != 0 {
148                    break;
149                }
150
151                if min_timeout_reached {
152                    let max_timeout_reached = match max_timeout {
153                        PathfinderTimeout::Time(max_duration) => {
154                            start_time.elapsed() >= max_duration
155                        }
156                        PathfinderTimeout::Nodes(max_nodes) => num_nodes >= max_nodes,
157                    };
158
159                    if max_timeout_reached {
160                        // timeout, we're gonna be returning an empty path :(
161                        trace!("A* couldn't find a path in time, returning best path");
162                        break;
163                    }
164                }
165            }
166        }
167    }
168
169    let best_path = determine_best_path(best_paths, 0);
170
171    let elapsed_seconds = start_time.elapsed().as_secs_f64();
172    let nodes_per_second = (num_nodes as f64 / elapsed_seconds) as u64;
173    let num_movements_per_second = (num_movements as f64 / elapsed_seconds) as u64;
174    debug!(
175        "A* ran at {} nodes per second and {} movements per second",
176        nodes_per_second.to_formatted_string(&num_format::Locale::en),
177        num_movements_per_second.to_formatted_string(&num_format::Locale::en),
178    );
179
180    Path {
181        movements: reconstruct_path(nodes, best_path, successors),
182        is_partial: true,
183    }
184}
185
186fn determine_best_path(best_paths: [usize; 7], start: usize) -> usize {
187    // this basically makes sure we don't create a path that's really short
188
189    for node in best_paths {
190        if node != start {
191            return node;
192        }
193    }
194    warn!("No best node found, returning first node");
195    best_paths[0]
196}
197
198fn reconstruct_path<P, M, SuccessorsFn>(
199    nodes: FxIndexMap<P, Node>,
200    mut current_index: usize,
201    mut successors: SuccessorsFn,
202) -> Vec<Movement<P, M>>
203where
204    P: Eq + Hash + Copy + Debug,
205    SuccessorsFn: FnMut(P) -> Vec<Edge<P, M>>,
206{
207    let mut path = Vec::new();
208    while let Some((&node_position, node)) = nodes.get_index(current_index) {
209        if node.came_from == usize::MAX {
210            break;
211        }
212        let came_from_position = *nodes.get_index(node.came_from).unwrap().0;
213
214        // find the movement data for this successor, we have to do this again because
215        // we don't include the movement data in the Node (as an optimization)
216        let mut best_successor = None;
217        let mut best_successor_cost = f32::INFINITY;
218        for successor in successors(came_from_position) {
219            if successor.movement.target == node_position && successor.cost < best_successor_cost {
220                best_successor_cost = successor.cost;
221                best_successor = Some(successor);
222            }
223        }
224        let Some(found_successor) = best_successor else {
225            warn!(
226                "a successor stopped being possible while reconstructing the path, returning empty path"
227            );
228            return vec![];
229        };
230
231        path.push(Movement {
232            target: node_position,
233            data: found_successor.movement.data,
234        });
235
236        current_index = node.came_from;
237    }
238    path.reverse();
239    path
240}
241
242pub struct Node {
243    pub came_from: usize,
244    pub g_score: f32,
245}
246
247#[derive(Clone, Debug)]
248pub struct Edge<P: Hash + Copy, M> {
249    pub movement: Movement<P, M>,
250    pub cost: f32,
251}
252
253pub struct Movement<P: Hash + Copy, M> {
254    pub target: P,
255    pub data: M,
256}
257
258impl<P: Hash + Copy + Debug, M: Debug> Debug for Movement<P, M> {
259    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
260        f.debug_struct("Movement")
261            .field("target", &self.target)
262            .field("data", &self.data)
263            .finish()
264    }
265}
266impl<P: Hash + Copy + Clone, M: Clone> Clone for Movement<P, M> {
267    fn clone(&self) -> Self {
268        Self {
269            target: self.target,
270            data: self.data.clone(),
271        }
272    }
273}
274
275#[derive(PartialEq)]
276#[repr(C)]
277pub struct WeightedNode {
278    /// Sum of the g_score and heuristic
279    pub f_score: f32,
280    /// The actual cost to get to this node
281    pub g_score: f32,
282    pub index: usize,
283}
284
285impl Ord for WeightedNode {
286    #[inline]
287    fn cmp(&self, other: &Self) -> cmp::Ordering {
288        // intentionally inverted to make the BinaryHeap a min-heap
289        match other.f_score.total_cmp(&self.f_score) {
290            cmp::Ordering::Equal => self.g_score.total_cmp(&other.g_score),
291            s => s,
292        }
293    }
294}
295impl Eq for WeightedNode {}
296impl PartialOrd for WeightedNode {
297    #[inline]
298    fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
299        Some(self.cmp(other))
300    }
301}
302
303#[derive(Debug, Clone, Copy, PartialEq)]
304pub enum PathfinderTimeout {
305    /// Time out after a certain duration has passed. This is a good default so
306    /// you don't waste too much time calculating a path if you're on a slow
307    /// computer.
308    Time(Duration),
309    /// Time out after this many nodes have been considered.
310    ///
311    /// This is useful as an alternative to a time limit if you're doing
312    /// something like running tests where you want consistent results.
313    Nodes(usize),
314}
315impl Default for PathfinderTimeout {
316    fn default() -> Self {
317        Self::Time(Duration::from_secs(1))
318    }
319}
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324
325    fn weighted_node(f: f32, g: f32) -> WeightedNode {
326        WeightedNode {
327            f_score: f,
328            g_score: g,
329            index: 0,
330        }
331    }
332
333    #[test]
334    fn test_weighted_node_eq() {
335        let a = weighted_node(0., 0.);
336        let b = weighted_node(0., 0.);
337        assert!(a == b);
338    }
339    #[test]
340    fn test_weighted_node_le() {
341        let a = weighted_node(1., 0.);
342        let b = weighted_node(0., 0.);
343        assert_eq!(a.cmp(&b), cmp::Ordering::Less);
344        assert!(a.le(&b));
345    }
346    #[test]
347    fn test_weighted_node_le_g() {
348        let a = weighted_node(0., 1.);
349        let b = weighted_node(0., 0.);
350        assert_eq!(a.cmp(&b), cmp::Ordering::Greater);
351        assert!(!a.le(&b));
352    }
353}