1use std::{
2 cmp::{self},
3 collections::BinaryHeap,
4 fmt::{self, Debug},
5 hash::{BuildHasherDefault, Hash},
6 time::{Duration, Instant},
7};
8
9use indexmap::IndexMap;
10use num_format::ToFormattedString;
11use rustc_hash::FxHasher;
12use tracing::{debug, trace, warn};
13
14pub struct Path<P, M>
15where
16 P: Eq + Hash + Copy + Debug,
17{
18 pub movements: Vec<Movement<P, M>>,
19 pub is_partial: bool,
20}
21
22const COEFFICIENTS: [f32; 7] = [1.5, 2., 2.5, 3., 4., 5., 10.];
25
26const MIN_IMPROVEMENT: f32 = 0.01;
27
28type FxIndexMap<K, V> = IndexMap<K, V, BuildHasherDefault<FxHasher>>;
29
30pub fn a_star<P, M, HeuristicFn, SuccessorsFn, SuccessFn>(
35 start: P,
36 heuristic: HeuristicFn,
37 mut successors: SuccessorsFn,
38 success: SuccessFn,
39 min_timeout: PathfinderTimeout,
40 max_timeout: PathfinderTimeout,
41) -> Path<P, M>
42where
43 P: Eq + Hash + Copy + Debug,
44 HeuristicFn: Fn(P) -> f32,
45 SuccessorsFn: FnMut(P) -> Vec<Edge<P, M>>,
46 SuccessFn: Fn(P) -> bool,
47{
48 let start_time = Instant::now();
49
50 let mut open_set = BinaryHeap::<WeightedNode>::new();
51 open_set.push(WeightedNode {
52 g_score: 0.,
53 f_score: 0.,
54 index: 0,
55 });
56 let mut nodes: FxIndexMap<P, Node> = IndexMap::default();
57 nodes.insert(
58 start,
59 Node {
60 came_from: usize::MAX,
61 g_score: 0.,
62 },
63 );
64
65 let mut best_paths: [usize; 7] = [0; 7];
66 let mut best_path_scores: [f32; 7] = [heuristic(start); 7];
67
68 let mut num_nodes = 0;
69 let mut num_movements = 0;
70
71 while let Some(WeightedNode { index, g_score, .. }) = open_set.pop() {
72 num_nodes += 1;
73
74 let (&node, node_data) = nodes.get_index(index).unwrap();
75 if success(node) {
76 debug!("Nodes considered: {num_nodes}");
77
78 return Path {
79 movements: reconstruct_path(nodes, index, successors),
80 is_partial: false,
81 };
82 }
83
84 if g_score > node_data.g_score {
85 continue;
86 }
87
88 for neighbor in successors(node) {
89 let tentative_g_score = g_score + neighbor.cost;
90 let neighbor_heuristic;
92 let neighbor_index;
93
94 if tentative_g_score - g_score < MIN_IMPROVEMENT {
96 continue;
97 }
98 num_movements += 1;
99
100 match nodes.entry(neighbor.movement.target) {
101 indexmap::map::Entry::Occupied(mut e) => {
102 if e.get().g_score > tentative_g_score {
103 neighbor_heuristic = heuristic(*e.key());
104 neighbor_index = e.index();
105 e.insert(Node {
106 came_from: index,
107 g_score: tentative_g_score,
108 });
109 } else {
110 continue;
111 }
112 }
113 indexmap::map::Entry::Vacant(e) => {
114 neighbor_heuristic = heuristic(*e.key());
115 neighbor_index = e.index();
116 e.insert(Node {
117 came_from: index,
118 g_score: tentative_g_score,
119 });
120 }
121 }
122
123 open_set.push(WeightedNode {
124 index: neighbor_index,
125 g_score: tentative_g_score,
126 f_score: tentative_g_score + neighbor_heuristic,
127 });
128
129 for (coefficient_i, &coefficient) in COEFFICIENTS.iter().enumerate() {
130 let node_score = neighbor_heuristic + tentative_g_score / coefficient;
131 if best_path_scores[coefficient_i] - node_score > MIN_IMPROVEMENT {
132 best_paths[coefficient_i] = neighbor_index;
133 best_path_scores[coefficient_i] = node_score;
134 }
135 }
136 }
137
138 if num_nodes % 10000 == 0 {
140 let min_timeout_reached = match min_timeout {
141 PathfinderTimeout::Time(max_duration) => start_time.elapsed() >= max_duration,
142 PathfinderTimeout::Nodes(max_nodes) => num_nodes >= max_nodes,
143 };
144
145 if min_timeout_reached {
146 if best_paths[6] != 0 {
148 break;
149 }
150
151 if min_timeout_reached {
152 let max_timeout_reached = match max_timeout {
153 PathfinderTimeout::Time(max_duration) => {
154 start_time.elapsed() >= max_duration
155 }
156 PathfinderTimeout::Nodes(max_nodes) => num_nodes >= max_nodes,
157 };
158
159 if max_timeout_reached {
160 trace!("A* couldn't find a path in time, returning best path");
162 break;
163 }
164 }
165 }
166 }
167 }
168
169 let best_path = determine_best_path(best_paths, 0);
170
171 let elapsed_seconds = start_time.elapsed().as_secs_f64();
172 let nodes_per_second = (num_nodes as f64 / elapsed_seconds) as u64;
173 let num_movements_per_second = (num_movements as f64 / elapsed_seconds) as u64;
174 debug!(
175 "A* ran at {} nodes per second and {} movements per second",
176 nodes_per_second.to_formatted_string(&num_format::Locale::en),
177 num_movements_per_second.to_formatted_string(&num_format::Locale::en),
178 );
179
180 Path {
181 movements: reconstruct_path(nodes, best_path, successors),
182 is_partial: true,
183 }
184}
185
186fn determine_best_path(best_paths: [usize; 7], start: usize) -> usize {
187 for node in best_paths {
190 if node != start {
191 return node;
192 }
193 }
194 warn!("No best node found, returning first node");
195 best_paths[0]
196}
197
198fn reconstruct_path<P, M, SuccessorsFn>(
199 nodes: FxIndexMap<P, Node>,
200 mut current_index: usize,
201 mut successors: SuccessorsFn,
202) -> Vec<Movement<P, M>>
203where
204 P: Eq + Hash + Copy + Debug,
205 SuccessorsFn: FnMut(P) -> Vec<Edge<P, M>>,
206{
207 let mut path = Vec::new();
208 while let Some((&node_position, node)) = nodes.get_index(current_index) {
209 if node.came_from == usize::MAX {
210 break;
211 }
212 let came_from_position = *nodes.get_index(node.came_from).unwrap().0;
213
214 let mut best_successor = None;
217 let mut best_successor_cost = f32::INFINITY;
218 for successor in successors(came_from_position) {
219 if successor.movement.target == node_position && successor.cost < best_successor_cost {
220 best_successor_cost = successor.cost;
221 best_successor = Some(successor);
222 }
223 }
224 let Some(found_successor) = best_successor else {
225 warn!(
226 "a successor stopped being possible while reconstructing the path, returning empty path"
227 );
228 return vec![];
229 };
230
231 path.push(Movement {
232 target: node_position,
233 data: found_successor.movement.data,
234 });
235
236 current_index = node.came_from;
237 }
238 path.reverse();
239 path
240}
241
242pub struct Node {
243 pub came_from: usize,
244 pub g_score: f32,
245}
246
247#[derive(Clone, Debug)]
248pub struct Edge<P: Hash + Copy, M> {
249 pub movement: Movement<P, M>,
250 pub cost: f32,
251}
252
253pub struct Movement<P: Hash + Copy, M> {
254 pub target: P,
255 pub data: M,
256}
257
258impl<P: Hash + Copy + Debug, M: Debug> Debug for Movement<P, M> {
259 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
260 f.debug_struct("Movement")
261 .field("target", &self.target)
262 .field("data", &self.data)
263 .finish()
264 }
265}
266impl<P: Hash + Copy + Clone, M: Clone> Clone for Movement<P, M> {
267 fn clone(&self) -> Self {
268 Self {
269 target: self.target,
270 data: self.data.clone(),
271 }
272 }
273}
274
275#[derive(PartialEq)]
276#[repr(C)]
277pub struct WeightedNode {
278 pub f_score: f32,
280 pub g_score: f32,
282 pub index: usize,
283}
284
285impl Ord for WeightedNode {
286 #[inline]
287 fn cmp(&self, other: &Self) -> cmp::Ordering {
288 match other.f_score.total_cmp(&self.f_score) {
290 cmp::Ordering::Equal => self.g_score.total_cmp(&other.g_score),
291 s => s,
292 }
293 }
294}
295impl Eq for WeightedNode {}
296impl PartialOrd for WeightedNode {
297 #[inline]
298 fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
299 Some(self.cmp(other))
300 }
301}
302
303#[derive(Debug, Clone, Copy, PartialEq)]
304pub enum PathfinderTimeout {
305 Time(Duration),
309 Nodes(usize),
314}
315impl Default for PathfinderTimeout {
316 fn default() -> Self {
317 Self::Time(Duration::from_secs(1))
318 }
319}
320
321#[cfg(test)]
322mod tests {
323 use super::*;
324
325 fn weighted_node(f: f32, g: f32) -> WeightedNode {
326 WeightedNode {
327 f_score: f,
328 g_score: g,
329 index: 0,
330 }
331 }
332
333 #[test]
334 fn test_weighted_node_eq() {
335 let a = weighted_node(0., 0.);
336 let b = weighted_node(0., 0.);
337 assert!(a == b);
338 }
339 #[test]
340 fn test_weighted_node_le() {
341 let a = weighted_node(1., 0.);
342 let b = weighted_node(0., 0.);
343 assert_eq!(a.cmp(&b), cmp::Ordering::Less);
344 assert!(a.le(&b));
345 }
346 #[test]
347 fn test_weighted_node_le_g() {
348 let a = weighted_node(0., 1.);
349 let b = weighted_node(0., 0.);
350 assert_eq!(a.cmp(&b), cmp::Ordering::Greater);
351 assert!(!a.le(&b));
352 }
353}