1use std::{
2 cmp::{self},
3 collections::BinaryHeap,
4 fmt::{self, Debug},
5 hash::{BuildHasherDefault, Hash},
6 time::{Duration, Instant},
7};
8
9use indexmap::IndexMap;
10use num_format::ToFormattedString;
11use rustc_hash::FxHasher;
12use tracing::{debug, trace, warn};
13
14pub struct Path<P, M>
15where
16 P: Eq + Hash + Copy + Debug,
17{
18 pub movements: Vec<Movement<P, M>>,
19 pub is_partial: bool,
20}
21
22const COEFFICIENTS: [f32; 7] = [1.5, 2., 2.5, 3., 4., 5., 10.];
25
26const MIN_IMPROVEMENT: f32 = 0.01;
27
28type FxIndexMap<K, V> = IndexMap<K, V, BuildHasherDefault<FxHasher>>;
29
30pub fn a_star<P, M, HeuristicFn, SuccessorsFn, SuccessFn>(
35 start: P,
36 heuristic: HeuristicFn,
37 mut successors: SuccessorsFn,
38 success: SuccessFn,
39 min_timeout: PathfinderTimeout,
40 max_timeout: PathfinderTimeout,
41) -> Path<P, M>
42where
43 P: Eq + Hash + Copy + Debug,
44 HeuristicFn: Fn(P) -> f32,
45 SuccessorsFn: FnMut(P) -> Vec<Edge<P, M>>,
46 SuccessFn: Fn(P) -> bool,
47{
48 let start_time = Instant::now();
49
50 let mut open_set = BinaryHeap::<WeightedNode>::new();
51 open_set.push(WeightedNode {
52 g_score: 0.,
53 f_score: 0.,
54 index: 0,
55 });
56 let mut nodes: FxIndexMap<P, Node> = IndexMap::default();
57 nodes.insert(
58 start,
59 Node {
60 came_from: usize::MAX,
61 g_score: 0.,
62 },
63 );
64
65 let mut best_paths: [usize; 7] = [0; 7];
66 let mut best_path_scores: [f32; 7] = [heuristic(start); 7];
67
68 let mut num_nodes = 0_usize;
69 let mut num_movements = 0;
70
71 while let Some(WeightedNode { index, g_score, .. }) = open_set.pop() {
72 num_nodes += 1;
73
74 let (&node, node_data) = nodes.get_index(index).unwrap();
75 if success(node) {
76 debug!("Nodes considered: {num_nodes}");
77
78 return Path {
79 movements: reconstruct_path(nodes, index, successors),
80 is_partial: false,
81 };
82 }
83
84 if g_score > node_data.g_score {
85 continue;
86 }
87
88 for neighbor in successors(node) {
89 let tentative_g_score = g_score + neighbor.cost;
90 let neighbor_heuristic;
92 let neighbor_index;
93
94 num_movements += 1;
95
96 match nodes.entry(neighbor.movement.target) {
97 indexmap::map::Entry::Occupied(mut e) => {
98 if e.get().g_score > tentative_g_score {
99 neighbor_heuristic = heuristic(*e.key());
100 neighbor_index = e.index();
101 e.insert(Node {
102 came_from: index,
103 g_score: tentative_g_score,
104 });
105 } else {
106 continue;
107 }
108 }
109 indexmap::map::Entry::Vacant(e) => {
110 neighbor_heuristic = heuristic(*e.key());
111 neighbor_index = e.index();
112 e.insert(Node {
113 came_from: index,
114 g_score: tentative_g_score,
115 });
116 }
117 }
118
119 open_set.push(WeightedNode {
120 index: neighbor_index,
121 g_score: tentative_g_score,
122 f_score: tentative_g_score + neighbor_heuristic,
123 });
124
125 for (coefficient_i, &coefficient) in COEFFICIENTS.iter().enumerate() {
126 let node_score = neighbor_heuristic + tentative_g_score / coefficient;
127 if best_path_scores[coefficient_i] - node_score > MIN_IMPROVEMENT {
128 best_paths[coefficient_i] = neighbor_index;
129 best_path_scores[coefficient_i] = node_score;
130 }
131 }
132 }
133
134 if num_nodes.is_multiple_of(10_000) {
136 let min_timeout_reached = match min_timeout {
137 PathfinderTimeout::Time(max_duration) => start_time.elapsed() >= max_duration,
138 PathfinderTimeout::Nodes(max_nodes) => num_nodes >= max_nodes,
139 };
140
141 if min_timeout_reached {
142 if best_paths[6] != 0 {
144 break;
145 }
146
147 if min_timeout_reached {
148 let max_timeout_reached = match max_timeout {
149 PathfinderTimeout::Time(max_duration) => {
150 start_time.elapsed() >= max_duration
151 }
152 PathfinderTimeout::Nodes(max_nodes) => num_nodes >= max_nodes,
153 };
154
155 if max_timeout_reached {
156 trace!("A* couldn't find a path in time, returning best path");
158 break;
159 }
160 }
161 }
162 }
163 }
164
165 let best_path = determine_best_path(best_paths, 0);
166
167 let elapsed_seconds = start_time.elapsed().as_secs_f64();
168 let nodes_per_second = (num_nodes as f64 / elapsed_seconds) as u64;
169 let num_movements_per_second = (num_movements as f64 / elapsed_seconds) as u64;
170 debug!(
171 "A* ran at {} nodes per second and {} movements per second",
172 nodes_per_second.to_formatted_string(&num_format::Locale::en),
173 num_movements_per_second.to_formatted_string(&num_format::Locale::en),
174 );
175
176 Path {
177 movements: reconstruct_path(nodes, best_path, successors),
178 is_partial: true,
179 }
180}
181
182fn determine_best_path(best_paths: [usize; 7], start: usize) -> usize {
183 for node in best_paths {
186 if node != start {
187 return node;
188 }
189 }
190 warn!("No best node found, returning first node");
191 best_paths[0]
192}
193
194fn reconstruct_path<P, M, SuccessorsFn>(
195 nodes: FxIndexMap<P, Node>,
196 mut current_index: usize,
197 mut successors: SuccessorsFn,
198) -> Vec<Movement<P, M>>
199where
200 P: Eq + Hash + Copy + Debug,
201 SuccessorsFn: FnMut(P) -> Vec<Edge<P, M>>,
202{
203 let mut path = Vec::new();
204 while let Some((&node_position, node)) = nodes.get_index(current_index) {
205 if node.came_from == usize::MAX {
206 break;
207 }
208 let came_from_position = *nodes.get_index(node.came_from).unwrap().0;
209
210 let mut best_successor = None;
213 let mut best_successor_cost = f32::INFINITY;
214 for successor in successors(came_from_position) {
215 if successor.movement.target == node_position && successor.cost < best_successor_cost {
216 best_successor_cost = successor.cost;
217 best_successor = Some(successor);
218 }
219 }
220 let Some(found_successor) = best_successor else {
221 warn!(
222 "a successor stopped being possible while reconstructing the path, returning empty path"
223 );
224 return vec![];
225 };
226
227 path.push(Movement {
228 target: node_position,
229 data: found_successor.movement.data,
230 });
231
232 current_index = node.came_from;
233 }
234 path.reverse();
235 path
236}
237
238pub struct Node {
239 pub came_from: usize,
240 pub g_score: f32,
241}
242
243#[derive(Clone, Debug)]
244pub struct Edge<P: Hash + Copy, M> {
245 pub movement: Movement<P, M>,
246 pub cost: f32,
247}
248
249pub struct Movement<P: Hash + Copy, M> {
250 pub target: P,
251 pub data: M,
252}
253
254impl<P: Hash + Copy + Debug, M: Debug> Debug for Movement<P, M> {
255 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
256 f.debug_struct("Movement")
257 .field("target", &self.target)
258 .field("data", &self.data)
259 .finish()
260 }
261}
262impl<P: Hash + Copy + Clone, M: Clone> Clone for Movement<P, M> {
263 fn clone(&self) -> Self {
264 Self {
265 target: self.target,
266 data: self.data.clone(),
267 }
268 }
269}
270
271#[derive(PartialEq)]
272#[repr(C)]
273pub struct WeightedNode {
274 pub f_score: f32,
276 pub g_score: f32,
278 pub index: usize,
279}
280
281impl Ord for WeightedNode {
282 #[inline]
283 fn cmp(&self, other: &Self) -> cmp::Ordering {
284 match other.f_score.total_cmp(&self.f_score) {
286 cmp::Ordering::Equal => self.g_score.total_cmp(&other.g_score),
287 s => s,
288 }
289 }
290}
291impl Eq for WeightedNode {}
292impl PartialOrd for WeightedNode {
293 #[inline]
294 fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
295 Some(self.cmp(other))
296 }
297}
298
299#[derive(Debug, Clone, Copy, PartialEq)]
300pub enum PathfinderTimeout {
301 Time(Duration),
305 Nodes(usize),
310}
311impl Default for PathfinderTimeout {
312 fn default() -> Self {
313 Self::Time(Duration::from_secs(1))
314 }
315}
316
317#[cfg(test)]
318mod tests {
319 use super::*;
320
321 fn weighted_node(f: f32, g: f32) -> WeightedNode {
322 WeightedNode {
323 f_score: f,
324 g_score: g,
325 index: 0,
326 }
327 }
328
329 #[test]
330 fn test_weighted_node_eq() {
331 let a = weighted_node(0., 0.);
332 let b = weighted_node(0., 0.);
333 assert!(a == b);
334 }
335 #[test]
336 fn test_weighted_node_le() {
337 let a = weighted_node(1., 0.);
338 let b = weighted_node(0., 0.);
339 assert_eq!(a.cmp(&b), cmp::Ordering::Less);
340 assert!(a.le(&b));
341 }
342 #[test]
343 fn test_weighted_node_le_g() {
344 let a = weighted_node(0., 1.);
345 let b = weighted_node(0., 0.);
346 assert_eq!(a.cmp(&b), cmp::Ordering::Greater);
347 assert!(!a.le(&b));
348 }
349}