azalea_protocol_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{
4    DeriveInput, Ident, Token, bracketed,
5    parse::{Parse, ParseStream, Result},
6    parse_macro_input,
7};
8
9fn as_packet_derive(input: TokenStream, state: proc_macro2::TokenStream) -> TokenStream {
10    let DeriveInput { ident, data, .. } = parse_macro_input!(input);
11
12    // technically it would still work with enums and non-named structs but for
13    // consistency in the api it's nicer if they are all just structs, which is why
14    // we enforce this here
15    let syn::Data::Struct(syn::DataStruct { fields, .. }) = &data else {
16        panic!("#[derive(*Packet)] can only be used on structs")
17    };
18    let (syn::Fields::Named(_) | syn::Fields::Unit) = fields else {
19        panic!("#[derive(*Packet)] can only be used on structs with named fields")
20    };
21
22    let variant_name = variant_name_from(&ident);
23
24    let contents = quote! {
25        impl #ident {
26            pub fn write(&self, buf: &mut impl std::io::Write) -> Result<(), std::io::Error> {
27                azalea_buf::AzaleaWrite::azalea_write(self, buf)
28            }
29
30            pub fn read(
31                buf: &mut std::io::Cursor<&[u8]>,
32            ) -> Result<#state, azalea_buf::BufReadError> {
33                use azalea_buf::AzaleaRead;
34                Ok(crate::packets::Packet::into_variant(Self::azalea_read(buf)?))
35            }
36
37        }
38
39        impl crate::packets::Packet<#state> for #ident {
40            fn into_variant(self) -> #state {
41                #state::#variant_name(self)
42            }
43        }
44    };
45
46    contents.into()
47}
48
49#[proc_macro_derive(ServerboundGamePacket, attributes(var))]
50pub fn derive_s_game_packet(input: TokenStream) -> TokenStream {
51    as_packet_derive(input, quote! {crate::packets::game::ServerboundGamePacket})
52}
53#[proc_macro_derive(ServerboundHandshakePacket, attributes(var))]
54pub fn derive_s_handshake_packet(input: TokenStream) -> TokenStream {
55    as_packet_derive(
56        input,
57        quote! {crate::packets::handshake::ServerboundHandshakePacket},
58    )
59}
60#[proc_macro_derive(ServerboundLoginPacket, attributes(var))]
61pub fn derive_s_login_packet(input: TokenStream) -> TokenStream {
62    as_packet_derive(
63        input,
64        quote! {crate::packets::login::ServerboundLoginPacket},
65    )
66}
67#[proc_macro_derive(ServerboundStatusPacket, attributes(var))]
68pub fn derive_s_status_packet(input: TokenStream) -> TokenStream {
69    as_packet_derive(
70        input,
71        quote! {crate::packets::status::ServerboundStatusPacket},
72    )
73}
74#[proc_macro_derive(ServerboundConfigPacket, attributes(var))]
75pub fn derive_s_config_packet(input: TokenStream) -> TokenStream {
76    as_packet_derive(
77        input,
78        quote! {crate::packets::config::ServerboundConfigPacket},
79    )
80}
81
82#[proc_macro_derive(ClientboundGamePacket, attributes(var))]
83pub fn derive_c_game_packet(input: TokenStream) -> TokenStream {
84    as_packet_derive(input, quote! {crate::packets::game::ClientboundGamePacket})
85}
86#[proc_macro_derive(ClientboundHandshakePacket, attributes(var))]
87pub fn derive_c_handshake_packet(input: TokenStream) -> TokenStream {
88    as_packet_derive(
89        input,
90        quote! {crate::packets::handshake::ClientboundHandshakePacket},
91    )
92}
93#[proc_macro_derive(ClientboundLoginPacket, attributes(var))]
94pub fn derive_c_login_packet(input: TokenStream) -> TokenStream {
95    as_packet_derive(
96        input,
97        quote! {crate::packets::login::ClientboundLoginPacket},
98    )
99}
100#[proc_macro_derive(ClientboundStatusPacket, attributes(var))]
101pub fn derive_c_status_packet(input: TokenStream) -> TokenStream {
102    as_packet_derive(
103        input,
104        quote! {crate::packets::status::ClientboundStatusPacket},
105    )
106}
107#[proc_macro_derive(ClientboundConfigPacket, attributes(var))]
108pub fn derive_c_config_packet(input: TokenStream) -> TokenStream {
109    as_packet_derive(
110        input,
111        quote! {crate::packets::config::ClientboundConfigPacket},
112    )
113}
114
115#[derive(Debug)]
116struct PacketList {
117    packets: Vec<Ident>,
118}
119
120impl Parse for PacketList {
121    fn parse(input: ParseStream) -> Result<Self> {
122        let mut packets = vec![];
123
124        // example:
125        // change_difficulty,
126        // keep_alive,
127        while let Ok(packet_name) = input.parse::<Ident>() {
128            packets.push(packet_name);
129            if input.parse::<Token![,]>().is_err() {
130                break;
131            }
132        }
133
134        Ok(PacketList { packets })
135    }
136}
137
138#[derive(Debug)]
139struct DeclareStatePackets {
140    name: Ident,
141    clientbound: PacketList,
142    serverbound: PacketList,
143}
144
145impl Parse for DeclareStatePackets {
146    fn parse(input: ParseStream) -> Result<Self> {
147        let name = input.parse()?;
148        input.parse::<Token![,]>()?;
149
150        let clientbound_token: Ident = input.parse()?;
151        if clientbound_token != "Clientbound" {
152            return Err(syn::Error::new(
153                clientbound_token.span(),
154                "Expected `Clientbound`",
155            ));
156        }
157        input.parse::<Token![=>]>()?;
158        let content;
159        bracketed!(content in input);
160        let clientbound = content.parse()?;
161
162        input.parse::<Token![,]>()?;
163
164        let serverbound_token: Ident = input.parse()?;
165        if serverbound_token != "Serverbound" {
166            return Err(syn::Error::new(
167                serverbound_token.span(),
168                "Expected `Serverbound`",
169            ));
170        }
171        input.parse::<Token![=>]>()?;
172        let content;
173        bracketed!(content in input);
174        let serverbound = content.parse()?;
175
176        Ok(DeclareStatePackets {
177            name,
178            serverbound,
179            clientbound,
180        })
181    }
182}
183#[proc_macro]
184pub fn declare_state_packets(input: TokenStream) -> TokenStream {
185    let input = parse_macro_input!(input as DeclareStatePackets);
186
187    let clientbound_state_name =
188        Ident::new(&format!("Clientbound{}", input.name), input.name.span());
189    let serverbound_state_name =
190        Ident::new(&format!("Serverbound{}", input.name), input.name.span());
191
192    let state_name_litstr = syn::LitStr::new(&input.name.to_string(), input.name.span());
193
194    let has_clientbound_packets = !input.clientbound.packets.is_empty();
195    let has_serverbound_packets = !input.serverbound.packets.is_empty();
196
197    let mut mod_and_use_statements_contents = quote!();
198    let mut clientbound_enum_contents = quote!();
199    let mut serverbound_enum_contents = quote!();
200    let mut clientbound_id_match_contents = quote!();
201    let mut serverbound_id_match_contents = quote!();
202    let mut clientbound_name_match_contents = quote!();
203    let mut serverbound_name_match_contents = quote!();
204    let mut clientbound_write_match_contents = quote!();
205    let mut serverbound_write_match_contents = quote!();
206    let mut clientbound_read_match_contents = quote!();
207    let mut serverbound_read_match_contents = quote!();
208
209    for (id, packet_name) in input.clientbound.packets.iter().enumerate() {
210        let id = id as u32;
211
212        let struct_name = packet_name_to_struct_name(packet_name, "clientbound");
213        let module_name = packet_name_to_module_name(packet_name, "clientbound");
214        let variant_name = packet_name_to_variant_name(packet_name);
215        let packet_name_litstr = syn::LitStr::new(&packet_name.to_string(), packet_name.span());
216
217        mod_and_use_statements_contents.extend(quote! {
218            pub mod #module_name;
219            pub use #module_name::#struct_name;
220        });
221
222        clientbound_enum_contents.extend(quote! {
223            #variant_name(#module_name::#struct_name),
224        });
225        clientbound_id_match_contents.extend(quote! {
226            #clientbound_state_name::#variant_name(..) => #id,
227        });
228        clientbound_name_match_contents.extend(quote! {
229            #clientbound_state_name::#variant_name(..) => #packet_name_litstr,
230        });
231        clientbound_write_match_contents.extend(quote! {
232            #clientbound_state_name::#variant_name(packet) => packet.write(buf),
233        });
234        clientbound_read_match_contents.extend(quote! {
235            #id => {
236                let data = #module_name::#struct_name::read(buf).map_err(|e| crate::read::ReadPacketError::Parse {
237                    source: e,
238                    packet_id: #id,
239                    backtrace: Box::new(std::backtrace::Backtrace::capture()),
240                    packet_name: #packet_name_litstr.to_string(),
241                })?;
242                #[cfg(debug_assertions)]
243                {
244                    let mut leftover = Vec::new();
245                    let _ = std::io::Read::read_to_end(buf, &mut leftover);
246                    if !leftover.is_empty() {
247                        return Err(
248                            Box::new(
249                                crate::read::ReadPacketError::LeftoverData {
250                                    packet_name: #packet_name_litstr.to_string(),
251                                    data: leftover
252                                }
253                            )
254                        );
255                    }
256                }
257                data
258            },
259        });
260    }
261    for (id, packet_name) in input.serverbound.packets.iter().enumerate() {
262        let id = id as u32;
263
264        let struct_name = packet_name_to_struct_name(packet_name, "serverbound");
265        let module_name = packet_name_to_module_name(packet_name, "serverbound");
266        let variant_name = packet_name_to_variant_name(packet_name);
267        let packet_name_litstr = syn::LitStr::new(&packet_name.to_string(), packet_name.span());
268
269        mod_and_use_statements_contents.extend(quote! {
270            pub mod #module_name;
271            pub use #module_name::#struct_name;
272        });
273
274        serverbound_enum_contents.extend(quote! {
275            #variant_name(#module_name::#struct_name),
276        });
277        serverbound_id_match_contents.extend(quote! {
278            #serverbound_state_name::#variant_name(..) => #id,
279        });
280        serverbound_name_match_contents.extend(quote! {
281            #serverbound_state_name::#variant_name(..) => #packet_name_litstr,
282        });
283        serverbound_write_match_contents.extend(quote! {
284            #serverbound_state_name::#variant_name(packet) => packet.write(buf),
285        });
286        serverbound_read_match_contents.extend(quote! {
287            #id => {
288                let data = #module_name::#struct_name::read(buf).map_err(|e| crate::read::ReadPacketError::Parse {
289                    source: e,
290                    packet_id: #id,
291                    backtrace: Box::new(std::backtrace::Backtrace::capture()),
292                    packet_name: #packet_name_litstr.to_string(),
293                })?;
294                #[cfg(debug_assertions)]
295                {
296                    let mut leftover = Vec::new();
297                    let _ = std::io::Read::read_to_end(buf, &mut leftover);
298                    if !leftover.is_empty() {
299                        return Err(Box::new(crate::read::ReadPacketError::LeftoverData { packet_name: #packet_name_litstr.to_string(), data: leftover }));
300                    }
301                }
302                data
303            },
304        });
305    }
306
307    if !has_serverbound_packets {
308        serverbound_id_match_contents.extend(quote! {
309            _ => unreachable!("This enum is empty and can't exist.")
310        });
311        serverbound_name_match_contents.extend(quote! {
312            _ => unreachable!("This enum is empty and can't exist.")
313        });
314        serverbound_write_match_contents.extend(quote! {
315            _ => unreachable!("This enum is empty and can't exist.")
316        });
317    }
318    if !has_clientbound_packets {
319        clientbound_id_match_contents.extend(quote! {
320            _ => unreachable!("This enum is empty and can't exist.")
321        });
322        clientbound_name_match_contents.extend(quote! {
323            _ => unreachable!("This enum is empty and can't exist.")
324        });
325        clientbound_write_match_contents.extend(quote! {
326            _ => unreachable!("This enum is empty and can't exist.")
327        });
328    }
329
330    let mut contents = quote! {
331        #mod_and_use_statements_contents
332
333        #[derive(Clone, Debug)]
334        pub enum #clientbound_state_name
335        where
336            Self: Sized,
337        {
338            #clientbound_enum_contents
339        }
340        #[derive(Clone, Debug)]
341        pub enum #serverbound_state_name
342        where
343        Self: Sized,
344        {
345            #serverbound_enum_contents
346        }
347    };
348
349    contents.extend(quote! {
350        #[allow(unreachable_code)]
351        impl crate::packets::ProtocolPacket for #serverbound_state_name {
352            fn id(&self) -> u32 {
353                match self {
354                    #serverbound_id_match_contents
355                }
356            }
357
358            fn name(&self) -> &'static str {
359                match self {
360                    #serverbound_name_match_contents
361                }
362            }
363
364            fn write(&self, buf: &mut impl std::io::Write) -> Result<(), std::io::Error> {
365                match self {
366                    #serverbound_write_match_contents
367                }
368            }
369
370            /// Read a packet by its id, ConnectionProtocol, and flow.
371            fn read(
372                id: u32,
373                buf: &mut std::io::Cursor<&[u8]>,
374            ) -> Result<#serverbound_state_name, Box<crate::read::ReadPacketError>>
375            where
376                Self: Sized,
377            {
378                Ok(match id {
379                    #serverbound_read_match_contents
380                    _ => return Err(Box::new(crate::read::ReadPacketError::UnknownPacketId { state_name: #state_name_litstr.to_string(), id })),
381                })
382            }
383        }
384
385        impl crate::packets::Packet<#serverbound_state_name> for #serverbound_state_name {
386            /// No-op, exists so you can pass a packet enum when a Packet<> is expected.
387            fn into_variant(self) -> #serverbound_state_name {
388                self
389            }
390        }
391    });
392
393    contents.extend(quote! {
394        #[allow(unreachable_code)]
395        impl crate::packets::ProtocolPacket for #clientbound_state_name {
396            fn id(&self) -> u32 {
397                match self {
398                    #clientbound_id_match_contents
399                }
400            }
401
402            fn name(&self) -> &'static str {
403                match self {
404                    #clientbound_name_match_contents
405                }
406            }
407
408            fn write(&self, buf: &mut impl std::io::Write) -> Result<(), std::io::Error> {
409                match self {
410                    #clientbound_write_match_contents
411                }
412            }
413
414            /// Read a packet by its id, ConnectionProtocol, and flow.
415            fn read(
416                id: u32,
417                buf: &mut std::io::Cursor<&[u8]>,
418            ) -> Result<#clientbound_state_name, Box<crate::read::ReadPacketError>>
419            where
420                Self: Sized,
421            {
422                Ok(match id {
423                    #clientbound_read_match_contents
424                    _ => return Err(Box::new(crate::read::ReadPacketError::UnknownPacketId { state_name: #state_name_litstr.to_string(), id })),
425                })
426            }
427        }
428
429        impl crate::packets::Packet<#clientbound_state_name> for #clientbound_state_name {
430            /// No-op, exists so you can pass a packet enum when a Packet<> is expected.
431            fn into_variant(self) -> #clientbound_state_name {
432                self
433            }
434        }
435    });
436
437    contents.into()
438}
439
440fn variant_name_from(name: &syn::Ident) -> syn::Ident {
441    // remove "<direction>Bound" from the start and "Packet" from the end
442    let mut variant_name = name.to_string();
443    if variant_name.starts_with("Clientbound") {
444        variant_name = variant_name["Clientbound".len()..].to_string();
445    } else if variant_name.starts_with("Serverbound") {
446        variant_name = variant_name["Serverbound".len()..].to_string();
447    }
448    syn::Ident::new(&variant_name, name.span())
449}
450
451fn packet_name_to_struct_name(name: &syn::Ident, direction: &str) -> syn::Ident {
452    let struct_name_snake = format!("{direction}_{name}");
453    let struct_name = to_camel_case(&struct_name_snake);
454    syn::Ident::new(&struct_name, name.span())
455}
456fn packet_name_to_module_name(name: &syn::Ident, direction: &str) -> syn::Ident {
457    let module_name_snake = format!("{}_{name}", direction.chars().next().unwrap());
458    let module_name = to_snake_case(&module_name_snake);
459    syn::Ident::new(&module_name, name.span())
460}
461fn packet_name_to_variant_name(name: &syn::Ident) -> syn::Ident {
462    let variant_name = to_camel_case(&name.to_string());
463    syn::Ident::new(&variant_name, name.span())
464}
465
466fn to_camel_case(snake_case: &str) -> String {
467    let mut camel_case = String::new();
468    let mut capitalize_next = true;
469    for c in snake_case.chars() {
470        if c == '_' {
471            capitalize_next = true;
472        } else {
473            if capitalize_next {
474                camel_case.push(c.to_ascii_uppercase());
475            } else {
476                camel_case.push(c);
477            }
478            capitalize_next = false;
479        }
480    }
481    camel_case
482}
483fn to_snake_case(camel_case: &str) -> String {
484    let mut snake_case = String::new();
485    for c in camel_case.chars() {
486        if c.is_ascii_uppercase() {
487            snake_case.push('_');
488            snake_case.push(c.to_ascii_lowercase());
489        } else {
490            snake_case.push(c);
491        }
492    }
493    snake_case
494}