azalea/pathfinder/
astar.rs

1use std::{
2    cmp::{self},
3    collections::BinaryHeap,
4    fmt::Debug,
5    hash::{BuildHasherDefault, Hash},
6    time::{Duration, Instant},
7};
8
9use indexmap::IndexMap;
10use num_format::ToFormattedString;
11use rustc_hash::FxHasher;
12use tracing::{debug, trace, warn};
13
14pub struct Path<P, M>
15where
16    P: Eq + Hash + Copy + Debug,
17{
18    pub movements: Vec<Movement<P, M>>,
19    pub is_partial: bool,
20}
21
22// used for better results when timing out
23// see https://github.com/cabaletta/baritone/blob/1.19.4/src/main/java/baritone/pathing/calc/AbstractNodeCostSearch.java#L68
24const COEFFICIENTS: [f32; 7] = [1.5, 2., 2.5, 3., 4., 5., 10.];
25
26const MIN_IMPROVEMENT: f32 = 0.01;
27
28type FxIndexMap<K, V> = IndexMap<K, V, BuildHasherDefault<FxHasher>>;
29
30// Sources:
31// - https://en.wikipedia.org/wiki/A*_search_algorithm
32// - https://github.com/evenfurther/pathfinding/blob/main/src/directed/astar.rs
33// - https://github.com/cabaletta/baritone/blob/1.19.4/src/main/java/baritone/pathing/calc/AbstractNodeCostSearch.java
34pub fn a_star<P, M, HeuristicFn, SuccessorsFn, SuccessFn>(
35    start: P,
36    heuristic: HeuristicFn,
37    mut successors: SuccessorsFn,
38    success: SuccessFn,
39    min_timeout: PathfinderTimeout,
40    max_timeout: PathfinderTimeout,
41) -> Path<P, M>
42where
43    P: Eq + Hash + Copy + Debug,
44    HeuristicFn: Fn(P) -> f32,
45    SuccessorsFn: FnMut(P) -> Vec<Edge<P, M>>,
46    SuccessFn: Fn(P) -> bool,
47{
48    let start_time = Instant::now();
49
50    let mut open_set = BinaryHeap::<WeightedNode>::new();
51    open_set.push(WeightedNode {
52        g_score: 0.,
53        f_score: 0.,
54        index: 0,
55    });
56    let mut nodes: FxIndexMap<P, Node> = IndexMap::default();
57    nodes.insert(
58        start,
59        Node {
60            came_from: usize::MAX,
61            g_score: 0.,
62        },
63    );
64
65    let mut best_paths: [usize; 7] = [0; 7];
66    let mut best_path_scores: [f32; 7] = [heuristic(start); 7];
67
68    let mut num_nodes = 0;
69
70    while let Some(WeightedNode { index, g_score, .. }) = open_set.pop() {
71        num_nodes += 1;
72
73        let (&node, node_data) = nodes.get_index(index).unwrap();
74        if success(node) {
75            debug!("Nodes considered: {num_nodes}");
76
77            return Path {
78                movements: reconstruct_path(nodes, index, successors),
79                is_partial: false,
80            };
81        }
82
83        if g_score > node_data.g_score {
84            continue;
85        }
86
87        for neighbor in successors(node) {
88            let tentative_g_score = g_score + neighbor.cost;
89            // let neighbor_heuristic = heuristic(neighbor.movement.target);
90            let neighbor_heuristic;
91            let neighbor_index;
92
93            // skip neighbors that don't result in a big enough improvement
94            if tentative_g_score - g_score < MIN_IMPROVEMENT {
95                continue;
96            }
97
98            match nodes.entry(neighbor.movement.target) {
99                indexmap::map::Entry::Occupied(mut e) => {
100                    if e.get().g_score > tentative_g_score {
101                        neighbor_heuristic = heuristic(*e.key());
102                        neighbor_index = e.index();
103                        e.insert(Node {
104                            came_from: index,
105                            g_score: tentative_g_score,
106                        });
107                    } else {
108                        continue;
109                    }
110                }
111                indexmap::map::Entry::Vacant(e) => {
112                    neighbor_heuristic = heuristic(*e.key());
113                    neighbor_index = e.index();
114                    e.insert(Node {
115                        came_from: index,
116                        g_score: tentative_g_score,
117                    });
118                }
119            }
120
121            open_set.push(WeightedNode {
122                index: neighbor_index,
123                g_score: tentative_g_score,
124                f_score: tentative_g_score + neighbor_heuristic,
125            });
126
127            for (coefficient_i, &coefficient) in COEFFICIENTS.iter().enumerate() {
128                let node_score = neighbor_heuristic + tentative_g_score / coefficient;
129                if best_path_scores[coefficient_i] - node_score > MIN_IMPROVEMENT {
130                    best_paths[coefficient_i] = neighbor_index;
131                    best_path_scores[coefficient_i] = node_score;
132                }
133            }
134        }
135
136        // check for timeout every ~10ms
137        if num_nodes % 10000 == 0 {
138            let min_timeout_reached = match min_timeout {
139                PathfinderTimeout::Time(max_duration) => start_time.elapsed() >= max_duration,
140                PathfinderTimeout::Nodes(max_nodes) => num_nodes >= max_nodes,
141            };
142
143            if min_timeout_reached {
144                // means we have a non-empty path
145                if best_paths[6] != 0 {
146                    break;
147                }
148
149                if min_timeout_reached {
150                    let max_timeout_reached = match max_timeout {
151                        PathfinderTimeout::Time(max_duration) => {
152                            start_time.elapsed() >= max_duration
153                        }
154                        PathfinderTimeout::Nodes(max_nodes) => num_nodes >= max_nodes,
155                    };
156
157                    if max_timeout_reached {
158                        // timeout, we're gonna be returning an empty path :(
159                        trace!("A* couldn't find a path in time, returning best path");
160                        break;
161                    }
162                }
163            }
164        }
165    }
166
167    let best_path = determine_best_path(best_paths, 0);
168
169    debug!(
170        "A* ran at {} nodes per second",
171        ((num_nodes as f64 / start_time.elapsed().as_secs_f64()) as u64)
172            .to_formatted_string(&num_format::Locale::en)
173    );
174
175    Path {
176        movements: reconstruct_path(nodes, best_path, successors),
177        is_partial: true,
178    }
179}
180
181fn determine_best_path(best_paths: [usize; 7], start: usize) -> usize {
182    // this basically makes sure we don't create a path that's really short
183
184    for node in best_paths {
185        if node != start {
186            return node;
187        }
188    }
189    warn!("No best node found, returning first node");
190    best_paths[0]
191}
192
193fn reconstruct_path<P, M, SuccessorsFn>(
194    nodes: FxIndexMap<P, Node>,
195    mut current_index: usize,
196    mut successors: SuccessorsFn,
197) -> Vec<Movement<P, M>>
198where
199    P: Eq + Hash + Copy + Debug,
200    SuccessorsFn: FnMut(P) -> Vec<Edge<P, M>>,
201{
202    let mut path = Vec::new();
203    while let Some((&node_position, node)) = nodes.get_index(current_index) {
204        if node.came_from == usize::MAX {
205            break;
206        }
207        let came_from_position = *nodes.get_index(node.came_from).unwrap().0;
208
209        // find the movement data for this successor, we have to do this again because
210        // we don't include the movement data in the Node (as an optimization)
211        let mut best_successor = None;
212        let mut best_successor_cost = f32::INFINITY;
213        for successor in successors(came_from_position) {
214            if successor.movement.target == node_position && successor.cost < best_successor_cost {
215                best_successor_cost = successor.cost;
216                best_successor = Some(successor);
217            }
218        }
219        let found_successor = best_successor.expect("No successor found");
220
221        path.push(Movement {
222            target: node_position,
223            data: found_successor.movement.data,
224        });
225
226        current_index = node.came_from;
227    }
228    path.reverse();
229    path
230}
231
232pub struct Node {
233    pub came_from: usize,
234    pub g_score: f32,
235}
236
237pub struct Edge<P: Hash + Copy, M> {
238    pub movement: Movement<P, M>,
239    pub cost: f32,
240}
241
242pub struct Movement<P: Hash + Copy, M> {
243    pub target: P,
244    pub data: M,
245}
246
247impl<P: Hash + Copy + Debug, M: Debug> Debug for Movement<P, M> {
248    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
249        f.debug_struct("Movement")
250            .field("target", &self.target)
251            .field("data", &self.data)
252            .finish()
253    }
254}
255impl<P: Hash + Copy + Clone, M: Clone> Clone for Movement<P, M> {
256    fn clone(&self) -> Self {
257        Self {
258            target: self.target,
259            data: self.data.clone(),
260        }
261    }
262}
263
264#[derive(PartialEq)]
265pub struct WeightedNode {
266    index: usize,
267    /// The actual cost to get to this node
268    g_score: f32,
269    /// Sum of the g_score and heuristic
270    f_score: f32,
271}
272
273impl Ord for WeightedNode {
274    #[inline]
275    fn cmp(&self, other: &Self) -> cmp::Ordering {
276        // intentionally inverted to make the BinaryHeap a min-heap
277        match other.f_score.total_cmp(&self.f_score) {
278            cmp::Ordering::Equal => self.g_score.total_cmp(&other.g_score),
279            s => s,
280        }
281    }
282}
283impl Eq for WeightedNode {}
284impl PartialOrd for WeightedNode {
285    #[inline]
286    fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
287        Some(self.cmp(other))
288    }
289}
290
291#[derive(Debug, Clone, Copy, PartialEq)]
292pub enum PathfinderTimeout {
293    /// Time out after a certain duration has passed. This is a good default so
294    /// you don't waste too much time calculating a path if you're on a slow
295    /// computer.
296    Time(Duration),
297    /// Time out after this many nodes have been considered.
298    ///
299    /// This is useful as an alternative to a time limit if you're doing
300    /// something like running tests where you want consistent results.
301    Nodes(usize),
302}
303impl Default for PathfinderTimeout {
304    fn default() -> Self {
305        Self::Time(Duration::from_secs(1))
306    }
307}