ab_contracts_trivial_type_derive/
lib.rs1use proc_macro2::{Ident, Literal, TokenStream};
2use quote::{format_ident, quote};
3use std::iter;
4use syn::spanned::Spanned;
5use syn::token::Paren;
6use syn::{
7 Attribute, Data, DataEnum, DataStruct, DeriveInput, Error, Fields, LitInt, parenthesized,
8 parse_macro_input,
9};
10
11#[proc_macro_derive(TrivialType)]
12pub fn trivial_type_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
13 let input = parse_macro_input!(input as DeriveInput);
14
15 if !input.generics.params.is_empty() {
16 return Error::new(
17 input.ident.span(),
18 "`TrivialType` can't be derived on generic types",
19 )
20 .into_compile_error()
21 .into();
22 }
23
24 let maybe_repr_attr = input.attrs.iter().find(|attr| attr.path().is_ident("repr"));
25
26 let Some(repr_attr) = maybe_repr_attr else {
27 return Error::new(input.ident.span(), "`TrivialType` requires `#[repr(..)]`")
28 .to_compile_error()
29 .into();
30 };
31
32 let (repr_c, repr_transparent, repr_numeric, repr_align, repr_packed) =
33 match parse_repr(repr_attr) {
34 Ok(result) => result,
35 Err(error) => {
36 return error.to_compile_error().into();
37 }
38 };
39
40 if repr_align.is_some() || repr_packed.is_some() {
41 return Error::new(
42 input.ident.span(),
43 "`TrivialType` doesn't allow `#[repr(align(N))]` or `#[repr(packed(N))]",
44 )
45 .to_compile_error()
46 .into();
47 }
48
49 let type_name = &input.ident;
50
51 let output = match &input.data {
52 Data::Struct(data_struct) => {
53 if !(repr_c || repr_transparent) {
54 return Error::new(
55 input.ident.span(),
56 "`TrivialType` on structs requires `#[repr(C)]` or `#[repr(transparent)]",
57 )
58 .into_compile_error()
59 .into();
60 }
61 let field_types = data_struct
62 .fields
63 .iter()
64 .map(|field| &field.ty)
65 .collect::<Vec<_>>();
66
67 let struct_metadata = match generate_struct_metadata(type_name, data_struct) {
68 Ok(struct_metadata) => struct_metadata,
69 Err(error) => {
70 return error.to_compile_error().into();
71 }
72 };
73
74 quote! {
75 const _: () = {
76 assert!(
79 0 == (
80 ::core::mem::size_of::<#type_name>()
81 #(- ::core::mem::size_of::<#field_types>() )*
82 ),
83 "Struct must not have implicit padding"
84 );
85
86 assert!(
88 u32::MAX as ::core::primitive::usize >= ::core::mem::size_of::<#type_name>(),
89 "Type size must be smaller than 2^32"
90 );
91
92 let (type_details, _metadata) =
94 ::ab_contracts_io_type::metadata::IoTypeMetadataKind::type_details(
95 <#type_name as ::ab_contracts_io_type::trivial_type::TrivialType>::METADATA,
96 )
97 .expect("Statically correct metadata; qed");
98 assert!(size_of::<#type_name>() == type_details.recommended_capacity as ::core::primitive::usize);
99 assert!(align_of::<#type_name>() == type_details.alignment.get() as ::core::primitive::usize);
100 };
101
102 #[automatically_derived]
103 unsafe impl ::ab_contracts_io_type::trivial_type::TrivialType for #type_name
104 where
105 #( #field_types: ::ab_contracts_io_type::trivial_type::TrivialType, )*
106 {
107 const METADATA: &[::core::primitive::u8] = #struct_metadata;
108 }
109 }
110 }
111 Data::Enum(data_enum) => {
112 if repr_numeric != Some(8) {
114 return Error::new(
115 input.generics.span(),
116 "`TrivialType` derive for enums only supports `#[repr(u8)]`, ambiguous \
117 or larger discriminant size is not allowed",
118 )
119 .to_compile_error()
120 .into();
121 }
122
123 let repr_numeric = format_ident!("u8");
124
125 let field_types = data_enum
126 .variants
127 .iter()
128 .flat_map(|variant| &variant.fields)
129 .map(|field| &field.ty)
130 .collect::<Vec<_>>();
131
132 let padding_assertions = data_enum.variants.iter().map(|variant| {
133 let field_types = variant.fields.iter().map(|field| &field.ty);
134
135 quote! {
136 assert!(
139 0 == (
140 ::core::mem::size_of::<#type_name>()
141 - ::core::mem::size_of::<::core::primitive::#repr_numeric>()
142 #(- ::core::mem::size_of::<#field_types>() )*
143 ),
144 "Enum must not have implicit padding"
145 );
146 }
147 });
148
149 let enum_metadata = match generate_enum_metadata(type_name, data_enum) {
150 Ok(struct_metadata) => struct_metadata,
151 Err(error) => {
152 return error.to_compile_error().into();
153 }
154 };
155
156 quote! {
157 const _: () = {
158 assert!(
160 u32::MAX as ::core::primitive::usize >= ::core::mem::size_of::<#type_name>(),
161 "Type size must be smaller than 2^32"
162 );
163
164 let (type_details, _metadata) =
166 ::ab_contracts_io_type::metadata::IoTypeMetadataKind::type_details(
167 <#type_name as ::ab_contracts_io_type::trivial_type::TrivialType>::METADATA,
168 )
169 .expect("Statically correct metadata; qed");
170 assert!(size_of::<#type_name>() == type_details.recommended_capacity as ::core::primitive::usize);
171 assert!(align_of::<#type_name>() == type_details.alignment.get() as ::core::primitive::usize);
172
173 #( #padding_assertions )*;
174 };
175
176 #[automatically_derived]
177 unsafe impl ::ab_contracts_io_type::trivial_type::TrivialType for #type_name
178 where
179 #( #field_types: ::ab_contracts_io_type::trivial_type::TrivialType, )*
180 {
181 const METADATA: &[::core::primitive::u8] = #enum_metadata;
182 }
183 }
184 }
185 Data::Union(data_union) => {
186 return Error::new(
187 data_union.union_token.span(),
188 "`TrivialType` can be derived for structs and enums, but not unions",
189 )
190 .to_compile_error()
191 .into();
192 }
193 };
194
195 output.into()
196}
197
198#[allow(clippy::type_complexity, reason = "Private one-off function")]
199fn parse_repr(
200 repr_attr: &Attribute,
201) -> Result<(bool, bool, Option<u8>, Option<usize>, Option<usize>), Error> {
202 let mut repr_c = false;
203 let mut repr_transparent = false;
204 let mut repr_numeric = None::<u8>;
205 let mut repr_align = None::<usize>;
206 let mut repr_packed = None::<usize>;
207
208 repr_attr.parse_nested_meta(|meta| {
210 if meta.path.is_ident("C") {
211 repr_c = true;
212 return Ok(());
213 }
214 if meta.path.is_ident("u8") {
215 repr_numeric.replace(8);
216 return Ok(());
217 }
218 if meta.path.is_ident("u16") {
219 repr_numeric.replace(16);
220 return Ok(());
221 }
222 if meta.path.is_ident("u32") {
223 repr_numeric.replace(32);
224 return Ok(());
225 }
226 if meta.path.is_ident("u64") {
227 repr_numeric.replace(64);
228 return Ok(());
229 }
230 if meta.path.is_ident("u128") {
231 repr_numeric.replace(128);
232 return Ok(());
233 }
234 if meta.path.is_ident("transparent") {
235 repr_transparent = true;
236 return Ok(());
237 }
238
239 if meta.path.is_ident("align") {
241 let content;
242 parenthesized!(content in meta.input);
243 let lit = content.parse::<LitInt>()?;
244 let n = lit.base10_parse::<usize>()?;
245 repr_align = Some(n);
246 return Ok(());
247 }
248
249 if meta.path.is_ident("packed") {
251 if meta.input.peek(Paren) {
252 let content;
253 parenthesized!(content in meta.input);
254 let lit = content.parse::<LitInt>()?;
255 let n = lit.base10_parse::<usize>()?;
256 repr_packed = Some(n);
257 } else {
258 repr_packed = Some(1);
259 }
260 return Ok(());
261 }
262
263 Err(meta.error("Unsupported `#[repr(..)]`"))
264 })?;
265
266 Ok((
267 repr_c,
268 repr_transparent,
269 repr_numeric,
270 repr_align,
271 repr_packed,
272 ))
273}
274
275fn generate_struct_metadata(ident: &Ident, data_struct: &DataStruct) -> Result<TokenStream, Error> {
276 let num_fields = data_struct.fields.len();
277 let struct_with_fields = data_struct
278 .fields
279 .iter()
280 .next()
281 .is_some_and(|field| field.ident.is_some());
282 let (io_type_metadata, with_num_fields) = if struct_with_fields {
283 match num_fields {
284 0..=16 => (format_ident!("Struct{num_fields}"), false),
285 _ => (format_ident!("Struct"), true),
286 }
287 } else {
288 match num_fields {
289 1..=16 => (format_ident!("TupleStruct{num_fields}"), false),
290 _ => (format_ident!("TupleStruct"), true),
291 }
292 };
293 let inner_struct_metadata =
294 generate_inner_struct_metadata(ident, &data_struct.fields, with_num_fields)
295 .collect::<Result<Vec<_>, _>>()?;
296
297 Ok(quote! {{
301 #[inline(always)]
302 const fn metadata() -> (
303 [::core::primitive::u8; ::ab_contracts_io_type::metadata::MAX_METADATA_CAPACITY],
304 usize,
305 )
306 {
307 ::ab_contracts_io_type::metadata::concat_metadata_sources(&[
308 &[::ab_contracts_io_type::metadata::IoTypeMetadataKind::#io_type_metadata as ::core::primitive::u8],
309 #( #inner_struct_metadata )*
310 ])
311 }
312
313 metadata()
315 .0
316 .split_at(metadata().1)
317 .0
318 }})
319}
320
321fn generate_enum_metadata(ident: &Ident, data_enum: &DataEnum) -> Result<TokenStream, Error> {
322 let type_name_string = ident.to_string();
323 let type_name_bytes = type_name_string.as_bytes();
324
325 let type_name_bytes_len = u8::try_from(type_name_bytes.len()).map_err(|_error| {
326 Error::new(
327 ident.span(),
328 format!(
329 "Name of the enum must not be more than {} bytes in length",
330 u8::MAX
331 ),
332 )
333 })?;
334 let num_variants = u8::try_from(data_enum.variants.len()).map_err(|_error| {
335 Error::new(
336 ident.span(),
337 format!("Enum must not have more than {} variants", u8::MAX),
338 )
339 })?;
340 let variant_has_fields = data_enum
341 .variants
342 .iter()
343 .next()
344 .is_some_and(|variant| !variant.fields.is_empty());
345 let enum_type = if variant_has_fields {
346 "Enum"
347 } else {
348 "EnumNoFields"
349 };
350 let (io_type_metadata, with_num_variants) = match num_variants {
351 1..=16 => (format_ident!("{enum_type}{num_variants}"), false),
352 _ => (format_ident!("{enum_type}"), true),
353 };
354
355 let enum_metadata_header = {
361 let enum_metadata_header = [Literal::u8_unsuffixed(type_name_bytes_len)]
362 .into_iter()
363 .chain(
364 type_name_bytes
365 .iter()
366 .map(|&char| Literal::byte_character(char)),
367 )
368 .chain(with_num_variants.then_some(Literal::u8_unsuffixed(num_variants)));
369
370 quote! {
371 &[
372 ::ab_contracts_io_type::metadata::IoTypeMetadataKind::#io_type_metadata as ::core::primitive::u8,
373 #( #enum_metadata_header, )*
374 ]
375 }
376 };
377
378 let inner = data_enum
380 .variants
381 .iter()
382 .flat_map(|variant| {
383 variant
384 .fields
385 .iter()
386 .find_map(|field| {
387 if field.ident.is_none() {
388 Some(Err(Error::new(
389 field.span(),
390 "Variant must have named fields",
391 )))
392 } else {
393 None
394 }
395 })
396 .into_iter()
397 .chain(generate_inner_struct_metadata(
398 &variant.ident,
399 &variant.fields,
400 variant_has_fields,
401 ))
402 })
403 .collect::<Result<Vec<TokenStream>, Error>>()?;
404
405 Ok(quote! {{
406 #[inline(always)]
407 const fn metadata() -> (
408 [::core::primitive::u8; ::ab_contracts_io_type::metadata::MAX_METADATA_CAPACITY],
409 usize,
410 )
411 {
412 ::ab_contracts_io_type::metadata::concat_metadata_sources(&[
413 #enum_metadata_header,
414 #( #inner )*
415 ])
416 }
417
418 metadata()
420 .0
421 .split_at(metadata().1)
422 .0
423 }})
424}
425
426fn generate_inner_struct_metadata<'a>(
427 ident: &'a Ident,
428 fields: &'a Fields,
429 with_num_fields: bool,
430) -> impl Iterator<Item = Result<TokenStream, Error>> + 'a {
431 iter::once_with(move || generate_inner_struct_metadata_header(ident, fields, with_num_fields))
432 .chain(generate_fields_metadata(fields))
433}
434
435fn generate_inner_struct_metadata_header(
436 ident: &Ident,
437 fields: &Fields,
438 with_num_fields: bool,
439) -> Result<TokenStream, Error> {
440 let ident_string = ident.to_string();
441 let ident_bytes = ident_string.as_bytes();
442
443 let ident_bytes_len = u8::try_from(ident_bytes.len()).map_err(|_error| {
444 Error::new(
445 ident.span(),
446 format!(
447 "Identifier must not be more than {} bytes in length",
448 u8::MAX
449 ),
450 )
451 })?;
452 let num_fields = u8::try_from(fields.len()).map_err(|_error| {
453 Error::new(
454 fields.span(),
455 format!("Must not have more than {} field", u8::MAX),
456 )
457 })?;
458
459 Ok({
464 let struct_metadata_header = [Literal::u8_unsuffixed(ident_bytes_len)]
465 .into_iter()
466 .chain(
467 ident_bytes
468 .iter()
469 .map(|&char| Literal::byte_character(char)),
470 )
471 .chain(with_num_fields.then_some(Literal::u8_unsuffixed(num_fields)));
472
473 quote! {
474 &[#( #struct_metadata_header, )*],
475 }
476 })
477}
478
479fn generate_fields_metadata(
480 fields: &Fields,
481) -> impl Iterator<Item = Result<TokenStream, Error>> + '_ {
482 fields.iter().map(move |field| {
487 let field_metadata = if let Some(field_name) = &field.ident {
488 let field_name_string = field_name.to_string();
489 let field_name_bytes = field_name_string.as_bytes();
490 let field_name_len = u8::try_from(field_name_bytes.len()).map_err(|_error| {
491 Error::new(
492 field.span(),
493 format!(
494 "Name of the field must not be more than {} bytes in length",
495 u8::MAX
496 ),
497 )
498 })?;
499
500 let field_metadata = [Literal::u8_unsuffixed(field_name_len)].into_iter().chain(
501 field_name_bytes
502 .iter()
503 .map(|&char| Literal::byte_character(char)),
504 );
505
506 Some(quote! { #( #field_metadata, )* })
507 } else {
508 None
509 };
510 let field_type = &field.ty;
511
512 Ok(quote! {
513 &[ #field_metadata ],
514 <#field_type as ::ab_contracts_io_type::trivial_type::TrivialType>::METADATA,
515 })
516 })
517}