Skip to main content

azalea_protocol_macros/
lib.rs

1mod utils;
2
3use proc_macro::TokenStream;
4use quote::quote;
5use syn::{
6    DeriveInput, Expr, Ident, Token, bracketed,
7    parse::{Parse, ParseStream, Result},
8    parse_macro_input,
9};
10
11use crate::utils::{to_camel_case, to_snake_case};
12
13fn as_packet_derive(input: TokenStream, state: proc_macro2::TokenStream) -> TokenStream {
14    let DeriveInput { ident, data, .. } = parse_macro_input!(input);
15
16    // technically it would still work with enums and non-named structs but for
17    // consistency in the api it's nicer if they are all just structs, which is why
18    // we enforce this here
19    let syn::Data::Struct(syn::DataStruct { fields, .. }) = &data else {
20        panic!("#[derive(*Packet)] can only be used on structs")
21    };
22    let (syn::Fields::Named(_) | syn::Fields::Unit) = fields else {
23        panic!("#[derive(*Packet)] can only be used on structs with named fields")
24    };
25
26    let variant_name = variant_name_from(&ident);
27
28    let contents = quote! {
29        impl #ident {
30            pub fn write(&self, buf: &mut impl std::io::Write) -> std::io::Result<()> {
31                azalea_buf::AzBuf::azalea_write(self, buf)
32            }
33
34            pub fn read(
35                buf: &mut std::io::Cursor<&[u8]>,
36            ) -> Result<#state, azalea_buf::BufReadError> {
37                use azalea_buf::AzBuf;
38                Ok(crate::packets::Packet::into_variant(Self::azalea_read(buf)?))
39            }
40
41        }
42
43        impl crate::packets::Packet<#state> for #ident {
44            fn into_variant(self) -> #state {
45                #state::#variant_name(self)
46            }
47        }
48    };
49
50    contents.into()
51}
52
53#[proc_macro_derive(ServerboundGamePacket, attributes(var))]
54pub fn derive_s_game_packet(input: TokenStream) -> TokenStream {
55    as_packet_derive(input, quote! {crate::packets::game::ServerboundGamePacket})
56}
57#[proc_macro_derive(ServerboundHandshakePacket, attributes(var))]
58pub fn derive_s_handshake_packet(input: TokenStream) -> TokenStream {
59    as_packet_derive(
60        input,
61        quote! {crate::packets::handshake::ServerboundHandshakePacket},
62    )
63}
64#[proc_macro_derive(ServerboundLoginPacket, attributes(var))]
65pub fn derive_s_login_packet(input: TokenStream) -> TokenStream {
66    as_packet_derive(
67        input,
68        quote! {crate::packets::login::ServerboundLoginPacket},
69    )
70}
71#[proc_macro_derive(ServerboundStatusPacket, attributes(var))]
72pub fn derive_s_status_packet(input: TokenStream) -> TokenStream {
73    as_packet_derive(
74        input,
75        quote! {crate::packets::status::ServerboundStatusPacket},
76    )
77}
78#[proc_macro_derive(ServerboundConfigPacket, attributes(var))]
79pub fn derive_s_config_packet(input: TokenStream) -> TokenStream {
80    as_packet_derive(
81        input,
82        quote! {crate::packets::config::ServerboundConfigPacket},
83    )
84}
85
86#[proc_macro_derive(ClientboundGamePacket, attributes(var))]
87pub fn derive_c_game_packet(input: TokenStream) -> TokenStream {
88    as_packet_derive(input, quote! {crate::packets::game::ClientboundGamePacket})
89}
90#[proc_macro_derive(ClientboundHandshakePacket, attributes(var))]
91pub fn derive_c_handshake_packet(input: TokenStream) -> TokenStream {
92    as_packet_derive(
93        input,
94        quote! {crate::packets::handshake::ClientboundHandshakePacket},
95    )
96}
97#[proc_macro_derive(ClientboundLoginPacket, attributes(var))]
98pub fn derive_c_login_packet(input: TokenStream) -> TokenStream {
99    as_packet_derive(
100        input,
101        quote! {crate::packets::login::ClientboundLoginPacket},
102    )
103}
104#[proc_macro_derive(ClientboundStatusPacket, attributes(var))]
105pub fn derive_c_status_packet(input: TokenStream) -> TokenStream {
106    as_packet_derive(
107        input,
108        quote! {crate::packets::status::ClientboundStatusPacket},
109    )
110}
111#[proc_macro_derive(ClientboundConfigPacket, attributes(var))]
112pub fn derive_c_config_packet(input: TokenStream) -> TokenStream {
113    as_packet_derive(
114        input,
115        quote! {crate::packets::config::ClientboundConfigPacket},
116    )
117}
118
119#[derive(Debug)]
120struct PacketList {
121    packets: Vec<Ident>,
122}
123
124impl Parse for PacketList {
125    fn parse(input: ParseStream) -> Result<Self> {
126        let mut packets = vec![];
127
128        // example:
129        // change_difficulty,
130        // keep_alive,
131        while let Ok(packet_name) = input.parse::<Ident>() {
132            packets.push(packet_name);
133            if input.parse::<Token![,]>().is_err() {
134                break;
135            }
136        }
137
138        Ok(PacketList { packets })
139    }
140}
141
142#[derive(Debug)]
143struct DeclareStatePackets {
144    name: Ident,
145    clientbound: PacketList,
146    serverbound: PacketList,
147}
148
149impl Parse for DeclareStatePackets {
150    fn parse(input: ParseStream) -> Result<Self> {
151        let name = input.parse()?;
152        input.parse::<Token![,]>()?;
153
154        let clientbound_token: Ident = input.parse()?;
155        if clientbound_token != "Clientbound" {
156            return Err(syn::Error::new(
157                clientbound_token.span(),
158                "Expected `Clientbound`",
159            ));
160        }
161        input.parse::<Token![=>]>()?;
162        let content;
163        bracketed!(content in input);
164        let clientbound = content.parse()?;
165
166        input.parse::<Token![,]>()?;
167
168        let serverbound_token: Ident = input.parse()?;
169        if serverbound_token != "Serverbound" {
170            return Err(syn::Error::new(
171                serverbound_token.span(),
172                "Expected `Serverbound`",
173            ));
174        }
175        input.parse::<Token![=>]>()?;
176        let content;
177        bracketed!(content in input);
178        let serverbound = content.parse()?;
179
180        Ok(DeclareStatePackets {
181            name,
182            serverbound,
183            clientbound,
184        })
185    }
186}
187#[proc_macro]
188pub fn declare_state_packets(input: TokenStream) -> TokenStream {
189    let input = parse_macro_input!(input as DeclareStatePackets);
190
191    let clientbound_state_name =
192        Ident::new(&format!("Clientbound{}", input.name), input.name.span());
193    let serverbound_state_name =
194        Ident::new(&format!("Serverbound{}", input.name), input.name.span());
195
196    let state_name_litstr = syn::LitStr::new(&input.name.to_string(), input.name.span());
197
198    let has_clientbound_packets = !input.clientbound.packets.is_empty();
199    let has_serverbound_packets = !input.serverbound.packets.is_empty();
200
201    let mut mod_and_use_statements_contents = quote! {};
202    let mut clientbound_enum_contents = quote! {};
203    let mut serverbound_enum_contents = quote! {};
204    let mut clientbound_id_match_contents = quote! {};
205    let mut serverbound_id_match_contents = quote! {};
206    let mut clientbound_name_match_contents = quote! {};
207    let mut serverbound_name_match_contents = quote! {};
208    let mut clientbound_write_match_contents = quote! {};
209    let mut serverbound_write_match_contents = quote! {};
210    let mut clientbound_read_match_contents = quote! {};
211    let mut serverbound_read_match_contents = quote! {};
212
213    for (id, packet_name) in input.clientbound.packets.iter().enumerate() {
214        let id = id as u32;
215
216        let struct_name = packet_name_to_struct_name(packet_name, "clientbound");
217        let module_name = packet_name_to_module_name(packet_name, "clientbound");
218        let variant_name = packet_name_to_variant_name(packet_name);
219        let packet_name_litstr = syn::LitStr::new(&packet_name.to_string(), packet_name.span());
220
221        mod_and_use_statements_contents.extend(quote! {
222            pub mod #module_name;
223            pub use #module_name::#struct_name;
224        });
225
226        clientbound_enum_contents.extend(quote! {
227            #variant_name(#module_name::#struct_name),
228        });
229        clientbound_id_match_contents.extend(quote! {
230            #clientbound_state_name::#variant_name(..) => #id,
231        });
232        clientbound_name_match_contents.extend(quote! {
233            #clientbound_state_name::#variant_name(..) => #packet_name_litstr,
234        });
235        clientbound_write_match_contents.extend(quote! {
236            #clientbound_state_name::#variant_name(packet) => packet.write(buf),
237        });
238        clientbound_read_match_contents.extend(quote! {
239            #id => {
240                let data = #module_name::#struct_name::read(buf).map_err(|e| crate::read::ReadPacketError::Parse {
241                    source: e,
242                    packet_id: #id,
243                    backtrace: Box::new(std::backtrace::Backtrace::capture()),
244                    packet_name: #packet_name_litstr.to_string(),
245                })?;
246                #[cfg(debug_assertions)]
247                {
248                    let mut leftover = Vec::new();
249                    let _ = std::io::Read::read_to_end(buf, &mut leftover);
250                    if !leftover.is_empty() {
251                        return Err(
252                            Box::new(
253                                crate::read::ReadPacketError::LeftoverData {
254                                    packet_name: #packet_name_litstr.to_string(),
255                                    data: leftover
256                                }
257                            )
258                        );
259                    }
260                }
261                data
262            },
263        });
264    }
265    for (id, packet_name) in input.serverbound.packets.iter().enumerate() {
266        let id = id as u32;
267
268        let struct_name = packet_name_to_struct_name(packet_name, "serverbound");
269        let module_name = packet_name_to_module_name(packet_name, "serverbound");
270        let variant_name = packet_name_to_variant_name(packet_name);
271        let packet_name_litstr = syn::LitStr::new(&packet_name.to_string(), packet_name.span());
272
273        mod_and_use_statements_contents.extend(quote! {
274            pub mod #module_name;
275            pub use #module_name::#struct_name;
276        });
277
278        serverbound_enum_contents.extend(quote! {
279            #variant_name(#module_name::#struct_name),
280        });
281        serverbound_id_match_contents.extend(quote! {
282            #serverbound_state_name::#variant_name(..) => #id,
283        });
284        serverbound_name_match_contents.extend(quote! {
285            #serverbound_state_name::#variant_name(..) => #packet_name_litstr,
286        });
287        serverbound_write_match_contents.extend(quote! {
288            #serverbound_state_name::#variant_name(packet) => packet.write(buf),
289        });
290        serverbound_read_match_contents.extend(quote! {
291            #id => {
292                let data = #module_name::#struct_name::read(buf).map_err(|e| crate::read::ReadPacketError::Parse {
293                    source: e,
294                    packet_id: #id,
295                    backtrace: Box::new(std::backtrace::Backtrace::capture()),
296                    packet_name: #packet_name_litstr.to_string(),
297                })?;
298                #[cfg(debug_assertions)]
299                {
300                    let mut leftover = Vec::new();
301                    let _ = std::io::Read::read_to_end(buf, &mut leftover);
302                    if !leftover.is_empty() {
303                        return Err(Box::new(crate::read::ReadPacketError::LeftoverData { packet_name: #packet_name_litstr.to_string(), data: leftover }));
304                    }
305                }
306                data
307            },
308        });
309    }
310
311    if !has_serverbound_packets {
312        serverbound_id_match_contents.extend(quote! {
313            _ => unreachable!("This enum is empty and can't exist.")
314        });
315        serverbound_name_match_contents.extend(quote! {
316            _ => unreachable!("This enum is empty and can't exist.")
317        });
318        serverbound_write_match_contents.extend(quote! {
319            _ => unreachable!("This enum is empty and can't exist.")
320        });
321    }
322    if !has_clientbound_packets {
323        clientbound_id_match_contents.extend(quote! {
324            _ => unreachable!("This enum is empty and can't exist.")
325        });
326        clientbound_name_match_contents.extend(quote! {
327            _ => unreachable!("This enum is empty and can't exist.")
328        });
329        clientbound_write_match_contents.extend(quote! {
330            _ => unreachable!("This enum is empty and can't exist.")
331        });
332    }
333
334    let mut contents = quote! {
335        #mod_and_use_statements_contents
336
337        #[derive(Clone, Debug, PartialEq)]
338        pub enum #clientbound_state_name
339        where
340            Self: Sized,
341        {
342            #clientbound_enum_contents
343        }
344        #[derive(Clone, Debug, PartialEq)]
345        pub enum #serverbound_state_name
346        where
347        Self: Sized,
348        {
349            #serverbound_enum_contents
350        }
351    };
352
353    contents.extend(quote! {
354        #[allow(unreachable_code)]
355        impl crate::packets::ProtocolPacket for #serverbound_state_name {
356            fn id(&self) -> u32 {
357                match self {
358                    #serverbound_id_match_contents
359                }
360            }
361
362            fn name(&self) -> &'static str {
363                match self {
364                    #serverbound_name_match_contents
365                }
366            }
367
368            fn write(&self, buf: &mut impl std::io::Write) -> std::io::Result<()> {
369                match self {
370                    #serverbound_write_match_contents
371                }
372            }
373
374            /// Read a packet by its ID, `ConnectionProtocol`, and flow.
375            fn read(
376                id: u32,
377                buf: &mut std::io::Cursor<&[u8]>,
378            ) -> Result<#serverbound_state_name, Box<crate::read::ReadPacketError>>
379            where
380                Self: Sized,
381            {
382                Ok(match id {
383                    #serverbound_read_match_contents
384                    _ => return Err(Box::new(crate::read::ReadPacketError::UnknownPacketId { state_name: #state_name_litstr.to_string(), id })),
385                })
386            }
387        }
388
389        impl crate::packets::Packet<#serverbound_state_name> for #serverbound_state_name {
390            /// No-op, exists so you can pass a packet enum when a Packet<> is expected.
391            fn into_variant(self) -> #serverbound_state_name {
392                self
393            }
394        }
395    });
396
397    contents.extend(quote! {
398        #[allow(unreachable_code)]
399        impl crate::packets::ProtocolPacket for #clientbound_state_name {
400            fn id(&self) -> u32 {
401                match self {
402                    #clientbound_id_match_contents
403                }
404            }
405
406            fn name(&self) -> &'static str {
407                match self {
408                    #clientbound_name_match_contents
409                }
410            }
411
412            fn write(&self, buf: &mut impl std::io::Write) -> std::io::Result<()> {
413                match self {
414                    #clientbound_write_match_contents
415                }
416            }
417
418            /// Read a packet by its ID, `ConnectionProtocol`, and flow.
419            fn read(
420                id: u32,
421                buf: &mut std::io::Cursor<&[u8]>,
422            ) -> Result<#clientbound_state_name, Box<crate::read::ReadPacketError>>
423            where
424                Self: Sized,
425            {
426                Ok(match id {
427                    #clientbound_read_match_contents
428                    _ => return Err(Box::new(crate::read::ReadPacketError::UnknownPacketId { state_name: #state_name_litstr.to_string(), id })),
429                })
430            }
431        }
432
433        impl crate::packets::Packet<#clientbound_state_name> for #clientbound_state_name {
434            /// No-op, exists so you can pass a packet enum when a Packet<> is expected.
435            fn into_variant(self) -> #clientbound_state_name {
436                self
437            }
438        }
439    });
440
441    contents.into()
442}
443
444fn variant_name_from(name: &syn::Ident) -> syn::Ident {
445    // remove "<direction>Bound" from the start and "Packet" from the end
446    let mut variant_name = name.to_string();
447    if variant_name.starts_with("Clientbound") {
448        variant_name = variant_name["Clientbound".len()..].to_string();
449    } else if variant_name.starts_with("Serverbound") {
450        variant_name = variant_name["Serverbound".len()..].to_string();
451    }
452    syn::Ident::new(&variant_name, name.span())
453}
454
455fn packet_name_to_struct_name(name: &syn::Ident, direction: &str) -> syn::Ident {
456    let struct_name_snake = format!("{direction}_{name}");
457    let struct_name = to_camel_case(&struct_name_snake);
458    syn::Ident::new(&struct_name, name.span())
459}
460fn packet_name_to_module_name(name: &syn::Ident, direction: &str) -> syn::Ident {
461    let module_name_snake = format!("{}_{name}", direction.chars().next().unwrap());
462    let module_name = to_snake_case(&module_name_snake);
463    syn::Ident::new(&module_name, name.span())
464}
465fn packet_name_to_variant_name(name: &syn::Ident) -> syn::Ident {
466    let variant_name = to_camel_case(&name.to_string());
467    syn::Ident::new(&variant_name, name.span())
468}
469
470struct DeclarePacketHandlers {
471    packet_enum: Ident,
472    packet_var: Expr,
473    handler: Ident,
474    packets: Box<[syn::Path]>,
475}
476impl Parse for DeclarePacketHandlers {
477    fn parse(input: ParseStream) -> Result<Self> {
478        let packet_enum = input.parse()?;
479        input.parse::<Token![,]>()?;
480        let packet_var = input.parse()?;
481        input.parse::<Token![,]>()?;
482        let handler = input.parse()?;
483        input.parse::<Token![,]>()?;
484
485        let packets_inner;
486        bracketed!(packets_inner in input);
487        let packets = packets_inner
488            .parse_terminated(syn::Path::parse, Token![,])?
489            .into_iter()
490            .collect::<Box<[_]>>();
491
492        Ok(Self {
493            packet_enum,
494            packet_var,
495            handler,
496            packets,
497        })
498    }
499}
500#[proc_macro]
501pub fn declare_packet_handlers(input: TokenStream) -> TokenStream {
502    let DeclarePacketHandlers {
503        packet_enum,
504        packet_var,
505        handler,
506        packets,
507    } = parse_macro_input!(input as DeclarePacketHandlers);
508    // needs to be a proc macro to be able to convert snake_case to camelCase
509
510    let mut inner = quote! {};
511    for packet in packets {
512        let mut packet_camel = packet.clone();
513        let packet_last_ident = &mut packet_camel.segments.last_mut().unwrap().ident;
514        *packet_last_ident = Ident::new(
515            &to_camel_case(&packet_last_ident.to_string()),
516            packet_last_ident.span(),
517        );
518
519        inner.extend(quote! {
520            #packet_enum::#packet_camel(p) => #handler.#packet(p),
521        });
522    }
523
524    quote! {
525        match #packet_var {
526            #inner
527        }
528    }
529    .into()
530}