ab_contracts_trivial_type_derive/
lib.rs

1use proc_macro2::{Ident, Literal, TokenStream};
2use quote::{format_ident, quote};
3use std::iter;
4use syn::spanned::Spanned;
5use syn::token::Paren;
6use syn::{
7    Attribute, Data, DataEnum, DataStruct, DeriveInput, Error, Fields, LitInt, parenthesized,
8    parse_macro_input,
9};
10
11#[proc_macro_derive(TrivialType)]
12pub fn trivial_type_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
13    let input = parse_macro_input!(input as DeriveInput);
14
15    if !input.generics.params.is_empty() {
16        return Error::new(
17            input.ident.span(),
18            "`TrivialType` can't be derived on generic types",
19        )
20        .into_compile_error()
21        .into();
22    }
23
24    let maybe_repr_attr = input.attrs.iter().find(|attr| attr.path().is_ident("repr"));
25
26    let Some(repr_attr) = maybe_repr_attr else {
27        return Error::new(input.ident.span(), "`TrivialType` requires `#[repr(..)]`")
28            .to_compile_error()
29            .into();
30    };
31
32    let (repr_c, repr_transparent, repr_numeric, repr_align, repr_packed) =
33        match parse_repr(repr_attr) {
34            Ok(result) => result,
35            Err(error) => {
36                return error.to_compile_error().into();
37            }
38        };
39
40    if repr_align.is_some() || repr_packed.is_some() {
41        return Error::new(
42            input.ident.span(),
43            "`TrivialType` doesn't allow `#[repr(align(N))]` or `#[repr(packed(N))]",
44        )
45        .to_compile_error()
46        .into();
47    }
48
49    let type_name = &input.ident;
50
51    let output = match &input.data {
52        Data::Struct(data_struct) => {
53            if !(repr_c || repr_transparent) {
54                return Error::new(
55                    input.ident.span(),
56                    "`TrivialType` on structs requires `#[repr(C)]` or `#[repr(transparent)]",
57                )
58                .into_compile_error()
59                .into();
60            }
61            let field_types = data_struct
62                .fields
63                .iter()
64                .map(|field| &field.ty)
65                .collect::<Vec<_>>();
66
67            let struct_metadata = match generate_struct_metadata(type_name, data_struct) {
68                Ok(struct_metadata) => struct_metadata,
69                Err(error) => {
70                    return error.to_compile_error().into();
71                }
72            };
73
74            quote! {
75                const _: () = {
76                    // Assert statically that there is no unexpected padding that would be left
77                    // uninitialized and unsound to access
78                    assert!(
79                        0 == (
80                            ::core::mem::size_of::<#type_name>()
81                            #(- ::core::mem::size_of::<#field_types>() )*
82                        ),
83                        "Struct must not have implicit padding"
84                    );
85
86                    // Assert that type doesn't exceed 32-bit size limit
87                    assert!(
88                        u32::MAX as ::core::primitive::usize >= ::core::mem::size_of::<#type_name>(),
89                        "Type size must be smaller than 2^32"
90                    );
91
92                    // Ensure capacity and alignment are correctly decoded from metadata
93                    let (type_details, _metadata) =
94                        ::ab_contracts_io_type::metadata::IoTypeMetadataKind::type_details(
95                            <#type_name as ::ab_contracts_io_type::trivial_type::TrivialType>::METADATA,
96                        )
97                            .expect("Statically correct metadata; qed");
98                    assert!(size_of::<#type_name>() == type_details.recommended_capacity as ::core::primitive::usize);
99                    assert!(align_of::<#type_name>() == type_details.alignment.get() as ::core::primitive::usize);
100                };
101
102                #[automatically_derived]
103                unsafe impl ::ab_contracts_io_type::trivial_type::TrivialType for #type_name
104                where
105                    #( #field_types: ::ab_contracts_io_type::trivial_type::TrivialType, )*
106                {
107                    const METADATA: &[::core::primitive::u8] = #struct_metadata;
108                }
109            }
110        }
111        Data::Enum(data_enum) => {
112            // Require defined size of the discriminant instead of allowing compiler to guess
113            if repr_numeric != Some(8) {
114                return Error::new(
115                    input.generics.span(),
116                    "`TrivialType` derive for enums only supports `#[repr(u8)]`, ambiguous \
117                    or larger discriminant size is not allowed",
118                )
119                .to_compile_error()
120                .into();
121            }
122
123            let repr_numeric = format_ident!("u8");
124
125            let field_types = data_enum
126                .variants
127                .iter()
128                .flat_map(|variant| &variant.fields)
129                .map(|field| &field.ty)
130                .collect::<Vec<_>>();
131
132            let padding_assertions = data_enum.variants.iter().map(|variant| {
133                let field_types = variant.fields.iter().map(|field| &field.ty);
134
135                quote! {
136                    // Assert statically that there is no unexpected padding that would be left
137                    // uninitialized and unsound to access
138                    assert!(
139                        0 == (
140                            ::core::mem::size_of::<#type_name>()
141                            - ::core::mem::size_of::<::core::primitive::#repr_numeric>()
142                            #(- ::core::mem::size_of::<#field_types>() )*
143                        ),
144                        "Enum must not have implicit padding"
145                    );
146                }
147            });
148
149            let enum_metadata = match generate_enum_metadata(type_name, data_enum) {
150                Ok(struct_metadata) => struct_metadata,
151                Err(error) => {
152                    return error.to_compile_error().into();
153                }
154            };
155
156            quote! {
157                const _: () = {
158                    // Assert that type doesn't exceed 32-bit size limit
159                    assert!(
160                        u32::MAX as ::core::primitive::usize >= ::core::mem::size_of::<#type_name>(),
161                        "Type size must be smaller than 2^32"
162                    );
163
164                    // Ensure capacity and alignment are correctly decoded from metadata
165                    let (type_details, _metadata) =
166                        ::ab_contracts_io_type::metadata::IoTypeMetadataKind::type_details(
167                            <#type_name as ::ab_contracts_io_type::trivial_type::TrivialType>::METADATA,
168                        )
169                            .expect("Statically correct metadata; qed");
170                    assert!(size_of::<#type_name>() == type_details.recommended_capacity as ::core::primitive::usize);
171                    assert!(align_of::<#type_name>() == type_details.alignment.get() as ::core::primitive::usize);
172
173                    #( #padding_assertions )*;
174                };
175
176                #[automatically_derived]
177                unsafe impl ::ab_contracts_io_type::trivial_type::TrivialType for #type_name
178                where
179                    #( #field_types: ::ab_contracts_io_type::trivial_type::TrivialType, )*
180                {
181                    const METADATA: &[::core::primitive::u8] = #enum_metadata;
182                }
183            }
184        }
185        Data::Union(data_union) => {
186            return Error::new(
187                data_union.union_token.span(),
188                "`TrivialType` can be derived for structs and enums, but not unions",
189            )
190            .to_compile_error()
191            .into();
192        }
193    };
194
195    output.into()
196}
197
198#[allow(clippy::type_complexity, reason = "Private one-off function")]
199fn parse_repr(
200    repr_attr: &Attribute,
201) -> Result<(bool, bool, Option<u8>, Option<usize>, Option<usize>), Error> {
202    let mut repr_c = false;
203    let mut repr_transparent = false;
204    let mut repr_numeric = None::<u8>;
205    let mut repr_align = None::<usize>;
206    let mut repr_packed = None::<usize>;
207
208    // Based on https://docs.rs/syn/2.0.93/syn/struct.Attribute.html#method.parse_nested_meta
209    repr_attr.parse_nested_meta(|meta| {
210        if meta.path.is_ident("C") {
211            repr_c = true;
212            return Ok(());
213        }
214        if meta.path.is_ident("u8") {
215            repr_numeric.replace(8);
216            return Ok(());
217        }
218        if meta.path.is_ident("u16") {
219            repr_numeric.replace(16);
220            return Ok(());
221        }
222        if meta.path.is_ident("u32") {
223            repr_numeric.replace(32);
224            return Ok(());
225        }
226        if meta.path.is_ident("u64") {
227            repr_numeric.replace(64);
228            return Ok(());
229        }
230        if meta.path.is_ident("u128") {
231            repr_numeric.replace(128);
232            return Ok(());
233        }
234        if meta.path.is_ident("transparent") {
235            repr_transparent = true;
236            return Ok(());
237        }
238
239        // #[repr(align(N))]
240        if meta.path.is_ident("align") {
241            let content;
242            parenthesized!(content in meta.input);
243            let lit = content.parse::<LitInt>()?;
244            let n = lit.base10_parse::<usize>()?;
245            repr_align = Some(n);
246            return Ok(());
247        }
248
249        // #[repr(packed)] or #[repr(packed(N))], omitted N means 1
250        if meta.path.is_ident("packed") {
251            if meta.input.peek(Paren) {
252                let content;
253                parenthesized!(content in meta.input);
254                let lit = content.parse::<LitInt>()?;
255                let n = lit.base10_parse::<usize>()?;
256                repr_packed = Some(n);
257            } else {
258                repr_packed = Some(1);
259            }
260            return Ok(());
261        }
262
263        Err(meta.error("Unsupported `#[repr(..)]`"))
264    })?;
265
266    Ok((
267        repr_c,
268        repr_transparent,
269        repr_numeric,
270        repr_align,
271        repr_packed,
272    ))
273}
274
275fn generate_struct_metadata(ident: &Ident, data_struct: &DataStruct) -> Result<TokenStream, Error> {
276    let num_fields = data_struct.fields.len();
277    let struct_with_fields = data_struct
278        .fields
279        .iter()
280        .next()
281        .is_some_and(|field| field.ident.is_some());
282    let (io_type_metadata, with_num_fields) = if struct_with_fields {
283        match num_fields {
284            0..=16 => (format_ident!("Struct{num_fields}"), false),
285            _ => (format_ident!("Struct"), true),
286        }
287    } else {
288        match num_fields {
289            1..=16 => (format_ident!("TupleStruct{num_fields}"), false),
290            _ => (format_ident!("TupleStruct"), true),
291        }
292    };
293    let inner_struct_metadata =
294        generate_inner_struct_metadata(ident, &data_struct.fields, with_num_fields)
295            .collect::<Result<Vec<_>, _>>()?;
296
297    // Encodes the following:
298    // * Type: struct
299    // * The rest as inner struct metadata
300    Ok(quote! {{
301        #[inline(always)]
302        const fn metadata() -> (
303            [::core::primitive::u8; ::ab_contracts_io_type::metadata::MAX_METADATA_CAPACITY],
304            usize,
305        )
306        {
307            ::ab_contracts_io_type::metadata::concat_metadata_sources(&[
308                &[::ab_contracts_io_type::metadata::IoTypeMetadataKind::#io_type_metadata as ::core::primitive::u8],
309                #( #inner_struct_metadata )*
310            ])
311        }
312
313        // Strange syntax to allow Rust to extend the lifetime of metadata scratch automatically
314        metadata()
315            .0
316            .split_at(metadata().1)
317            .0
318    }})
319}
320
321fn generate_enum_metadata(ident: &Ident, data_enum: &DataEnum) -> Result<TokenStream, Error> {
322    let type_name_string = ident.to_string();
323    let type_name_bytes = type_name_string.as_bytes();
324
325    let type_name_bytes_len = u8::try_from(type_name_bytes.len()).map_err(|_error| {
326        Error::new(
327            ident.span(),
328            format!(
329                "Name of the enum must not be more than {} bytes in length",
330                u8::MAX
331            ),
332        )
333    })?;
334    let num_variants = u8::try_from(data_enum.variants.len()).map_err(|_error| {
335        Error::new(
336            ident.span(),
337            format!("Enum must not have more than {} variants", u8::MAX),
338        )
339    })?;
340    let variant_has_fields = data_enum
341        .variants
342        .iter()
343        .next()
344        .is_some_and(|variant| !variant.fields.is_empty());
345    let enum_type = if variant_has_fields {
346        "Enum"
347    } else {
348        "EnumNoFields"
349    };
350    let (io_type_metadata, with_num_variants) = match num_variants {
351        1..=16 => (format_ident!("{enum_type}{num_variants}"), false),
352        _ => (format_ident!("{enum_type}"), true),
353    };
354
355    // Encodes the following:
356    // * Type: enum
357    // * Length of enum name in bytes (u8)
358    // * Enum name as UTF-8 bytes
359    // * Number of variants (u8, if requested)
360    let enum_metadata_header = {
361        let enum_metadata_header = [Literal::u8_unsuffixed(type_name_bytes_len)]
362            .into_iter()
363            .chain(
364                type_name_bytes
365                    .iter()
366                    .map(|&char| Literal::byte_character(char)),
367            )
368            .chain(with_num_variants.then_some(Literal::u8_unsuffixed(num_variants)));
369
370        quote! {
371            &[
372                ::ab_contracts_io_type::metadata::IoTypeMetadataKind::#io_type_metadata as ::core::primitive::u8,
373                #( #enum_metadata_header, )*
374            ]
375        }
376    };
377
378    // Encodes each variant as inner struct
379    let inner = data_enum
380        .variants
381        .iter()
382        .flat_map(|variant| {
383            variant
384                .fields
385                .iter()
386                .find_map(|field| {
387                    if field.ident.is_none() {
388                        Some(Err(Error::new(
389                            field.span(),
390                            "Variant must have named fields",
391                        )))
392                    } else {
393                        None
394                    }
395                })
396                .into_iter()
397                .chain(generate_inner_struct_metadata(
398                    &variant.ident,
399                    &variant.fields,
400                    variant_has_fields,
401                ))
402        })
403        .collect::<Result<Vec<TokenStream>, Error>>()?;
404
405    Ok(quote! {{
406        #[inline(always)]
407        const fn metadata() -> (
408            [::core::primitive::u8; ::ab_contracts_io_type::metadata::MAX_METADATA_CAPACITY],
409            usize,
410        )
411        {
412            ::ab_contracts_io_type::metadata::concat_metadata_sources(&[
413                #enum_metadata_header,
414                #( #inner )*
415            ])
416        }
417
418        // Strange syntax to allow Rust to extend the lifetime of metadata scratch automatically
419        metadata()
420            .0
421            .split_at(metadata().1)
422            .0
423    }})
424}
425
426fn generate_inner_struct_metadata<'a>(
427    ident: &'a Ident,
428    fields: &'a Fields,
429    with_num_fields: bool,
430) -> impl Iterator<Item = Result<TokenStream, Error>> + 'a {
431    iter::once_with(move || generate_inner_struct_metadata_header(ident, fields, with_num_fields))
432        .chain(generate_fields_metadata(fields))
433}
434
435fn generate_inner_struct_metadata_header(
436    ident: &Ident,
437    fields: &Fields,
438    with_num_fields: bool,
439) -> Result<TokenStream, Error> {
440    let ident_string = ident.to_string();
441    let ident_bytes = ident_string.as_bytes();
442
443    let ident_bytes_len = u8::try_from(ident_bytes.len()).map_err(|_error| {
444        Error::new(
445            ident.span(),
446            format!(
447                "Identifier must not be more than {} bytes in length",
448                u8::MAX
449            ),
450        )
451    })?;
452    let num_fields = u8::try_from(fields.len()).map_err(|_error| {
453        Error::new(
454            fields.span(),
455            format!("Must not have more than {} field", u8::MAX),
456        )
457    })?;
458
459    // Encodes the following:
460    // * Length of identifier in bytes (u8)
461    // * Identifier as UTF-8 bytes
462    // * Number of fields (u8, if requested)
463    Ok({
464        let struct_metadata_header = [Literal::u8_unsuffixed(ident_bytes_len)]
465            .into_iter()
466            .chain(
467                ident_bytes
468                    .iter()
469                    .map(|&char| Literal::byte_character(char)),
470            )
471            .chain(with_num_fields.then_some(Literal::u8_unsuffixed(num_fields)));
472
473        quote! {
474            &[#( #struct_metadata_header, )*],
475        }
476    })
477}
478
479fn generate_fields_metadata(
480    fields: &Fields,
481) -> impl Iterator<Item = Result<TokenStream, Error>> + '_ {
482    // Encodes the following for each field:
483    // * Length of the field name in bytes (u8, if not tuple)
484    // * Field name as UTF-8 bytes (if not tuple)
485    // * Recursive metadata of the field's type
486    fields.iter().map(move |field| {
487        let field_metadata = if let Some(field_name) = &field.ident {
488            let field_name_string = field_name.to_string();
489            let field_name_bytes = field_name_string.as_bytes();
490            let field_name_len = u8::try_from(field_name_bytes.len()).map_err(|_error| {
491                Error::new(
492                    field.span(),
493                    format!(
494                        "Name of the field must not be more than {} bytes in length",
495                        u8::MAX
496                    ),
497                )
498            })?;
499
500            let field_metadata = [Literal::u8_unsuffixed(field_name_len)].into_iter().chain(
501                field_name_bytes
502                    .iter()
503                    .map(|&char| Literal::byte_character(char)),
504            );
505
506            Some(quote! { #( #field_metadata, )* })
507        } else {
508            None
509        };
510        let field_type = &field.ty;
511
512        Ok(quote! {
513            &[ #field_metadata ],
514            <#field_type as ::ab_contracts_io_type::trivial_type::TrivialType>::METADATA,
515        })
516    })
517}