1use std::{
2 cmp::{self},
3 collections::BinaryHeap,
4 fmt::{self, Debug},
5 hash::{BuildHasherDefault, Hash},
6 time::{Duration, Instant},
7};
8
9use indexmap::IndexMap;
10use num_format::ToFormattedString;
11use rustc_hash::FxHasher;
12use tracing::{debug, trace, warn};
13
14pub struct Path<P, M>
15where
16 P: Eq + Hash + Copy + Debug,
17{
18 pub movements: Vec<Movement<P, M>>,
19 pub is_partial: bool,
20}
21
22const COEFFICIENTS: [f32; 7] = [1.5, 2., 2.5, 3., 4., 5., 10.];
25
26const MIN_IMPROVEMENT: f32 = 0.01;
27
28type FxIndexMap<K, V> = IndexMap<K, V, BuildHasherDefault<FxHasher>>;
29
30pub fn a_star<P, M, HeuristicFn, SuccessorsFn, SuccessFn>(
35 start: P,
36 heuristic: HeuristicFn,
37 mut successors: SuccessorsFn,
38 success: SuccessFn,
39 min_timeout: PathfinderTimeout,
40 max_timeout: PathfinderTimeout,
41) -> Path<P, M>
42where
43 P: Eq + Hash + Copy + Debug,
44 HeuristicFn: Fn(P) -> f32,
45 SuccessorsFn: FnMut(P) -> Vec<Edge<P, M>>,
46 SuccessFn: Fn(P) -> bool,
47{
48 let start_time = Instant::now();
49
50 let mut open_set = BinaryHeap::<WeightedNode>::new();
51 open_set.push(WeightedNode {
52 g_score: 0.,
53 f_score: 0.,
54 index: 0,
55 });
56 let mut nodes: FxIndexMap<P, Node> = IndexMap::default();
57 nodes.insert(
58 start,
59 Node {
60 came_from: usize::MAX,
61 g_score: 0.,
62 },
63 );
64
65 let mut best_paths: [usize; 7] = [0; 7];
66 let mut best_path_scores: [f32; 7] = [heuristic(start); 7];
67
68 let mut num_nodes = 0_usize;
69 let mut num_movements = 0;
70
71 while let Some(WeightedNode { index, g_score, .. }) = open_set.pop() {
72 num_nodes += 1;
73
74 let (&node, node_data) = nodes.get_index(index).unwrap();
75 if success(node) {
76 let best_path = index;
77 log_perf_info(start_time, num_nodes, num_movements);
78
79 return Path {
80 movements: reconstruct_path(nodes, best_path, successors),
81 is_partial: false,
82 };
83 }
84
85 if g_score > node_data.g_score {
86 continue;
87 }
88
89 for neighbor in successors(node) {
90 let tentative_g_score = g_score + neighbor.cost;
91 let neighbor_heuristic;
93 let neighbor_index;
94
95 num_movements += 1;
96
97 match nodes.entry(neighbor.movement.target) {
98 indexmap::map::Entry::Occupied(mut e) => {
99 if e.get().g_score > tentative_g_score {
100 neighbor_heuristic = heuristic(*e.key());
101 neighbor_index = e.index();
102 e.insert(Node {
103 came_from: index,
104 g_score: tentative_g_score,
105 });
106 } else {
107 continue;
108 }
109 }
110 indexmap::map::Entry::Vacant(e) => {
111 neighbor_heuristic = heuristic(*e.key());
112 neighbor_index = e.index();
113 e.insert(Node {
114 came_from: index,
115 g_score: tentative_g_score,
116 });
117 }
118 }
119
120 open_set.push(WeightedNode {
121 index: neighbor_index,
122 g_score: tentative_g_score,
123 f_score: tentative_g_score + neighbor_heuristic,
124 });
125
126 for (coefficient_i, &coefficient) in COEFFICIENTS.iter().enumerate() {
127 let node_score = neighbor_heuristic + tentative_g_score / coefficient;
128 if best_path_scores[coefficient_i] - node_score > MIN_IMPROVEMENT {
129 best_paths[coefficient_i] = neighbor_index;
130 best_path_scores[coefficient_i] = node_score;
131 }
132 }
133 }
134
135 if num_nodes.is_multiple_of(10_000) {
137 let min_timeout_reached = match min_timeout {
138 PathfinderTimeout::Time(max_duration) => start_time.elapsed() >= max_duration,
139 PathfinderTimeout::Nodes(max_nodes) => num_nodes >= max_nodes,
140 };
141
142 if min_timeout_reached {
143 if best_paths[6] != 0 {
145 break;
146 }
147
148 if min_timeout_reached {
149 let max_timeout_reached = match max_timeout {
150 PathfinderTimeout::Time(max_duration) => {
151 start_time.elapsed() >= max_duration
152 }
153 PathfinderTimeout::Nodes(max_nodes) => num_nodes >= max_nodes,
154 };
155
156 if max_timeout_reached {
157 trace!("A* couldn't find a path in time, returning best path");
159 break;
160 }
161 }
162 }
163 }
164 }
165
166 let best_path = determine_best_path(best_paths, 0);
167 log_perf_info(start_time, num_nodes, num_movements);
168 Path {
169 movements: reconstruct_path(nodes, best_path, successors),
170 is_partial: true,
171 }
172}
173
174fn log_perf_info(start_time: Instant, num_nodes: usize, num_movements: usize) {
175 let elapsed_seconds = start_time.elapsed().as_secs_f64();
176 let nodes_per_second = (num_nodes as f64 / elapsed_seconds) as u64;
177 let num_movements_per_second = (num_movements as f64 / elapsed_seconds) as u64;
178 debug!(
179 "Nodes considered: {}",
180 num_nodes.to_formatted_string(&num_format::Locale::en)
181 );
182 debug!(
183 "A* ran at {} nodes per second and {} movements per second",
184 nodes_per_second.to_formatted_string(&num_format::Locale::en),
185 num_movements_per_second.to_formatted_string(&num_format::Locale::en),
186 );
187}
188
189fn determine_best_path(best_paths: [usize; 7], start: usize) -> usize {
190 for node in best_paths {
193 if node != start {
194 return node;
195 }
196 }
197 warn!("No best node found, returning first node");
198 best_paths[0]
199}
200
201fn reconstruct_path<P, M, SuccessorsFn>(
202 nodes: FxIndexMap<P, Node>,
203 mut current_index: usize,
204 mut successors: SuccessorsFn,
205) -> Vec<Movement<P, M>>
206where
207 P: Eq + Hash + Copy + Debug,
208 SuccessorsFn: FnMut(P) -> Vec<Edge<P, M>>,
209{
210 let mut path = Vec::new();
211 while let Some((&node_position, node)) = nodes.get_index(current_index) {
212 if node.came_from == usize::MAX {
213 break;
214 }
215 let came_from_position = *nodes.get_index(node.came_from).unwrap().0;
216
217 let mut best_successor = None;
220 let mut best_successor_cost = f32::INFINITY;
221 for successor in successors(came_from_position) {
222 if successor.movement.target == node_position && successor.cost < best_successor_cost {
223 best_successor_cost = successor.cost;
224 best_successor = Some(successor);
225 }
226 }
227 let Some(found_successor) = best_successor else {
228 warn!(
229 "a successor stopped being possible while reconstructing the path, returning empty path"
230 );
231 return vec![];
232 };
233
234 path.push(Movement {
235 target: node_position,
236 data: found_successor.movement.data,
237 });
238
239 current_index = node.came_from;
240 }
241 path.reverse();
242 path
243}
244
245pub struct Node {
246 pub came_from: usize,
247 pub g_score: f32,
248}
249
250#[derive(Clone, Debug)]
251pub struct Edge<P: Hash + Copy, M> {
252 pub movement: Movement<P, M>,
253 pub cost: f32,
254}
255
256pub struct Movement<P: Hash + Copy, M> {
257 pub target: P,
258 pub data: M,
259}
260
261impl<P: Hash + Copy + Debug, M: Debug> Debug for Movement<P, M> {
262 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
263 f.debug_struct("Movement")
264 .field("target", &self.target)
265 .field("data", &self.data)
266 .finish()
267 }
268}
269impl<P: Hash + Copy + Clone, M: Clone> Clone for Movement<P, M> {
270 fn clone(&self) -> Self {
271 Self {
272 target: self.target,
273 data: self.data.clone(),
274 }
275 }
276}
277
278#[derive(PartialEq)]
279#[repr(C)]
280pub struct WeightedNode {
281 pub f_score: f32,
283 pub g_score: f32,
285 pub index: usize,
286}
287
288impl Ord for WeightedNode {
289 #[inline]
290 fn cmp(&self, other: &Self) -> cmp::Ordering {
291 match other.f_score.total_cmp(&self.f_score) {
293 cmp::Ordering::Equal => self.g_score.total_cmp(&other.g_score),
294 s => s,
295 }
296 }
297}
298impl Eq for WeightedNode {}
299impl PartialOrd for WeightedNode {
300 #[inline]
301 fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
302 Some(self.cmp(other))
303 }
304}
305
306#[derive(Debug, Clone, Copy, PartialEq)]
314pub enum PathfinderTimeout {
315 Time(Duration),
319 Nodes(usize),
324}
325impl Default for PathfinderTimeout {
326 fn default() -> Self {
327 Self::Time(Duration::from_secs(1))
328 }
329}
330
331#[cfg(test)]
332mod tests {
333 use super::*;
334
335 fn weighted_node(f: f32, g: f32) -> WeightedNode {
336 WeightedNode {
337 f_score: f,
338 g_score: g,
339 index: 0,
340 }
341 }
342
343 #[test]
344 fn test_weighted_node_eq() {
345 let a = weighted_node(0., 0.);
346 let b = weighted_node(0., 0.);
347 assert!(a == b);
348 }
349 #[test]
350 fn test_weighted_node_le() {
351 let a = weighted_node(1., 0.);
352 let b = weighted_node(0., 0.);
353 assert_eq!(a.cmp(&b), cmp::Ordering::Less);
354 assert!(a.le(&b));
355 }
356 #[test]
357 fn test_weighted_node_le_g() {
358 let a = weighted_node(0., 1.);
359 let b = weighted_node(0., 0.);
360 assert_eq!(a.cmp(&b), cmp::Ordering::Greater);
361 assert!(!a.le(&b));
362 }
363}