1use std::{
2 cmp::{self},
3 collections::BinaryHeap,
4 fmt::Debug,
5 hash::{BuildHasherDefault, Hash},
6 time::{Duration, Instant},
7};
8
9use indexmap::IndexMap;
10use num_format::ToFormattedString;
11use rustc_hash::FxHasher;
12use tracing::{debug, trace, warn};
13
14pub struct Path<P, M>
15where
16 P: Eq + Hash + Copy + Debug,
17{
18 pub movements: Vec<Movement<P, M>>,
19 pub is_partial: bool,
20}
21
22const COEFFICIENTS: [f32; 7] = [1.5, 2., 2.5, 3., 4., 5., 10.];
25
26const MIN_IMPROVEMENT: f32 = 0.01;
27
28type FxIndexMap<K, V> = IndexMap<K, V, BuildHasherDefault<FxHasher>>;
29
30pub fn a_star<P, M, HeuristicFn, SuccessorsFn, SuccessFn>(
35 start: P,
36 heuristic: HeuristicFn,
37 mut successors: SuccessorsFn,
38 success: SuccessFn,
39 min_timeout: PathfinderTimeout,
40 max_timeout: PathfinderTimeout,
41) -> Path<P, M>
42where
43 P: Eq + Hash + Copy + Debug,
44 HeuristicFn: Fn(P) -> f32,
45 SuccessorsFn: FnMut(P) -> Vec<Edge<P, M>>,
46 SuccessFn: Fn(P) -> bool,
47{
48 let start_time = Instant::now();
49
50 let mut open_set = BinaryHeap::<WeightedNode>::new();
51 open_set.push(WeightedNode {
52 g_score: 0.,
53 f_score: 0.,
54 index: 0,
55 });
56 let mut nodes: FxIndexMap<P, Node> = IndexMap::default();
57 nodes.insert(
58 start,
59 Node {
60 came_from: usize::MAX,
61 g_score: 0.,
62 },
63 );
64
65 let mut best_paths: [usize; 7] = [0; 7];
66 let mut best_path_scores: [f32; 7] = [heuristic(start); 7];
67
68 let mut num_nodes = 0;
69
70 while let Some(WeightedNode { index, g_score, .. }) = open_set.pop() {
71 num_nodes += 1;
72
73 let (&node, node_data) = nodes.get_index(index).unwrap();
74 if success(node) {
75 debug!("Nodes considered: {num_nodes}");
76
77 return Path {
78 movements: reconstruct_path(nodes, index, successors),
79 is_partial: false,
80 };
81 }
82
83 if g_score > node_data.g_score {
84 continue;
85 }
86
87 for neighbor in successors(node) {
88 let tentative_g_score = g_score + neighbor.cost;
89 let neighbor_heuristic;
91 let neighbor_index;
92
93 if tentative_g_score - g_score < MIN_IMPROVEMENT {
95 continue;
96 }
97
98 match nodes.entry(neighbor.movement.target) {
99 indexmap::map::Entry::Occupied(mut e) => {
100 if e.get().g_score > tentative_g_score {
101 neighbor_heuristic = heuristic(*e.key());
102 neighbor_index = e.index();
103 e.insert(Node {
104 came_from: index,
105 g_score: tentative_g_score,
106 });
107 } else {
108 continue;
109 }
110 }
111 indexmap::map::Entry::Vacant(e) => {
112 neighbor_heuristic = heuristic(*e.key());
113 neighbor_index = e.index();
114 e.insert(Node {
115 came_from: index,
116 g_score: tentative_g_score,
117 });
118 }
119 }
120
121 open_set.push(WeightedNode {
122 index: neighbor_index,
123 g_score: tentative_g_score,
124 f_score: tentative_g_score + neighbor_heuristic,
125 });
126
127 for (coefficient_i, &coefficient) in COEFFICIENTS.iter().enumerate() {
128 let node_score = neighbor_heuristic + tentative_g_score / coefficient;
129 if best_path_scores[coefficient_i] - node_score > MIN_IMPROVEMENT {
130 best_paths[coefficient_i] = neighbor_index;
131 best_path_scores[coefficient_i] = node_score;
132 }
133 }
134 }
135
136 if num_nodes % 10000 == 0 {
138 let min_timeout_reached = match min_timeout {
139 PathfinderTimeout::Time(max_duration) => start_time.elapsed() >= max_duration,
140 PathfinderTimeout::Nodes(max_nodes) => num_nodes >= max_nodes,
141 };
142
143 if min_timeout_reached {
144 if best_paths[6] != 0 {
146 break;
147 }
148
149 if min_timeout_reached {
150 let max_timeout_reached = match max_timeout {
151 PathfinderTimeout::Time(max_duration) => {
152 start_time.elapsed() >= max_duration
153 }
154 PathfinderTimeout::Nodes(max_nodes) => num_nodes >= max_nodes,
155 };
156
157 if max_timeout_reached {
158 trace!("A* couldn't find a path in time, returning best path");
160 break;
161 }
162 }
163 }
164 }
165 }
166
167 let best_path = determine_best_path(best_paths, 0);
168
169 debug!(
170 "A* ran at {} nodes per second",
171 ((num_nodes as f64 / start_time.elapsed().as_secs_f64()) as u64)
172 .to_formatted_string(&num_format::Locale::en)
173 );
174
175 Path {
176 movements: reconstruct_path(nodes, best_path, successors),
177 is_partial: true,
178 }
179}
180
181fn determine_best_path(best_paths: [usize; 7], start: usize) -> usize {
182 for node in best_paths {
185 if node != start {
186 return node;
187 }
188 }
189 warn!("No best node found, returning first node");
190 best_paths[0]
191}
192
193fn reconstruct_path<P, M, SuccessorsFn>(
194 nodes: FxIndexMap<P, Node>,
195 mut current_index: usize,
196 mut successors: SuccessorsFn,
197) -> Vec<Movement<P, M>>
198where
199 P: Eq + Hash + Copy + Debug,
200 SuccessorsFn: FnMut(P) -> Vec<Edge<P, M>>,
201{
202 let mut path = Vec::new();
203 while let Some((&node_position, node)) = nodes.get_index(current_index) {
204 if node.came_from == usize::MAX {
205 break;
206 }
207 let came_from_position = *nodes.get_index(node.came_from).unwrap().0;
208
209 let mut best_successor = None;
212 let mut best_successor_cost = f32::INFINITY;
213 for successor in successors(came_from_position) {
214 if successor.movement.target == node_position && successor.cost < best_successor_cost {
215 best_successor_cost = successor.cost;
216 best_successor = Some(successor);
217 }
218 }
219 let found_successor = best_successor.expect("No successor found");
220
221 path.push(Movement {
222 target: node_position,
223 data: found_successor.movement.data,
224 });
225
226 current_index = node.came_from;
227 }
228 path.reverse();
229 path
230}
231
232pub struct Node {
233 pub came_from: usize,
234 pub g_score: f32,
235}
236
237pub struct Edge<P: Hash + Copy, M> {
238 pub movement: Movement<P, M>,
239 pub cost: f32,
240}
241
242pub struct Movement<P: Hash + Copy, M> {
243 pub target: P,
244 pub data: M,
245}
246
247impl<P: Hash + Copy + Debug, M: Debug> Debug for Movement<P, M> {
248 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
249 f.debug_struct("Movement")
250 .field("target", &self.target)
251 .field("data", &self.data)
252 .finish()
253 }
254}
255impl<P: Hash + Copy + Clone, M: Clone> Clone for Movement<P, M> {
256 fn clone(&self) -> Self {
257 Self {
258 target: self.target,
259 data: self.data.clone(),
260 }
261 }
262}
263
264#[derive(PartialEq)]
265pub struct WeightedNode {
266 index: usize,
267 g_score: f32,
269 f_score: f32,
271}
272
273impl Ord for WeightedNode {
274 #[inline]
275 fn cmp(&self, other: &Self) -> cmp::Ordering {
276 match other.f_score.total_cmp(&self.f_score) {
278 cmp::Ordering::Equal => self.g_score.total_cmp(&other.g_score),
279 s => s,
280 }
281 }
282}
283impl Eq for WeightedNode {}
284impl PartialOrd for WeightedNode {
285 #[inline]
286 fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
287 Some(self.cmp(other))
288 }
289}
290
291#[derive(Debug, Clone, Copy, PartialEq)]
292pub enum PathfinderTimeout {
293 Time(Duration),
297 Nodes(usize),
302}
303impl Default for PathfinderTimeout {
304 fn default() -> Self {
305 Self::Time(Duration::from_secs(1))
306 }
307}