azalea/pathfinder/
astar.rs

1use std::{
2    cmp::{self},
3    collections::BinaryHeap,
4    fmt::{self, Debug},
5    hash::{BuildHasherDefault, Hash},
6    time::{Duration, Instant},
7};
8
9use indexmap::IndexMap;
10use num_format::ToFormattedString;
11use rustc_hash::FxHasher;
12use tracing::{debug, trace, warn};
13
14pub struct Path<P, M>
15where
16    P: Eq + Hash + Copy + Debug,
17{
18    pub movements: Vec<Movement<P, M>>,
19    pub is_partial: bool,
20}
21
22// used for better results when timing out
23// see https://github.com/cabaletta/baritone/blob/1.19.4/src/main/java/baritone/pathing/calc/AbstractNodeCostSearch.java#L68
24const COEFFICIENTS: [f32; 7] = [1.5, 2., 2.5, 3., 4., 5., 10.];
25
26const MIN_IMPROVEMENT: f32 = 0.01;
27
28type FxIndexMap<K, V> = IndexMap<K, V, BuildHasherDefault<FxHasher>>;
29
30// Sources:
31// - https://en.wikipedia.org/wiki/A*_search_algorithm
32// - https://github.com/evenfurther/pathfinding/blob/main/src/directed/astar.rs
33// - https://github.com/cabaletta/baritone/blob/1.19.4/src/main/java/baritone/pathing/calc/AbstractNodeCostSearch.java
34pub fn a_star<P, M, HeuristicFn, SuccessorsFn, SuccessFn>(
35    start: P,
36    heuristic: HeuristicFn,
37    mut successors: SuccessorsFn,
38    success: SuccessFn,
39    min_timeout: PathfinderTimeout,
40    max_timeout: PathfinderTimeout,
41) -> Path<P, M>
42where
43    P: Eq + Hash + Copy + Debug,
44    HeuristicFn: Fn(P) -> f32,
45    SuccessorsFn: FnMut(P) -> Vec<Edge<P, M>>,
46    SuccessFn: Fn(P) -> bool,
47{
48    let start_time = Instant::now();
49
50    let mut open_set = BinaryHeap::<WeightedNode>::new();
51    open_set.push(WeightedNode {
52        g_score: 0.,
53        f_score: 0.,
54        index: 0,
55    });
56    let mut nodes: FxIndexMap<P, Node> = IndexMap::default();
57    nodes.insert(
58        start,
59        Node {
60            came_from: usize::MAX,
61            g_score: 0.,
62        },
63    );
64
65    let mut best_paths: [usize; 7] = [0; 7];
66    let mut best_path_scores: [f32; 7] = [heuristic(start); 7];
67
68    let mut num_nodes = 0_usize;
69    let mut num_movements = 0;
70
71    while let Some(WeightedNode { index, g_score, .. }) = open_set.pop() {
72        num_nodes += 1;
73
74        let (&node, node_data) = nodes.get_index(index).unwrap();
75        if success(node) {
76            let best_path = index;
77            log_perf_info(start_time, num_nodes, num_movements);
78
79            return Path {
80                movements: reconstruct_path(nodes, best_path, successors),
81                is_partial: false,
82            };
83        }
84
85        if g_score > node_data.g_score {
86            continue;
87        }
88
89        for neighbor in successors(node) {
90            let tentative_g_score = g_score + neighbor.cost;
91            // let neighbor_heuristic = heuristic(neighbor.movement.target);
92            let neighbor_heuristic;
93            let neighbor_index;
94
95            num_movements += 1;
96
97            match nodes.entry(neighbor.movement.target) {
98                indexmap::map::Entry::Occupied(mut e) => {
99                    if e.get().g_score > tentative_g_score {
100                        neighbor_heuristic = heuristic(*e.key());
101                        neighbor_index = e.index();
102                        e.insert(Node {
103                            came_from: index,
104                            g_score: tentative_g_score,
105                        });
106                    } else {
107                        continue;
108                    }
109                }
110                indexmap::map::Entry::Vacant(e) => {
111                    neighbor_heuristic = heuristic(*e.key());
112                    neighbor_index = e.index();
113                    e.insert(Node {
114                        came_from: index,
115                        g_score: tentative_g_score,
116                    });
117                }
118            }
119
120            open_set.push(WeightedNode {
121                index: neighbor_index,
122                g_score: tentative_g_score,
123                f_score: tentative_g_score + neighbor_heuristic,
124            });
125
126            for (coefficient_i, &coefficient) in COEFFICIENTS.iter().enumerate() {
127                let node_score = neighbor_heuristic + tentative_g_score / coefficient;
128                if best_path_scores[coefficient_i] - node_score > MIN_IMPROVEMENT {
129                    best_paths[coefficient_i] = neighbor_index;
130                    best_path_scores[coefficient_i] = node_score;
131                }
132            }
133        }
134
135        // check for timeout every ~10ms
136        if num_nodes.is_multiple_of(10_000) {
137            let min_timeout_reached = match min_timeout {
138                PathfinderTimeout::Time(max_duration) => start_time.elapsed() >= max_duration,
139                PathfinderTimeout::Nodes(max_nodes) => num_nodes >= max_nodes,
140            };
141
142            if min_timeout_reached {
143                // means we have a non-empty path
144                if best_paths[6] != 0 {
145                    break;
146                }
147
148                if min_timeout_reached {
149                    let max_timeout_reached = match max_timeout {
150                        PathfinderTimeout::Time(max_duration) => {
151                            start_time.elapsed() >= max_duration
152                        }
153                        PathfinderTimeout::Nodes(max_nodes) => num_nodes >= max_nodes,
154                    };
155
156                    if max_timeout_reached {
157                        // timeout, we're gonna be returning an empty path :(
158                        trace!("A* couldn't find a path in time, returning best path");
159                        break;
160                    }
161                }
162            }
163        }
164    }
165
166    let best_path = determine_best_path(best_paths, 0);
167    log_perf_info(start_time, num_nodes, num_movements);
168    Path {
169        movements: reconstruct_path(nodes, best_path, successors),
170        is_partial: true,
171    }
172}
173
174fn log_perf_info(start_time: Instant, num_nodes: usize, num_movements: usize) {
175    let elapsed_seconds = start_time.elapsed().as_secs_f64();
176    let nodes_per_second = (num_nodes as f64 / elapsed_seconds) as u64;
177    let num_movements_per_second = (num_movements as f64 / elapsed_seconds) as u64;
178    debug!(
179        "Nodes considered: {}",
180        num_nodes.to_formatted_string(&num_format::Locale::en)
181    );
182    debug!(
183        "A* ran at {} nodes per second and {} movements per second",
184        nodes_per_second.to_formatted_string(&num_format::Locale::en),
185        num_movements_per_second.to_formatted_string(&num_format::Locale::en),
186    );
187}
188
189fn determine_best_path(best_paths: [usize; 7], start: usize) -> usize {
190    // this basically makes sure we don't create a path that's really short
191
192    for node in best_paths {
193        if node != start {
194            return node;
195        }
196    }
197    warn!("No best node found, returning first node");
198    best_paths[0]
199}
200
201fn reconstruct_path<P, M, SuccessorsFn>(
202    nodes: FxIndexMap<P, Node>,
203    mut current_index: usize,
204    mut successors: SuccessorsFn,
205) -> Vec<Movement<P, M>>
206where
207    P: Eq + Hash + Copy + Debug,
208    SuccessorsFn: FnMut(P) -> Vec<Edge<P, M>>,
209{
210    let mut path = Vec::new();
211    while let Some((&node_position, node)) = nodes.get_index(current_index) {
212        if node.came_from == usize::MAX {
213            break;
214        }
215        let came_from_position = *nodes.get_index(node.came_from).unwrap().0;
216
217        // find the movement data for this successor, we have to do this again because
218        // we don't include the movement data in the Node (as an optimization)
219        let mut best_successor = None;
220        let mut best_successor_cost = f32::INFINITY;
221        for successor in successors(came_from_position) {
222            if successor.movement.target == node_position && successor.cost < best_successor_cost {
223                best_successor_cost = successor.cost;
224                best_successor = Some(successor);
225            }
226        }
227        let Some(found_successor) = best_successor else {
228            warn!(
229                "a successor stopped being possible while reconstructing the path, returning empty path"
230            );
231            return vec![];
232        };
233
234        path.push(Movement {
235            target: node_position,
236            data: found_successor.movement.data,
237        });
238
239        current_index = node.came_from;
240    }
241    path.reverse();
242    path
243}
244
245pub struct Node {
246    pub came_from: usize,
247    pub g_score: f32,
248}
249
250#[derive(Clone, Debug)]
251pub struct Edge<P: Hash + Copy, M> {
252    pub movement: Movement<P, M>,
253    pub cost: f32,
254}
255
256pub struct Movement<P: Hash + Copy, M> {
257    pub target: P,
258    pub data: M,
259}
260
261impl<P: Hash + Copy + Debug, M: Debug> Debug for Movement<P, M> {
262    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
263        f.debug_struct("Movement")
264            .field("target", &self.target)
265            .field("data", &self.data)
266            .finish()
267    }
268}
269impl<P: Hash + Copy + Clone, M: Clone> Clone for Movement<P, M> {
270    fn clone(&self) -> Self {
271        Self {
272            target: self.target,
273            data: self.data.clone(),
274        }
275    }
276}
277
278#[derive(PartialEq)]
279#[repr(C)]
280pub struct WeightedNode {
281    /// Sum of the g_score and heuristic
282    pub f_score: f32,
283    /// The actual cost to get to this node
284    pub g_score: f32,
285    pub index: usize,
286}
287
288impl Ord for WeightedNode {
289    #[inline]
290    fn cmp(&self, other: &Self) -> cmp::Ordering {
291        // intentionally inverted to make the BinaryHeap a min-heap
292        match other.f_score.total_cmp(&self.f_score) {
293            cmp::Ordering::Equal => self.g_score.total_cmp(&other.g_score),
294            s => s,
295        }
296    }
297}
298impl Eq for WeightedNode {}
299impl PartialOrd for WeightedNode {
300    #[inline]
301    fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
302        Some(self.cmp(other))
303    }
304}
305
306/// A timeout that the pathfinder will consider when calculating a path.
307///
308/// See [`PathfinderOpts::min_timeout`] and [`PathfinderOpts::max_timeout`] if
309/// you want to modify this.
310///
311/// [`PathfinderOpts::min_timeout`]: super::goto_event::PathfinderOpts::min_timeout
312/// [`PathfinderOpts::max_timeout`]: super::goto_event::PathfinderOpts::max_timeout
313#[derive(Debug, Clone, Copy, PartialEq)]
314pub enum PathfinderTimeout {
315    /// Time out after a certain duration has passed. This is a good default so
316    /// you don't waste too much time calculating a path if you're on a slow
317    /// computer.
318    Time(Duration),
319    /// Time out after this many nodes have been considered.
320    ///
321    /// This is useful as an alternative to a time limit if you're doing
322    /// something like running tests where you want consistent results.
323    Nodes(usize),
324}
325impl Default for PathfinderTimeout {
326    fn default() -> Self {
327        Self::Time(Duration::from_secs(1))
328    }
329}
330
331#[cfg(test)]
332mod tests {
333    use super::*;
334
335    fn weighted_node(f: f32, g: f32) -> WeightedNode {
336        WeightedNode {
337            f_score: f,
338            g_score: g,
339            index: 0,
340        }
341    }
342
343    #[test]
344    fn test_weighted_node_eq() {
345        let a = weighted_node(0., 0.);
346        let b = weighted_node(0., 0.);
347        assert!(a == b);
348    }
349    #[test]
350    fn test_weighted_node_le() {
351        let a = weighted_node(1., 0.);
352        let b = weighted_node(0., 0.);
353        assert_eq!(a.cmp(&b), cmp::Ordering::Less);
354        assert!(a.le(&b));
355    }
356    #[test]
357    fn test_weighted_node_le_g() {
358        let a = weighted_node(0., 1.);
359        let b = weighted_node(0., 0.);
360        assert_eq!(a.cmp(&b), cmp::Ordering::Greater);
361        assert!(!a.le(&b));
362    }
363}