ab_trivial_type_derive/
lib.rs

1use proc_macro2::{Ident, Literal, TokenStream};
2use quote::{format_ident, quote};
3use std::iter;
4use syn::spanned::Spanned;
5use syn::token::Paren;
6use syn::{
7    Attribute, Data, DataEnum, DataStruct, DeriveInput, Error, Fields, LitInt, parenthesized,
8    parse_macro_input,
9};
10
11#[proc_macro_derive(TrivialType)]
12pub fn trivial_type_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
13    let input = parse_macro_input!(input as DeriveInput);
14
15    if !input.generics.params.is_empty() {
16        return Error::new(
17            input.ident.span(),
18            "`TrivialType` can't be derived on generic types",
19        )
20        .into_compile_error()
21        .into();
22    }
23
24    let maybe_repr_attr = input.attrs.iter().find(|attr| attr.path().is_ident("repr"));
25
26    let Some(repr_attr) = maybe_repr_attr else {
27        return Error::new(input.ident.span(), "`TrivialType` requires `#[repr(..)]`")
28            .to_compile_error()
29            .into();
30    };
31
32    let (repr_c, repr_transparent, repr_numeric, repr_align, repr_packed) =
33        match parse_repr(repr_attr) {
34            Ok(result) => result,
35            Err(error) => {
36                return error.to_compile_error().into();
37            }
38        };
39
40    if repr_align.is_some() || repr_packed.is_some() {
41        return Error::new(
42            input.ident.span(),
43            "`TrivialType` doesn't allow `#[repr(align(N))]` or `#[repr(packed(N))]",
44        )
45        .to_compile_error()
46        .into();
47    }
48
49    let type_name = &input.ident;
50
51    let output = match &input.data {
52        Data::Struct(data_struct) => {
53            if !(repr_c || repr_transparent) {
54                return Error::new(
55                    input.ident.span(),
56                    "`TrivialType` on structs requires `#[repr(C)]` or `#[repr(transparent)]",
57                )
58                .into_compile_error()
59                .into();
60            }
61            let field_types = data_struct
62                .fields
63                .iter()
64                .map(|field| &field.ty)
65                .collect::<Vec<_>>();
66
67            let struct_metadata = match generate_struct_metadata(type_name, data_struct) {
68                Ok(struct_metadata) => struct_metadata,
69                Err(error) => {
70                    return error.to_compile_error().into();
71                }
72            };
73
74            quote! {
75                const _: () = {
76                    // Assert statically that there is no unexpected padding that would be left
77                    // uninitialized and unsound to access
78                    assert!(
79                        0 == (
80                            ::core::mem::size_of::<#type_name>()
81                            #(- ::core::mem::size_of::<#field_types>() )*
82                        ),
83                        "Struct must not have implicit padding. Consider reordering fields, adding \
84                        `padding: [u8; N]` field where necessary or use `Unaligned<T>` wrapper for \
85                        types with larger alignment to reduce it to one byte."
86                    );
87
88                    // Assert that type doesn't exceed 32-bit size limit
89                    assert!(
90                        u32::MAX as ::core::primitive::usize >= ::core::mem::size_of::<#type_name>(),
91                        "Type size must be smaller than 2^32"
92                    );
93
94                    // Ensure capacity and alignment are correctly decoded from metadata
95                    let (type_details, _metadata) =
96                        ::ab_io_type::metadata::IoTypeMetadataKind::type_details(
97                            <#type_name as ::ab_io_type::trivial_type::TrivialType>::METADATA,
98                        )
99                            .expect("Statically correct metadata; qed");
100                    assert!(size_of::<#type_name>() == type_details.recommended_capacity as ::core::primitive::usize);
101                    assert!(align_of::<#type_name>() == type_details.alignment.get() as ::core::primitive::usize);
102                };
103
104                #[automatically_derived]
105                unsafe impl ::ab_io_type::trivial_type::TrivialType for #type_name
106                where
107                    #( #field_types: ::ab_io_type::trivial_type::TrivialType, )*
108                {
109                    const METADATA: &[::core::primitive::u8] = #struct_metadata;
110                }
111            }
112        }
113        Data::Enum(data_enum) => {
114            // Require defined size of the discriminant instead of allowing compiler to guess
115            if repr_numeric != Some(8) {
116                return Error::new(
117                    input.generics.span(),
118                    "`TrivialType` derive for enums only supports `#[repr(u8)]`, ambiguous \
119                    or larger discriminant size is not allowed",
120                )
121                .to_compile_error()
122                .into();
123            }
124
125            let repr_numeric = format_ident!("u8");
126
127            let field_types = data_enum
128                .variants
129                .iter()
130                .flat_map(|variant| &variant.fields)
131                .map(|field| &field.ty)
132                .collect::<Vec<_>>();
133
134            let padding_assertions = data_enum.variants.iter().map(|variant| {
135                let field_types = variant.fields.iter().map(|field| &field.ty);
136
137                quote! {
138                    // Assert statically that there is no unexpected padding that would be left
139                    // uninitialized and unsound to access
140                    assert!(
141                        0 == (
142                            ::core::mem::size_of::<#type_name>()
143                            - ::core::mem::size_of::<::core::primitive::#repr_numeric>()
144                            #(- ::core::mem::size_of::<#field_types>() )*
145                        ),
146                        "Enum must not have implicit padding. Consider reordering fields, adding \
147                        `padding: [u8; N]` field where necessary or use `Unaligned<T>` wrapper for \
148                        types with larger alignment to reduce it to one byte."
149                    );
150                }
151            });
152
153            let enum_metadata = match generate_enum_metadata(type_name, data_enum) {
154                Ok(struct_metadata) => struct_metadata,
155                Err(error) => {
156                    return error.to_compile_error().into();
157                }
158            };
159
160            quote! {
161                const _: () = {
162                    // Assert that type doesn't exceed 32-bit size limit
163                    assert!(
164                        u32::MAX as ::core::primitive::usize >= ::core::mem::size_of::<#type_name>(),
165                        "Type size must be smaller than 2^32"
166                    );
167
168                    // Ensure capacity and alignment are correctly decoded from metadata
169                    let (type_details, _metadata) =
170                        ::ab_io_type::metadata::IoTypeMetadataKind::type_details(
171                            <#type_name as ::ab_io_type::trivial_type::TrivialType>::METADATA,
172                        )
173                            .expect("Statically correct metadata; qed");
174                    assert!(size_of::<#type_name>() == type_details.recommended_capacity as ::core::primitive::usize);
175                    assert!(align_of::<#type_name>() == type_details.alignment.get() as ::core::primitive::usize);
176
177                    #( #padding_assertions )*;
178                };
179
180                #[automatically_derived]
181                unsafe impl ::ab_io_type::trivial_type::TrivialType for #type_name
182                where
183                    #( #field_types: ::ab_io_type::trivial_type::TrivialType, )*
184                {
185                    const METADATA: &[::core::primitive::u8] = #enum_metadata;
186                }
187            }
188        }
189        Data::Union(data_union) => {
190            return Error::new(
191                data_union.union_token.span(),
192                "`TrivialType` can be derived for structs and enums, but not unions",
193            )
194            .to_compile_error()
195            .into();
196        }
197    };
198
199    output.into()
200}
201
202#[expect(clippy::type_complexity, reason = "Private one-off function")]
203fn parse_repr(
204    repr_attr: &Attribute,
205) -> Result<(bool, bool, Option<u8>, Option<usize>, Option<usize>), Error> {
206    let mut repr_c = false;
207    let mut repr_transparent = false;
208    let mut repr_numeric = None::<u8>;
209    let mut repr_align = None::<usize>;
210    let mut repr_packed = None::<usize>;
211
212    // Based on https://docs.rs/syn/2.0.93/syn/struct.Attribute.html#method.parse_nested_meta
213    repr_attr.parse_nested_meta(|meta| {
214        if meta.path.is_ident("C") {
215            repr_c = true;
216            return Ok(());
217        }
218        if meta.path.is_ident("u8") {
219            repr_numeric.replace(8);
220            return Ok(());
221        }
222        if meta.path.is_ident("u16") {
223            repr_numeric.replace(16);
224            return Ok(());
225        }
226        if meta.path.is_ident("u32") {
227            repr_numeric.replace(32);
228            return Ok(());
229        }
230        if meta.path.is_ident("u64") {
231            repr_numeric.replace(64);
232            return Ok(());
233        }
234        if meta.path.is_ident("u128") {
235            repr_numeric.replace(128);
236            return Ok(());
237        }
238        if meta.path.is_ident("transparent") {
239            repr_transparent = true;
240            return Ok(());
241        }
242
243        // #[repr(align(N))]
244        if meta.path.is_ident("align") {
245            let content;
246            parenthesized!(content in meta.input);
247            let lit = content.parse::<LitInt>()?;
248            let n = lit.base10_parse::<usize>()?;
249            repr_align = Some(n);
250            return Ok(());
251        }
252
253        // #[repr(packed)] or #[repr(packed(N))], omitted N means 1
254        if meta.path.is_ident("packed") {
255            if meta.input.peek(Paren) {
256                let content;
257                parenthesized!(content in meta.input);
258                let lit = content.parse::<LitInt>()?;
259                let n = lit.base10_parse::<usize>()?;
260                repr_packed = Some(n);
261            } else {
262                repr_packed = Some(1);
263            }
264            return Ok(());
265        }
266
267        Err(meta.error("Unsupported `#[repr(..)]`"))
268    })?;
269
270    Ok((
271        repr_c,
272        repr_transparent,
273        repr_numeric,
274        repr_align,
275        repr_packed,
276    ))
277}
278
279fn generate_struct_metadata(ident: &Ident, data_struct: &DataStruct) -> Result<TokenStream, Error> {
280    let num_fields = data_struct.fields.len();
281    let struct_with_fields = data_struct
282        .fields
283        .iter()
284        .next()
285        .is_some_and(|field| field.ident.is_some());
286    let (io_type_metadata, with_num_fields) = if struct_with_fields {
287        match num_fields {
288            0..=16 => (format_ident!("Struct{num_fields}"), false),
289            _ => (format_ident!("Struct"), true),
290        }
291    } else {
292        match num_fields {
293            1..=16 => (format_ident!("TupleStruct{num_fields}"), false),
294            _ => (format_ident!("TupleStruct"), true),
295        }
296    };
297    let inner_struct_metadata =
298        generate_inner_struct_metadata(ident, &data_struct.fields, with_num_fields)
299            .collect::<Result<Vec<_>, _>>()?;
300
301    // Encodes the following:
302    // * Type: struct
303    // * The rest as inner struct metadata
304    Ok(quote! {{
305        #[inline(always)]
306        const fn metadata() -> (
307            [::core::primitive::u8; ::ab_io_type::metadata::MAX_METADATA_CAPACITY],
308            usize,
309        )
310        {
311            ::ab_io_type::metadata::concat_metadata_sources(&[
312                &[::ab_io_type::metadata::IoTypeMetadataKind::#io_type_metadata as ::core::primitive::u8],
313                #( #inner_struct_metadata )*
314            ])
315        }
316
317        // Strange syntax to allow Rust to extend the lifetime of metadata scratch automatically
318        metadata()
319            .0
320            .split_at(metadata().1)
321            .0
322    }})
323}
324
325fn generate_enum_metadata(ident: &Ident, data_enum: &DataEnum) -> Result<TokenStream, Error> {
326    let type_name_string = ident.to_string();
327    let type_name_bytes = type_name_string.as_bytes();
328
329    let type_name_bytes_len = u8::try_from(type_name_bytes.len()).map_err(|_error| {
330        Error::new(
331            ident.span(),
332            format!(
333                "Name of the enum must not be more than {} bytes in length",
334                u8::MAX
335            ),
336        )
337    })?;
338    let num_variants = u8::try_from(data_enum.variants.len()).map_err(|_error| {
339        Error::new(
340            ident.span(),
341            format!("Enum must not have more than {} variants", u8::MAX),
342        )
343    })?;
344    let variant_has_fields = data_enum
345        .variants
346        .iter()
347        .next()
348        .is_some_and(|variant| !variant.fields.is_empty());
349    let enum_type = if variant_has_fields {
350        "Enum"
351    } else {
352        "EnumNoFields"
353    };
354    let (io_type_metadata, with_num_variants) = match num_variants {
355        1..=16 => (format_ident!("{enum_type}{num_variants}"), false),
356        _ => (format_ident!("{enum_type}"), true),
357    };
358
359    // Encodes the following:
360    // * Type: enum
361    // * Length of enum name in bytes (u8)
362    // * Enum name as UTF-8 bytes
363    // * Number of variants (u8, if requested)
364    let enum_metadata_header = {
365        let enum_metadata_header = [Literal::u8_unsuffixed(type_name_bytes_len)]
366            .into_iter()
367            .chain(
368                type_name_bytes
369                    .iter()
370                    .map(|&char| Literal::byte_character(char)),
371            )
372            .chain(with_num_variants.then_some(Literal::u8_unsuffixed(num_variants)));
373
374        quote! {
375            &[
376                ::ab_io_type::metadata::IoTypeMetadataKind::#io_type_metadata as ::core::primitive::u8,
377                #( #enum_metadata_header, )*
378            ]
379        }
380    };
381
382    // Encodes each variant as inner struct
383    let inner = data_enum
384        .variants
385        .iter()
386        .flat_map(|variant| {
387            variant
388                .fields
389                .iter()
390                .find_map(|field| {
391                    if field.ident.is_none() {
392                        Some(Err(Error::new(
393                            field.span(),
394                            "Variant must have named fields",
395                        )))
396                    } else {
397                        None
398                    }
399                })
400                .into_iter()
401                .chain(generate_inner_struct_metadata(
402                    &variant.ident,
403                    &variant.fields,
404                    variant_has_fields,
405                ))
406        })
407        .collect::<Result<Vec<TokenStream>, Error>>()?;
408
409    Ok(quote! {{
410        #[inline(always)]
411        const fn metadata() -> (
412            [::core::primitive::u8; ::ab_io_type::metadata::MAX_METADATA_CAPACITY],
413            usize,
414        )
415        {
416            ::ab_io_type::metadata::concat_metadata_sources(&[
417                #enum_metadata_header,
418                #( #inner )*
419            ])
420        }
421
422        // Strange syntax to allow Rust to extend the lifetime of metadata scratch automatically
423        metadata()
424            .0
425            .split_at(metadata().1)
426            .0
427    }})
428}
429
430fn generate_inner_struct_metadata<'a>(
431    ident: &'a Ident,
432    fields: &'a Fields,
433    with_num_fields: bool,
434) -> impl Iterator<Item = Result<TokenStream, Error>> + 'a {
435    iter::once_with(move || generate_inner_struct_metadata_header(ident, fields, with_num_fields))
436        .chain(generate_fields_metadata(fields))
437}
438
439fn generate_inner_struct_metadata_header(
440    ident: &Ident,
441    fields: &Fields,
442    with_num_fields: bool,
443) -> Result<TokenStream, Error> {
444    let ident_string = ident.to_string();
445    let ident_bytes = ident_string.as_bytes();
446
447    let ident_bytes_len = u8::try_from(ident_bytes.len()).map_err(|_error| {
448        Error::new(
449            ident.span(),
450            format!(
451                "Identifier must not be more than {} bytes in length",
452                u8::MAX
453            ),
454        )
455    })?;
456    let num_fields = u8::try_from(fields.len()).map_err(|_error| {
457        Error::new(
458            fields.span(),
459            format!("Must not have more than {} field", u8::MAX),
460        )
461    })?;
462
463    // Encodes the following:
464    // * Length of identifier in bytes (u8)
465    // * Identifier as UTF-8 bytes
466    // * Number of fields (u8, if requested)
467    Ok({
468        let struct_metadata_header = [Literal::u8_unsuffixed(ident_bytes_len)]
469            .into_iter()
470            .chain(
471                ident_bytes
472                    .iter()
473                    .map(|&char| Literal::byte_character(char)),
474            )
475            .chain(with_num_fields.then_some(Literal::u8_unsuffixed(num_fields)));
476
477        quote! {
478            &[#( #struct_metadata_header, )*],
479        }
480    })
481}
482
483fn generate_fields_metadata(
484    fields: &Fields,
485) -> impl Iterator<Item = Result<TokenStream, Error>> + '_ {
486    // Encodes the following for each field:
487    // * Length of the field name in bytes (u8, if not tuple)
488    // * Field name as UTF-8 bytes (if not tuple)
489    // * Recursive metadata of the field's type
490    fields.iter().map(move |field| {
491        let field_metadata = if let Some(field_name) = &field.ident {
492            let field_name_string = field_name.to_string();
493            let field_name_bytes = field_name_string.as_bytes();
494            let field_name_len = u8::try_from(field_name_bytes.len()).map_err(|_error| {
495                Error::new(
496                    field.span(),
497                    format!(
498                        "Name of the field must not be more than {} bytes in length",
499                        u8::MAX
500                    ),
501                )
502            })?;
503
504            let field_metadata = [Literal::u8_unsuffixed(field_name_len)].into_iter().chain(
505                field_name_bytes
506                    .iter()
507                    .map(|&char| Literal::byte_character(char)),
508            );
509
510            Some(quote! { #( #field_metadata, )* })
511        } else {
512            None
513        };
514        let field_type = &field.ty;
515
516        Ok(quote! {
517            &[ #field_metadata ],
518            <#field_type as ::ab_io_type::trivial_type::TrivialType>::METADATA,
519        })
520    })
521}