1use std::io::{self, Cursor, Write};
2
3use azalea_buf::{AzBuf, AzaleaRead, AzaleaWrite, BufReadError};
4
5#[derive(Debug, Clone, PartialEq, Eq, Hash, Default, AzBuf)]
7pub struct BitSet {
8 data: Vec<u64>,
9}
10
11const ADDRESS_BITS_PER_WORD: usize = 6;
12
13impl BitSet {
15 pub fn new(num_bits: usize) -> Self {
16 BitSet {
17 data: vec![0; num_bits.div_ceil(64)],
18 }
19 }
20
21 pub fn index(&self, index: usize) -> bool {
22 (self.data[index / 64] & (1u64 << (index % 64))) != 0
23 }
24
25 fn check_range(&self, from_index: usize, to_index: usize) {
26 assert!(
27 from_index <= to_index,
28 "fromIndex: {from_index} > toIndex: {to_index}",
29 );
30 }
31
32 fn word_index(&self, bit_index: usize) -> usize {
33 bit_index >> ADDRESS_BITS_PER_WORD
34 }
35
36 pub fn clear(&mut self, from_index: usize, mut to_index: usize) {
37 self.check_range(from_index, to_index);
38
39 if from_index == to_index {
40 return;
41 }
42
43 let start_word_index = self.word_index(from_index);
44 if start_word_index >= self.data.len() {
45 return;
46 }
47
48 let mut end_word_index = self.word_index(to_index - 1);
49 if end_word_index >= self.data.len() {
50 to_index = self.len();
51 end_word_index = self.data.len() - 1;
52 }
53
54 let first_word_mask = u64::MAX.wrapping_shl(
55 from_index
56 .try_into()
57 .expect("from_index shouldn't be larger than u32"),
58 );
59 let last_word_mask = u64::MAX.wrapping_shr((64 - (to_index % 64)) as u32);
60 if start_word_index == end_word_index {
61 self.data[start_word_index] &= !(first_word_mask & last_word_mask);
63 } else {
64 self.data[start_word_index] &= !first_word_mask;
67
68 for i in start_word_index + 1..end_word_index {
70 self.data[i] = 0;
71 }
72
73 self.data[end_word_index] &= !last_word_mask;
75 }
76 }
77
78 fn len(&self) -> usize {
81 self.data.len() * 64
82 }
83
84 pub fn next_clear_bit(&self, from_index: usize) -> usize {
87 let mut u = self.word_index(from_index);
88 if u >= self.data.len() {
89 return from_index;
90 }
91
92 let mut word = !self.data[u] & (u64::MAX.wrapping_shl(from_index.try_into().unwrap()));
93
94 loop {
95 if word != 0 {
96 return (u * 64) + word.trailing_zeros() as usize;
97 }
98 u += 1;
99 if u == self.data.len() {
100 return self.data.len() * 64;
101 }
102 word = !self.data[u];
103 }
104 }
105
106 pub fn set(&mut self, bit_index: usize) {
107 self.data[bit_index / 64] |= 1u64 << (bit_index % 64);
108 }
109}
110
111impl From<Vec<u64>> for BitSet {
112 fn from(data: Vec<u64>) -> Self {
113 BitSet { data }
114 }
115}
116
117impl From<Vec<u8>> for BitSet {
118 fn from(data: Vec<u8>) -> Self {
119 let mut words = vec![0; data.len().div_ceil(8)];
120 for (i, byte) in data.iter().enumerate() {
121 words[i / 8] |= (*byte as u64) << ((i % 8) * 8);
122 }
123 BitSet { data: words }
124 }
125}
126
127#[derive(Debug, Clone, PartialEq, Eq, Hash)]
137pub struct FixedBitSet<const N: usize>
138where
139 [u8; bits_to_bytes(N)]: Sized,
140{
141 data: [u8; bits_to_bytes(N)],
142}
143
144impl<const N: usize> FixedBitSet<N>
145where
146 [u8; bits_to_bytes(N)]: Sized,
147{
148 pub const fn new() -> Self {
149 FixedBitSet {
150 data: [0; bits_to_bytes(N)],
151 }
152 }
153
154 pub const fn new_with_data(data: [u8; bits_to_bytes(N)]) -> Self {
155 FixedBitSet { data }
156 }
157
158 #[inline]
159 pub fn index(&self, index: usize) -> bool {
160 (self.data[index / 8] & (1u8 << (index % 8))) != 0
161 }
162
163 #[inline]
164 pub fn set(&mut self, bit_index: usize) {
165 self.data[bit_index / 8] |= 1u8 << (bit_index % 8);
166 }
167}
168
169impl<const N: usize> AzaleaRead for FixedBitSet<N>
170where
171 [u8; bits_to_bytes(N)]: Sized,
172{
173 fn azalea_read(buf: &mut Cursor<&[u8]>) -> Result<Self, BufReadError> {
174 let mut data = [0; bits_to_bytes(N)];
175 for item in data.iter_mut().take(bits_to_bytes(N)) {
176 *item = u8::azalea_read(buf)?;
177 }
178 Ok(FixedBitSet { data })
179 }
180}
181impl<const N: usize> AzaleaWrite for FixedBitSet<N>
182where
183 [u8; bits_to_bytes(N)]: Sized,
184{
185 fn azalea_write(&self, buf: &mut impl Write) -> io::Result<()> {
186 for i in 0..bits_to_bytes(N) {
187 self.data[i].azalea_write(buf)?;
188 }
189 Ok(())
190 }
191}
192impl<const N: usize> Default for FixedBitSet<N>
193where
194 [u8; bits_to_bytes(N)]: Sized,
195{
196 fn default() -> Self {
197 Self::new()
198 }
199}
200
201pub const fn bits_to_bytes(n: usize) -> usize {
202 n.div_ceil(8)
203}
204
205#[derive(Debug, Clone, PartialEq, Eq, Hash)]
213pub struct FastFixedBitSet<const N: usize>
214where
215 [u64; bits_to_longs(N)]: Sized,
216{
217 data: [u64; bits_to_longs(N)],
218}
219impl<const N: usize> FastFixedBitSet<N>
220where
221 [u64; bits_to_longs(N)]: Sized,
222{
223 pub const fn new() -> Self {
224 FastFixedBitSet {
225 data: [0; bits_to_longs(N)],
226 }
227 }
228
229 #[inline]
230 pub fn index(&self, index: usize) -> bool {
231 (self.data[index / 64] & (1u64 << (index % 64))) != 0
232 }
233
234 #[inline]
235 pub fn set(&mut self, bit_index: usize) {
236 self.data[bit_index / 64] |= 1u64 << (bit_index % 64);
237 }
238}
239impl<const N: usize> Default for FastFixedBitSet<N>
240where
241 [u64; bits_to_longs(N)]: Sized,
242{
243 fn default() -> Self {
244 Self::new()
245 }
246}
247pub const fn bits_to_longs(n: usize) -> usize {
248 n.div_ceil(64)
249}
250
251#[cfg(test)]
252mod tests {
253 use super::*;
254
255 #[test]
256 fn test_bitset() {
257 let mut bitset = BitSet::new(64);
258 assert!(!bitset.index(0));
259 assert!(!bitset.index(1));
260 assert!(!bitset.index(2));
261 bitset.set(1);
262 assert!(!bitset.index(0));
263 assert!(bitset.index(1));
264 assert!(!bitset.index(2));
265 }
266
267 #[test]
268 fn test_clear() {
269 let mut bitset = BitSet::new(128);
270 bitset.set(62);
271 bitset.set(63);
272 bitset.set(64);
273 bitset.set(65);
274 bitset.set(66);
275
276 bitset.clear(63, 65);
277
278 assert!(bitset.index(62));
279 assert!(!bitset.index(63));
280 assert!(!bitset.index(64));
281 assert!(bitset.index(65));
282 assert!(bitset.index(66));
283 }
284
285 #[test]
286 fn test_clear_2() {
287 let mut bitset = BitSet::new(128);
288 bitset.set(64);
289 bitset.set(65);
290 bitset.set(66);
291 bitset.set(67);
292 bitset.set(68);
293
294 bitset.clear(65, 67);
295
296 assert!(bitset.index(64));
297 assert!(!bitset.index(65));
298 assert!(!bitset.index(66));
299 assert!(bitset.index(67));
300 assert!(bitset.index(68));
301 }
302}