1use std::{
2 cmp::{self, Reverse},
3 collections::BinaryHeap,
4 fmt::{self, Debug},
5 hash::{BuildHasherDefault, Hash},
6 time::{Duration, Instant},
7};
8
9use indexmap::IndexMap;
10use num_format::ToFormattedString;
11use radix_heap::RadixHeapMap;
12use rustc_hash::FxHasher;
13use tracing::{debug, trace, warn};
14
15pub struct Path<P, M>
16where
17 P: Eq + Hash + Copy + Debug,
18{
19 pub movements: Vec<Movement<P, M>>,
20 pub is_partial: bool,
21 pub cost: f32,
26}
27
28const COEFFICIENTS: [f32; 7] = [1.5, 2., 2.5, 3., 4., 5., 10.];
31
32const MIN_IMPROVEMENT: f32 = 0.01;
33
34type FxIndexMap<K, V> = IndexMap<K, V, BuildHasherDefault<FxHasher>>;
35
36pub fn a_star<P, M, HeuristicFn, SuccessorsFn, SuccessFn>(
41 start: P,
42 heuristic: HeuristicFn,
43 mut successors: SuccessorsFn,
44 success: SuccessFn,
45 min_timeout: PathfinderTimeout,
46 max_timeout: PathfinderTimeout,
47) -> Path<P, M>
48where
49 P: Eq + Hash + Copy + Debug,
50 HeuristicFn: Fn(P) -> f32,
51 SuccessorsFn: FnMut(P) -> Vec<Edge<P, M>>,
52 SuccessFn: Fn(P) -> bool,
53{
54 let start_time = Instant::now();
55
56 let mut open_set = PathfinderHeap::new();
57 open_set.push(WeightedNode {
58 g_score: 0.,
59 f_score: 0.,
60 index: 0,
61 });
62 let mut nodes: FxIndexMap<P, Node> = IndexMap::default();
63 nodes.insert(
64 start,
65 Node {
66 came_from: u32::MAX,
67 g_score: 0.,
68 },
69 );
70
71 let mut best_paths: [u32; 7] = [0; 7];
72 let mut best_path_scores: [f32; 7] = [heuristic(start); 7];
73
74 let mut num_nodes = 0_usize;
75 let mut num_movements = 0;
76
77 while let Some(WeightedNode { index, g_score, .. }) = open_set.pop() {
78 let (&node, node_data) = nodes.get_index(index as usize).unwrap();
79 if g_score > node_data.g_score {
80 continue;
81 }
82
83 num_nodes += 1;
84
85 if success(node) {
86 let best_path = index;
87 log_perf_info(start_time, num_nodes, num_movements);
88
89 return Path {
90 movements: reconstruct_path(nodes, best_path, successors),
91 is_partial: false,
92 cost: g_score,
93 };
94 }
95
96 for neighbor in successors(node) {
97 let tentative_g_score = g_score + neighbor.cost;
98 let neighbor_heuristic;
100 let neighbor_index;
101
102 num_movements += 1;
103
104 match nodes.entry(neighbor.movement.target) {
105 indexmap::map::Entry::Occupied(mut e) => {
106 if e.get().g_score > tentative_g_score {
107 neighbor_heuristic = heuristic(*e.key());
108 neighbor_index = e.index() as u32;
109 e.insert(Node {
110 came_from: index,
111 g_score: tentative_g_score,
112 });
113 } else {
114 continue;
115 }
116 }
117 indexmap::map::Entry::Vacant(e) => {
118 neighbor_heuristic = heuristic(*e.key());
119 neighbor_index = e.index() as u32;
120 e.insert(Node {
121 came_from: index,
122 g_score: tentative_g_score,
123 });
124 }
125 }
126
127 open_set.push(WeightedNode {
131 index: neighbor_index,
132 g_score: tentative_g_score,
133 f_score: tentative_g_score + neighbor_heuristic,
134 });
135
136 for (coefficient_i, &coefficient) in COEFFICIENTS.iter().enumerate() {
137 let node_score = neighbor_heuristic + tentative_g_score / coefficient;
138 if best_path_scores[coefficient_i] - node_score > MIN_IMPROVEMENT {
139 best_paths[coefficient_i] = neighbor_index;
140 best_path_scores[coefficient_i] = node_score;
141 }
142 }
143 }
144
145 if num_nodes.is_multiple_of(10_000) {
147 let min_timeout_reached = match min_timeout {
148 PathfinderTimeout::Time(max_duration) => start_time.elapsed() >= max_duration,
149 PathfinderTimeout::Nodes(max_nodes) => num_nodes >= max_nodes,
150 };
151
152 if min_timeout_reached {
153 if best_paths[6] != 0 {
155 break;
156 }
157
158 if min_timeout_reached {
159 let max_timeout_reached = match max_timeout {
160 PathfinderTimeout::Time(max_duration) => {
161 start_time.elapsed() >= max_duration
162 }
163 PathfinderTimeout::Nodes(max_nodes) => num_nodes >= max_nodes,
164 };
165
166 if max_timeout_reached {
167 trace!("A* couldn't find a path in time, returning best path");
169 break;
170 }
171 }
172 }
173 }
174 }
175
176 let best_path_idx = determine_best_path_idx(best_paths, 0);
177 log_perf_info(start_time, num_nodes, num_movements);
178 Path {
179 movements: reconstruct_path(nodes, best_paths[best_path_idx], successors),
180 is_partial: true,
181 cost: best_path_scores[best_path_idx],
182 }
183}
184
185fn log_perf_info(start_time: Instant, num_nodes: usize, num_movements: usize) {
186 let elapsed = start_time.elapsed();
187 let elapsed_seconds = elapsed.as_secs_f64();
188 let nodes_per_second = (num_nodes as f64 / elapsed_seconds) as u64;
189 let num_movements_per_second = (num_movements as f64 / elapsed_seconds) as u64;
190 debug!(
191 "Considered {} nodes in {elapsed:?}",
192 num_nodes.to_formatted_string(&num_format::Locale::en)
193 );
194 debug!(
195 "A* ran at {} nodes per second and {} movements per second",
196 nodes_per_second.to_formatted_string(&num_format::Locale::en),
197 num_movements_per_second.to_formatted_string(&num_format::Locale::en),
198 );
199}
200
201fn determine_best_path_idx(best_paths: [u32; 7], start: u32) -> usize {
202 for (i, &node) in best_paths.iter().enumerate() {
205 if node != start {
206 return i;
207 }
208 }
209 warn!("No best node found, returning first node");
210 0
211}
212
213fn reconstruct_path<P, M, SuccessorsFn>(
214 nodes: FxIndexMap<P, Node>,
215 mut current_index: u32,
216 mut successors: SuccessorsFn,
217) -> Vec<Movement<P, M>>
218where
219 P: Eq + Hash + Copy + Debug,
220 SuccessorsFn: FnMut(P) -> Vec<Edge<P, M>>,
221{
222 let mut path = Vec::new();
223 while let Some((&node_position, node)) = nodes.get_index(current_index as usize) {
224 if node.came_from == u32::MAX {
225 break;
226 }
227 let came_from_position = *nodes.get_index(node.came_from as usize).unwrap().0;
228
229 let mut best_successor = None;
232 let mut best_successor_cost = f32::INFINITY;
233 for successor in successors(came_from_position) {
234 if successor.movement.target == node_position && successor.cost < best_successor_cost {
235 best_successor_cost = successor.cost;
236 best_successor = Some(successor);
237 }
238 }
239 let Some(found_successor) = best_successor else {
240 warn!(
241 "a successor stopped being possible while reconstructing the path, returning empty path"
242 );
243 return vec![];
244 };
245
246 path.push(Movement {
247 target: node_position,
248 data: found_successor.movement.data,
249 });
250
251 current_index = node.came_from;
252 }
253 path.reverse();
254 path
255}
256
257pub struct Node {
258 pub came_from: u32,
259 pub g_score: f32,
260}
261
262#[derive(Clone, Debug)]
263pub struct Edge<P: Hash + Copy, M> {
264 pub movement: Movement<P, M>,
265 pub cost: f32,
266}
267
268pub struct Movement<P: Hash + Copy, M> {
269 pub target: P,
270 pub data: M,
271}
272
273impl<P: Hash + Copy + Debug, M: Debug> Debug for Movement<P, M> {
274 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
275 f.debug_struct("Movement")
276 .field("target", &self.target)
277 .field("data", &self.data)
278 .finish()
279 }
280}
281impl<P: Hash + Copy + Clone, M: Clone> Clone for Movement<P, M> {
282 fn clone(&self) -> Self {
283 Self {
284 target: self.target,
285 data: self.data.clone(),
286 }
287 }
288}
289
290#[derive(Default)]
291struct PathfinderHeap {
292 binary_heap: BinaryHeap<WeightedNode>,
293 radix_heap: RadixHeapMap<Reverse<u32>, (f32, u32)>,
298}
299impl PathfinderHeap {
300 pub fn new() -> Self {
301 Self::default()
302 }
303
304 pub fn push(&mut self, item: WeightedNode) {
305 if let Some(top) = self.radix_heap.top() {
306 if item.f_score < f32::from_bits(top.0) {
309 self.binary_heap.push(item);
310 return;
311 }
312 }
313 self.radix_heap
314 .push(Reverse(item.f_score.to_bits()), (item.g_score, item.index))
315 }
316 pub fn pop(&mut self) -> Option<WeightedNode> {
317 self.binary_heap.pop().or_else(|| {
318 self.radix_heap
319 .pop()
320 .map(|(f_score, (g_score, index))| WeightedNode {
321 f_score: f32::from_bits(f_score.0),
322 g_score,
323 index,
324 })
325 })
326 }
327}
328
329#[derive(PartialEq)]
330#[repr(C)]
331pub struct WeightedNode {
332 pub f_score: f32,
334 pub g_score: f32,
336 pub index: u32,
337}
338
339impl Ord for WeightedNode {
340 #[inline]
341 fn cmp(&self, other: &Self) -> cmp::Ordering {
342 match other.f_score.total_cmp(&self.f_score) {
344 cmp::Ordering::Equal => self.g_score.total_cmp(&other.g_score),
345 s => s,
346 }
347 }
348}
349impl Eq for WeightedNode {}
350impl PartialOrd for WeightedNode {
351 #[inline]
352 fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
353 Some(self.cmp(other))
354 }
355}
356
357#[derive(Clone, Copy, Debug, PartialEq)]
365pub enum PathfinderTimeout {
366 Time(Duration),
371 Nodes(usize),
376}
377impl Default for PathfinderTimeout {
378 fn default() -> Self {
379 Self::Time(Duration::from_secs(1))
380 }
381}
382
383#[cfg(test)]
384mod tests {
385 use super::*;
386
387 fn weighted_node(f: f32, g: f32) -> WeightedNode {
388 WeightedNode {
389 f_score: f,
390 g_score: g,
391 index: 0,
392 }
393 }
394
395 #[test]
396 fn test_weighted_node_eq() {
397 let a = weighted_node(0., 0.);
398 let b = weighted_node(0., 0.);
399 assert!(a == b);
400 }
401 #[test]
402 fn test_weighted_node_le() {
403 let a = weighted_node(1., 0.);
404 let b = weighted_node(0., 0.);
405 assert_eq!(a.cmp(&b), cmp::Ordering::Less);
406 assert!(a.le(&b));
407 }
408 #[test]
409 fn test_weighted_node_le_g() {
410 let a = weighted_node(0., 1.);
411 let b = weighted_node(0., 0.);
412 assert_eq!(a.cmp(&b), cmp::Ordering::Greater);
413 assert!(!a.le(&b));
414 }
415}