azalea/pathfinder/
astar.rs

1use std::{
2    cmp::{self, Reverse},
3    collections::BinaryHeap,
4    fmt::{self, Debug},
5    hash::{BuildHasherDefault, Hash},
6    time::{Duration, Instant},
7};
8
9use indexmap::IndexMap;
10use num_format::ToFormattedString;
11use radix_heap::RadixHeapMap;
12use rustc_hash::FxHasher;
13use tracing::{debug, trace, warn};
14
15pub struct Path<P, M>
16where
17    P: Eq + Hash + Copy + Debug,
18{
19    pub movements: Vec<Movement<P, M>>,
20    pub is_partial: bool,
21    /// The A* cost for executing the path.
22    ///
23    /// For Azalea's pathfinder, this is generally the estimated amount of time
24    /// that it takes to complete the path, in ticks.
25    pub cost: f32,
26}
27
28// used for better results when timing out
29// see https://github.com/cabaletta/baritone/blob/1.19.4/src/main/java/baritone/pathing/calc/AbstractNodeCostSearch.java#L68
30const COEFFICIENTS: [f32; 7] = [1.5, 2., 2.5, 3., 4., 5., 10.];
31
32const MIN_IMPROVEMENT: f32 = 0.01;
33
34type FxIndexMap<K, V> = IndexMap<K, V, BuildHasherDefault<FxHasher>>;
35
36// Sources:
37// - https://en.wikipedia.org/wiki/A*_search_algorithm
38// - https://github.com/evenfurther/pathfinding/blob/main/src/directed/astar.rs
39// - https://github.com/cabaletta/baritone/blob/1.19.4/src/main/java/baritone/pathing/calc/AbstractNodeCostSearch.java
40pub fn a_star<P, M, HeuristicFn, SuccessorsFn, SuccessFn>(
41    start: P,
42    heuristic: HeuristicFn,
43    mut successors: SuccessorsFn,
44    success: SuccessFn,
45    min_timeout: PathfinderTimeout,
46    max_timeout: PathfinderTimeout,
47) -> Path<P, M>
48where
49    P: Eq + Hash + Copy + Debug,
50    HeuristicFn: Fn(P) -> f32,
51    SuccessorsFn: FnMut(P) -> Vec<Edge<P, M>>,
52    SuccessFn: Fn(P) -> bool,
53{
54    let start_time = Instant::now();
55
56    let mut open_set = PathfinderHeap::new();
57    open_set.push(WeightedNode {
58        g_score: 0.,
59        f_score: 0.,
60        index: 0,
61    });
62    let mut nodes: FxIndexMap<P, Node> = IndexMap::default();
63    nodes.insert(
64        start,
65        Node {
66            came_from: u32::MAX,
67            g_score: 0.,
68        },
69    );
70
71    let mut best_paths: [u32; 7] = [0; 7];
72    let mut best_path_scores: [f32; 7] = [heuristic(start); 7];
73
74    let mut num_nodes = 0_usize;
75    let mut num_movements = 0;
76
77    while let Some(WeightedNode { index, g_score, .. }) = open_set.pop() {
78        let (&node, node_data) = nodes.get_index(index as usize).unwrap();
79        if g_score > node_data.g_score {
80            continue;
81        }
82
83        num_nodes += 1;
84
85        if success(node) {
86            let best_path = index;
87            log_perf_info(start_time, num_nodes, num_movements);
88
89            return Path {
90                movements: reconstruct_path(nodes, best_path, successors),
91                is_partial: false,
92                cost: g_score,
93            };
94        }
95
96        for neighbor in successors(node) {
97            let tentative_g_score = g_score + neighbor.cost;
98            // let neighbor_heuristic = heuristic(neighbor.movement.target);
99            let neighbor_heuristic;
100            let neighbor_index;
101
102            num_movements += 1;
103
104            match nodes.entry(neighbor.movement.target) {
105                indexmap::map::Entry::Occupied(mut e) => {
106                    if e.get().g_score > tentative_g_score {
107                        neighbor_heuristic = heuristic(*e.key());
108                        neighbor_index = e.index() as u32;
109                        e.insert(Node {
110                            came_from: index,
111                            g_score: tentative_g_score,
112                        });
113                    } else {
114                        continue;
115                    }
116                }
117                indexmap::map::Entry::Vacant(e) => {
118                    neighbor_heuristic = heuristic(*e.key());
119                    neighbor_index = e.index() as u32;
120                    e.insert(Node {
121                        came_from: index,
122                        g_score: tentative_g_score,
123                    });
124                }
125            }
126
127            // we don't update the existing node, which means that the same node might be
128            // present in the open_set multiple times. this is fine because at the start of
129            // the loop we check `g_score > node_data.g_score`.
130            open_set.push(WeightedNode {
131                index: neighbor_index,
132                g_score: tentative_g_score,
133                f_score: tentative_g_score + neighbor_heuristic,
134            });
135
136            for (coefficient_i, &coefficient) in COEFFICIENTS.iter().enumerate() {
137                let node_score = neighbor_heuristic + tentative_g_score / coefficient;
138                if best_path_scores[coefficient_i] - node_score > MIN_IMPROVEMENT {
139                    best_paths[coefficient_i] = neighbor_index;
140                    best_path_scores[coefficient_i] = node_score;
141                }
142            }
143        }
144
145        // check for timeout every ~10ms
146        if num_nodes.is_multiple_of(10_000) {
147            let min_timeout_reached = match min_timeout {
148                PathfinderTimeout::Time(max_duration) => start_time.elapsed() >= max_duration,
149                PathfinderTimeout::Nodes(max_nodes) => num_nodes >= max_nodes,
150            };
151
152            if min_timeout_reached {
153                // means we have a non-empty path
154                if best_paths[6] != 0 {
155                    break;
156                }
157
158                if min_timeout_reached {
159                    let max_timeout_reached = match max_timeout {
160                        PathfinderTimeout::Time(max_duration) => {
161                            start_time.elapsed() >= max_duration
162                        }
163                        PathfinderTimeout::Nodes(max_nodes) => num_nodes >= max_nodes,
164                    };
165
166                    if max_timeout_reached {
167                        // timeout, we're gonna be returning an empty path :(
168                        trace!("A* couldn't find a path in time, returning best path");
169                        break;
170                    }
171                }
172            }
173        }
174    }
175
176    let best_path_idx = determine_best_path_idx(best_paths, 0);
177    log_perf_info(start_time, num_nodes, num_movements);
178    Path {
179        movements: reconstruct_path(nodes, best_paths[best_path_idx], successors),
180        is_partial: true,
181        cost: best_path_scores[best_path_idx],
182    }
183}
184
185fn log_perf_info(start_time: Instant, num_nodes: usize, num_movements: usize) {
186    let elapsed = start_time.elapsed();
187    let elapsed_seconds = elapsed.as_secs_f64();
188    let nodes_per_second = (num_nodes as f64 / elapsed_seconds) as u64;
189    let num_movements_per_second = (num_movements as f64 / elapsed_seconds) as u64;
190    debug!(
191        "Considered {} nodes in {elapsed:?}",
192        num_nodes.to_formatted_string(&num_format::Locale::en)
193    );
194    debug!(
195        "A* ran at {} nodes per second and {} movements per second",
196        nodes_per_second.to_formatted_string(&num_format::Locale::en),
197        num_movements_per_second.to_formatted_string(&num_format::Locale::en),
198    );
199}
200
201fn determine_best_path_idx(best_paths: [u32; 7], start: u32) -> usize {
202    // this basically makes sure we don't create a path that's really short
203
204    for (i, &node) in best_paths.iter().enumerate() {
205        if node != start {
206            return i;
207        }
208    }
209    warn!("No best node found, returning first node");
210    0
211}
212
213fn reconstruct_path<P, M, SuccessorsFn>(
214    nodes: FxIndexMap<P, Node>,
215    mut current_index: u32,
216    mut successors: SuccessorsFn,
217) -> Vec<Movement<P, M>>
218where
219    P: Eq + Hash + Copy + Debug,
220    SuccessorsFn: FnMut(P) -> Vec<Edge<P, M>>,
221{
222    let mut path = Vec::new();
223    while let Some((&node_position, node)) = nodes.get_index(current_index as usize) {
224        if node.came_from == u32::MAX {
225            break;
226        }
227        let came_from_position = *nodes.get_index(node.came_from as usize).unwrap().0;
228
229        // find the movement data for this successor, we have to do this again because
230        // we don't include the movement data in the Node (as an optimization)
231        let mut best_successor = None;
232        let mut best_successor_cost = f32::INFINITY;
233        for successor in successors(came_from_position) {
234            if successor.movement.target == node_position && successor.cost < best_successor_cost {
235                best_successor_cost = successor.cost;
236                best_successor = Some(successor);
237            }
238        }
239        let Some(found_successor) = best_successor else {
240            warn!(
241                "a successor stopped being possible while reconstructing the path, returning empty path"
242            );
243            return vec![];
244        };
245
246        path.push(Movement {
247            target: node_position,
248            data: found_successor.movement.data,
249        });
250
251        current_index = node.came_from;
252    }
253    path.reverse();
254    path
255}
256
257pub struct Node {
258    pub came_from: u32,
259    pub g_score: f32,
260}
261
262#[derive(Clone, Debug)]
263pub struct Edge<P: Hash + Copy, M> {
264    pub movement: Movement<P, M>,
265    pub cost: f32,
266}
267
268pub struct Movement<P: Hash + Copy, M> {
269    pub target: P,
270    pub data: M,
271}
272
273impl<P: Hash + Copy + Debug, M: Debug> Debug for Movement<P, M> {
274    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
275        f.debug_struct("Movement")
276            .field("target", &self.target)
277            .field("data", &self.data)
278            .finish()
279    }
280}
281impl<P: Hash + Copy + Clone, M: Clone> Clone for Movement<P, M> {
282    fn clone(&self) -> Self {
283        Self {
284            target: self.target,
285            data: self.data.clone(),
286        }
287    }
288}
289
290#[derive(Default)]
291struct PathfinderHeap {
292    binary_heap: BinaryHeap<WeightedNode>,
293    /// Key is f_score.to_bits(), value is (g_score, index)
294    ///
295    /// As long as the f_score is positive, comparing it as bits is fine. Also,
296    /// it has to be `Reverse`d to make it a min-heap.
297    radix_heap: RadixHeapMap<Reverse<u32>, (f32, u32)>,
298}
299impl PathfinderHeap {
300    pub fn new() -> Self {
301        Self::default()
302    }
303
304    pub fn push(&mut self, item: WeightedNode) {
305        if let Some(top) = self.radix_heap.top() {
306            // this can happen when the heuristic wasn't an underestimate, so just fall back
307            // to a binary heap in those cases
308            if item.f_score < f32::from_bits(top.0) {
309                self.binary_heap.push(item);
310                return;
311            }
312        }
313        self.radix_heap
314            .push(Reverse(item.f_score.to_bits()), (item.g_score, item.index))
315    }
316    pub fn pop(&mut self) -> Option<WeightedNode> {
317        self.binary_heap.pop().or_else(|| {
318            self.radix_heap
319                .pop()
320                .map(|(f_score, (g_score, index))| WeightedNode {
321                    f_score: f32::from_bits(f_score.0),
322                    g_score,
323                    index,
324                })
325        })
326    }
327}
328
329#[derive(PartialEq)]
330#[repr(C)]
331pub struct WeightedNode {
332    /// Sum of the g_score and heuristic
333    pub f_score: f32,
334    /// The actual cost to get to this node
335    pub g_score: f32,
336    pub index: u32,
337}
338
339impl Ord for WeightedNode {
340    #[inline]
341    fn cmp(&self, other: &Self) -> cmp::Ordering {
342        // intentionally inverted to make the BinaryHeap a min-heap
343        match other.f_score.total_cmp(&self.f_score) {
344            cmp::Ordering::Equal => self.g_score.total_cmp(&other.g_score),
345            s => s,
346        }
347    }
348}
349impl Eq for WeightedNode {}
350impl PartialOrd for WeightedNode {
351    #[inline]
352    fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
353        Some(self.cmp(other))
354    }
355}
356
357/// A timeout that the pathfinder will consider when calculating a path.
358///
359/// See [`PathfinderOpts::min_timeout`] and [`PathfinderOpts::max_timeout`] if
360/// you want to modify this.
361///
362/// [`PathfinderOpts::min_timeout`]: super::goto_event::PathfinderOpts::min_timeout
363/// [`PathfinderOpts::max_timeout`]: super::goto_event::PathfinderOpts::max_timeout
364#[derive(Clone, Copy, Debug, PartialEq)]
365pub enum PathfinderTimeout {
366    /// Time out after a certain duration has passed.
367    ///
368    /// This is a good default so you don't waste too much time calculating a
369    /// path if you're on a slow computer.
370    Time(Duration),
371    /// Time out after this many nodes have been considered.
372    ///
373    /// This is useful as an alternative to a time limit if you're doing
374    /// something like running tests where you want consistent results.
375    Nodes(usize),
376}
377impl Default for PathfinderTimeout {
378    fn default() -> Self {
379        Self::Time(Duration::from_secs(1))
380    }
381}
382
383#[cfg(test)]
384mod tests {
385    use super::*;
386
387    fn weighted_node(f: f32, g: f32) -> WeightedNode {
388        WeightedNode {
389            f_score: f,
390            g_score: g,
391            index: 0,
392        }
393    }
394
395    #[test]
396    fn test_weighted_node_eq() {
397        let a = weighted_node(0., 0.);
398        let b = weighted_node(0., 0.);
399        assert!(a == b);
400    }
401    #[test]
402    fn test_weighted_node_le() {
403        let a = weighted_node(1., 0.);
404        let b = weighted_node(0., 0.);
405        assert_eq!(a.cmp(&b), cmp::Ordering::Less);
406        assert!(a.le(&b));
407    }
408    #[test]
409    fn test_weighted_node_le_g() {
410        let a = weighted_node(0., 1.);
411        let b = weighted_node(0., 0.);
412        assert_eq!(a.cmp(&b), cmp::Ordering::Greater);
413        assert!(!a.le(&b));
414    }
415}