1use std::{error::Error, fmt};
2
3#[rustfmt::skip]
5const MAGIC: [(i32, i32, u32); 64] = [
6 (1, 0, 0), (1, 0, 1), (0x55555555, 0x55555555, 32),
10 (1, 0, 2),
11 (0x33333333, 0x33333333, 32),
12 (0x2AAAAAAA, 0x2AAAAAAA, 32),
13 (0x24924924, 0x24924924, 32),
14 (1, 0, 3),
15 (0x1C71C71C, 0x1C71C71C, 32),
16 (0x19999999, 0x19999999, 32),
17 (0x1745D174, 0x1745D174, 32),
18 (0x15555555, 0x15555555, 32),
19 (0x13B13B13, 0x13B13B13, 32),
20 (0x12492492, 0x12492492, 32),
21 (0x11111111, 0x11111111, 32),
22 (1, 0, 4),
23 (0xF0F0F0F, 0xF0F0F0F, 32),
24 (0xE38E38E, 0xE38E38E, 32),
25 (0xD79435E, 0xD79435E, 32),
26 (0x7FFFFFF8, 0x7FFFFFF8, 32),
27 (0xC30C30C, 0xC30C30C, 32),
28 (0xBA2E8BA, 0xBA2E8BA, 32),
29 (0xB21642C, 0xB21642C, 32),
30 (0xAAAAAAA, 0xAAAAAAA, 32),
31 (0xA3D70A3, 0xA3D70A3, 32),
32 (0x9D89D89, 0x9D89D89, 32),
33 (0x97B425E, 0x97B425E, 32),
34 (0x9249249, 0x9249249, 32),
35 (0x8D3DCB0, 0x8D3DCB0, 32),
36 (0x8888888, 0x8888888, 32),
37 (0x8421084, 0x8421084, 32),
38 (1, 0, 5),
39 (0x7C1F07C, 0x7C1F07C, 32),
40 (0x7878787, 0x7878787, 32),
41 (0x7507507, 0x7507507, 32),
42 (0x71C71C7, 0x71C71C7, 32),
43 (0x6EB3E45, 0x6EB3E45, 32),
44 (0x6BCA1AF, 0x6BCA1AF, 32),
45 (0x6906906, 0x6906906, 32),
46 (0x6666666, 0x6666666, 32),
47 (0x63E7063, 0x63E7063, 32),
48 (0x6186186, 0x6186186, 32),
49 (0x5F417D0, 0x5F417D0, 32),
50 (0x5D1745D, 0x5D1745D, 32),
51 (0x5B05B05, 0x5B05B05, 32),
52 (0x590B216, 0x590B216, 32),
53 (0x572620A, 0x572620A, 32),
54 (0x5555555, 0x5555555, 32),
55 (0x5397829, 0x5397829, 32),
56 (0x51EB851, 0x51EB851, 32),
57 (0x5050505, 0x5050505, 32),
58 (0x4EC4EC4, 0x4EC4EC4, 32),
59 (0x4D4873E, 0x4D4873E, 32),
60 (0x4BDA12F, 0x4BDA12F, 32),
61 (0x4A7904A, 0x4A7904A, 32),
62 (0x4924924, 0x4924924, 32),
63 (0x47DC11F, 0x47DC11F, 32),
64 (0x469EE58, 0x469EE58, 32),
65 (0x456C797, 0x456C797, 32),
66 (0x4444444, 0x4444444, 32),
67 (0x4325C53, 0x4325C53, 32),
68 (0x4210842, 0x4210842, 32),
69 (0x4104104, 0x4104104, 32),
70 (1, 0, 6),
71];
72
73#[derive(Clone, Debug, Default, PartialEq)]
75pub struct BitStorage {
76 pub data: Box<[u64]>,
77 bits: usize,
78 mask: u64,
79 size: usize,
80 values_per_long: usize,
81 divide_mul: i32,
82 divide_add: i32,
83 divide_shift: u32,
84}
85
86#[derive(Debug)]
87pub enum BitStorageError {
88 InvalidLength { got: usize, expected: usize },
89}
90impl fmt::Display for BitStorageError {
91 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
92 match self {
93 BitStorageError::InvalidLength { got, expected } => write!(
94 f,
95 "Invalid length given for storage, got: {got}, but expected: {expected}",
96 ),
97 }
98 }
99}
100impl Error for BitStorageError {}
101
102impl BitStorage {
103 pub fn new(
106 bits: usize,
107 size: usize,
108 data: Option<Box<[u64]>>,
109 ) -> Result<Self, BitStorageError> {
110 if let Some(data) = &data {
111 if data.is_empty() {
113 return Ok(BitStorage {
114 data: Box::new([]),
115 bits,
116 size,
117 ..Default::default()
118 });
119 }
120 }
121
122 debug_assert!((1..=32).contains(&bits));
123
124 let values_per_long = 64 / bits;
125 let magic_index = values_per_long - 1;
126 let (divide_mul, divide_add, divide_shift) = MAGIC[magic_index];
127 let calculated_length = size.div_ceil(values_per_long);
128
129 let mask = (1 << bits) - 1;
130
131 let using_data = if let Some(data) = data {
132 if data.len() != calculated_length {
133 return Err(BitStorageError::InvalidLength {
134 got: data.len(),
135 expected: calculated_length,
136 });
137 }
138 data
139 } else {
140 vec![0; calculated_length].into()
141 };
142
143 Ok(BitStorage {
144 data: using_data,
145 bits,
146 mask,
147 size,
148 values_per_long,
149 divide_mul,
150 divide_add,
151 divide_shift,
152 })
153 }
154
155 #[inline]
156 fn cell_index(&self, index: u64) -> usize {
157 let mul = self.divide_mul as u32 as u64;
158 let add = self.divide_add as u32 as u64;
159 let shift = self.divide_shift;
160
161 (((index * mul) + add) >> shift) as usize
162 }
163
164 pub fn get(&self, index: usize) -> u64 {
171 assert!(
172 index < self.size,
173 "Index {index} out of bounds (must be less than {})",
174 self.size
175 );
176
177 if self.data.is_empty() {
179 return 0;
180 }
181
182 let cell_index = self.cell_index(index as u64);
183 let cell = &self.data[cell_index];
184 let bit_index = (index - cell_index * self.values_per_long) * self.bits;
185 (cell >> bit_index) & self.mask
186 }
187
188 pub fn get_and_set(&mut self, index: usize, value: u64) -> u64 {
189 if self.data.is_empty() {
191 return 0;
192 }
193
194 debug_assert!(index < self.size);
195 debug_assert!(value <= self.mask);
196 let cell_index = self.cell_index(index as u64);
197 let cell = &mut self.data[cell_index];
198 let bit_index = (index - cell_index * self.values_per_long) * self.bits;
199 let old_value = (*cell >> (bit_index as u64)) & self.mask;
200 *cell = (*cell & !(self.mask << bit_index)) | ((value & self.mask) << bit_index);
201 old_value
202 }
203
204 pub fn set(&mut self, index: usize, value: u64) {
205 if self.data.is_empty() {
207 return;
208 }
209
210 debug_assert!(index < self.size);
211 debug_assert!(
212 value <= self.mask,
213 "value {value} at {index} was outside of the mask for {self:?}"
214 );
215 let cell_index = self.cell_index(index as u64);
216 let cell = &mut self.data[cell_index];
217 let bit_index = (index - cell_index * self.values_per_long) * self.bits;
218 *cell = (*cell & !(self.mask << bit_index)) | ((value & self.mask) << bit_index);
219 }
220
221 #[inline]
223 pub fn size(&self) -> usize {
224 self.size
225 }
226
227 pub fn iter(&self) -> BitStorageIter<'_> {
228 BitStorageIter {
229 storage: self,
230 index: 0,
231 }
232 }
233}
234
235pub struct BitStorageIter<'a> {
236 storage: &'a BitStorage,
237 index: usize,
238}
239
240impl Iterator for BitStorageIter<'_> {
241 type Item = u64;
242
243 fn next(&mut self) -> Option<Self::Item> {
244 if self.index >= self.storage.size {
245 return None;
246 }
247
248 let value = self.storage.get(self.index);
249 self.index += 1;
250 Some(value)
251 }
252}
253
254#[cfg(test)]
255mod tests {
256 use super::*;
257
258 #[test]
259 fn test_protocol_wiki_example() {
260 let data = [
263 1, 2, 2, 3, 4, 4, 5, 6, 6, 4, 8, 0, 7, 4, 3, 13, 15, 16, 9, 14, 10, 12, 0, 2,
264 ];
265 let compact_data: [u64; 2] = [0x0020863148418841, 0x01018A7260F68C87];
266 let storage = BitStorage::new(5, data.len(), Some(Box::new(compact_data))).unwrap();
267
268 for (i, expected) in data.iter().enumerate() {
269 assert_eq!(storage.get(i), *expected);
270 }
271 }
272}