Skip to main content

azalea_buf_macros/
read.rs

1use quote::{ToTokens, quote};
2use syn::{Data, Field, FieldsNamed, Ident, punctuated::Punctuated, token::Comma};
3
4pub fn create_fn_azalea_read(data: &Data) -> proc_macro2::TokenStream {
5    match data {
6        syn::Data::Struct(syn::DataStruct { fields, .. }) => match fields {
7            syn::Fields::Named(FieldsNamed { named, .. }) => {
8                let (read_fields, read_field_names) = read_named_fields(named);
9
10                quote! {
11                    fn azalea_read(buf: &mut std::io::Cursor<&[u8]>) -> std::result::Result<Self, azalea_buf::BufReadError> {
12                        #(#read_fields)*
13                        Ok(Self {
14                            #(#read_field_names: #read_field_names),*
15                        })
16                    }
17                }
18            }
19            syn::Fields::Unit => {
20                quote! {
21                    fn azalea_read(buf: &mut std::io::Cursor<&[u8]>) -> std::result::Result<Self, azalea_buf::BufReadError> {
22                        Ok(Self)
23                    }
24                }
25            }
26            syn::Fields::Unnamed(fields) => {
27                let read_fields = read_unnamed_fields(&fields.unnamed);
28
29                quote! {
30                    fn azalea_read(buf: &mut std::io::Cursor<&[u8]>) -> std::result::Result<Self, azalea_buf::BufReadError> {
31                        Ok(Self(
32                            #(#read_fields),*
33                        ))
34                    }
35                }
36            }
37        },
38        syn::Data::Enum(syn::DataEnum { variants, .. }) => {
39            let mut match_contents = quote!();
40            let mut variant_discrim: u32 = 0;
41            let mut first = true;
42            let mut first_reader = None;
43            for variant in variants {
44                let variant_name = &variant.ident;
45                match &variant.discriminant.as_ref() {
46                    Some(d) => {
47                        variant_discrim = match &d.1 {
48                            syn::Expr::Lit(e) => match &e.lit {
49                                syn::Lit::Int(i) => i.base10_parse().unwrap(),
50                                _ => panic!("Error parsing enum discriminant as int (is {e:?})"),
51                            },
52                            syn::Expr::Unary(_) => {
53                                panic!("Negative enum discriminants are not supported")
54                            }
55                            _ => {
56                                panic!("Error parsing enum discriminant as literal (is {:?})", d.1)
57                            }
58                        }
59                    }
60                    None => {
61                        if !first {
62                            variant_discrim += 1;
63                        }
64                    }
65                }
66                let reader = match &variant.fields {
67                    syn::Fields::Named(f) => {
68                        let (read_fields, read_field_names) = read_named_fields(&f.named);
69
70                        quote! {
71                            #(#read_fields)*
72                            Ok(Self::#variant_name {
73                                #(#read_field_names: #read_field_names),*
74                            })
75                        }
76                    }
77                    syn::Fields::Unnamed(fields) => {
78                        let mut reader_code = quote! {};
79                        for f in &fields.unnamed {
80                            let is_variable_length =
81                                f.attrs.iter().any(|a| a.path().is_ident("var"));
82                            let limit =
83                                f.attrs
84                                    .iter()
85                                    .find(|a| a.path().is_ident("limit"))
86                                    .map(|a| {
87                                        a.parse_args::<syn::LitInt>()
88                                            .unwrap()
89                                            .base10_parse::<u32>()
90                                            .unwrap()
91                                    });
92
93                            if is_variable_length && limit.is_some() {
94                                panic!("Fields cannot have both var and limit attributes");
95                            }
96
97                            if is_variable_length {
98                                reader_code.extend(quote! {
99                                    Self::#variant_name(azalea_buf::AzBufVar::azalea_read_var(buf)?),
100                                });
101                            } else if let Some(limit) = limit {
102                                reader_code.extend(quote! {
103                                    Self::#variant_name(azalea_buf::AzBufLimited::azalea_read_limited(buf, #limit)?),
104                                });
105                            } else {
106                                reader_code.extend(quote! {
107                                    Self::#variant_name(azalea_buf::AzBuf::azalea_read(buf)?),
108                                });
109                            }
110                        }
111                        quote! { Ok(#reader_code) }
112                    }
113                    syn::Fields::Unit => quote! {
114                        Ok(Self::#variant_name)
115                    },
116                };
117                if first {
118                    first_reader = Some(reader.clone());
119                    first = false;
120                };
121
122                match_contents.extend(quote! {
123                    #variant_discrim => {
124                        #reader
125                    },
126                });
127            }
128
129            let first_reader = first_reader.expect("There should be at least one variant");
130
131            quote! {
132                fn azalea_read(buf: &mut std::io::Cursor<&[u8]>) -> std::result::Result<Self, azalea_buf::BufReadError> {
133                    let id = azalea_buf::AzBufVar::azalea_read_var(buf)?;
134
135                    match id {
136                        #match_contents
137                        // you'd THINK this throws an error, but mojang decided to make it default for some reason
138                        _ => {#first_reader}
139                    }
140                }
141            }
142        }
143        _ => panic!("#[derive(AzBuf)] can only be used on structs"),
144    }
145}
146
147fn read_named_fields(
148    named: &Punctuated<Field, Comma>,
149) -> (Vec<proc_macro2::TokenStream>, Vec<&Option<Ident>>) {
150    let read_fields = named
151        .iter()
152        .map(|f| {
153            let field_name = &f.ident;
154
155            let reader_call = get_reader_call(f);
156            quote! { let #field_name = #reader_call; }
157        })
158        .collect::<Vec<_>>();
159    let read_field_names = named.iter().map(|f| &f.ident).collect::<Vec<_>>();
160
161    (read_fields, read_field_names)
162}
163
164fn read_unnamed_fields(unnamed: &Punctuated<Field, Comma>) -> Vec<proc_macro2::TokenStream> {
165    unnamed
166        .iter()
167        .map(|f| {
168            let reader_call = get_reader_call(f);
169            quote! { #reader_call }
170        })
171        .collect::<Vec<_>>()
172}
173
174fn get_reader_call(f: &Field) -> proc_macro2::TokenStream {
175    let is_variable_length = f
176        .attrs
177        .iter()
178        .any(|a: &syn::Attribute| a.path().is_ident("var"));
179    let limit = f
180        .attrs
181        .iter()
182        .find(|a| a.path().is_ident("limit"))
183        .map(|a| {
184            a.parse_args::<syn::LitInt>()
185                .unwrap()
186                .base10_parse::<u32>()
187                .unwrap()
188        });
189
190    if is_variable_length && limit.is_some() {
191        panic!("Fields cannot have both var and limit attributes");
192    }
193
194    let field_type = &f.ty;
195
196    // do a different buf.write_* for each field depending on the type
197    // if it's a string, use buf.write_string
198    match field_type {
199        syn::Type::Path(_) | syn::Type::Array(_) => {
200            if is_variable_length {
201                quote! {
202                    azalea_buf::AzBufVar::azalea_read_var(buf)?
203                }
204            } else if let Some(limit) = limit {
205                quote! {
206                    azalea_buf::AzBufLimited::azalea_read_limited(buf, #limit)?
207                }
208            } else {
209                quote! {
210                    azalea_buf::AzBuf::azalea_read(buf)?
211                }
212            }
213        }
214        _ => panic!(
215            "Error reading field {:?}: {}",
216            f.ident.clone(),
217            field_type.to_token_stream()
218        ),
219    }
220}