1use proc_macro2::{Ident, Literal, TokenStream};
2use quote::{format_ident, quote};
3use std::iter;
4use syn::spanned::Spanned;
5use syn::token::Paren;
6use syn::{
7 Attribute, Data, DataEnum, DataStruct, DeriveInput, Error, Fields, LitInt, parenthesized,
8 parse_macro_input,
9};
10
11#[proc_macro_derive(TrivialType)]
12pub fn trivial_type_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
13 let input = parse_macro_input!(input as DeriveInput);
14
15 if !input.generics.params.is_empty() {
16 return Error::new(
17 input.ident.span(),
18 "`TrivialType` can't be derived on generic types",
19 )
20 .into_compile_error()
21 .into();
22 }
23
24 let maybe_repr_attr = input.attrs.iter().find(|attr| attr.path().is_ident("repr"));
25
26 let Some(repr_attr) = maybe_repr_attr else {
27 return Error::new(input.ident.span(), "`TrivialType` requires `#[repr(..)]`")
28 .to_compile_error()
29 .into();
30 };
31
32 let (repr_c, repr_transparent, repr_numeric, repr_align, repr_packed) =
33 match parse_repr(repr_attr) {
34 Ok(result) => result,
35 Err(error) => {
36 return error.to_compile_error().into();
37 }
38 };
39
40 if repr_align.is_some() || repr_packed.is_some() {
41 return Error::new(
42 input.ident.span(),
43 "`TrivialType` doesn't allow `#[repr(align(N))]` or `#[repr(packed(N))]",
44 )
45 .to_compile_error()
46 .into();
47 }
48
49 let type_name = &input.ident;
50
51 let output = match &input.data {
52 Data::Struct(data_struct) => {
53 if !(repr_c || repr_transparent) {
54 return Error::new(
55 input.ident.span(),
56 "`TrivialType` on structs requires `#[repr(C)]` or `#[repr(transparent)]",
57 )
58 .into_compile_error()
59 .into();
60 }
61 let field_types = data_struct
62 .fields
63 .iter()
64 .map(|field| &field.ty)
65 .collect::<Vec<_>>();
66
67 let struct_metadata = match generate_struct_metadata(type_name, data_struct) {
68 Ok(struct_metadata) => struct_metadata,
69 Err(error) => {
70 return error.to_compile_error().into();
71 }
72 };
73
74 quote! {
75 const _: () = {
76 assert!(
79 0 == (
80 ::core::mem::size_of::<#type_name>()
81 #(- ::core::mem::size_of::<#field_types>() )*
82 ),
83 "Struct must not have implicit padding. Consider reordering fields, adding \
84 `padding: [u8; N]` field where necessary or use `Unaligned<T>` wrapper for \
85 types with larger alignment to reduce it to one byte."
86 );
87
88 assert!(
90 u32::MAX as ::core::primitive::usize >= ::core::mem::size_of::<#type_name>(),
91 "Type size must be smaller than 2^32"
92 );
93
94 let (type_details, _metadata) =
96 ::ab_io_type::metadata::IoTypeMetadataKind::type_details(
97 <#type_name as ::ab_io_type::trivial_type::TrivialType>::METADATA,
98 )
99 .expect("Statically correct metadata; qed");
100 assert!(size_of::<#type_name>() == type_details.recommended_capacity as ::core::primitive::usize);
101 assert!(align_of::<#type_name>() == type_details.alignment.get() as ::core::primitive::usize);
102 };
103
104 #[automatically_derived]
105 unsafe impl ::ab_io_type::trivial_type::TrivialType for #type_name
106 where
107 #( #field_types: ::ab_io_type::trivial_type::TrivialType, )*
108 {
109 const METADATA: &[::core::primitive::u8] = #struct_metadata;
110 }
111 }
112 }
113 Data::Enum(data_enum) => {
114 if repr_numeric != Some(8) {
116 return Error::new(
117 input.generics.span(),
118 "`TrivialType` derive for enums only supports `#[repr(u8)]`, ambiguous \
119 or larger discriminant size is not allowed",
120 )
121 .to_compile_error()
122 .into();
123 }
124
125 let repr_numeric = format_ident!("u8");
126
127 let field_types = data_enum
128 .variants
129 .iter()
130 .flat_map(|variant| &variant.fields)
131 .map(|field| &field.ty)
132 .collect::<Vec<_>>();
133
134 let padding_assertions = data_enum.variants.iter().map(|variant| {
135 let field_types = variant.fields.iter().map(|field| &field.ty);
136
137 quote! {
138 assert!(
141 0 == (
142 ::core::mem::size_of::<#type_name>()
143 - ::core::mem::size_of::<::core::primitive::#repr_numeric>()
144 #(- ::core::mem::size_of::<#field_types>() )*
145 ),
146 "Enum must not have implicit padding. Consider reordering fields, adding \
147 `padding: [u8; N]` field where necessary or use `Unaligned<T>` wrapper for \
148 types with larger alignment to reduce it to one byte."
149 );
150 }
151 });
152
153 let enum_metadata = match generate_enum_metadata(type_name, data_enum) {
154 Ok(struct_metadata) => struct_metadata,
155 Err(error) => {
156 return error.to_compile_error().into();
157 }
158 };
159
160 quote! {
161 const _: () = {
162 assert!(
164 u32::MAX as ::core::primitive::usize >= ::core::mem::size_of::<#type_name>(),
165 "Type size must be smaller than 2^32"
166 );
167
168 let (type_details, _metadata) =
170 ::ab_io_type::metadata::IoTypeMetadataKind::type_details(
171 <#type_name as ::ab_io_type::trivial_type::TrivialType>::METADATA,
172 )
173 .expect("Statically correct metadata; qed");
174 assert!(size_of::<#type_name>() == type_details.recommended_capacity as ::core::primitive::usize);
175 assert!(align_of::<#type_name>() == type_details.alignment.get() as ::core::primitive::usize);
176
177 #( #padding_assertions )*;
178 };
179
180 #[automatically_derived]
181 unsafe impl ::ab_io_type::trivial_type::TrivialType for #type_name
182 where
183 #( #field_types: ::ab_io_type::trivial_type::TrivialType, )*
184 {
185 const METADATA: &[::core::primitive::u8] = #enum_metadata;
186 }
187 }
188 }
189 Data::Union(data_union) => {
190 return Error::new(
191 data_union.union_token.span(),
192 "`TrivialType` can be derived for structs and enums, but not unions",
193 )
194 .to_compile_error()
195 .into();
196 }
197 };
198
199 output.into()
200}
201
202#[expect(clippy::type_complexity, reason = "Private one-off function")]
203fn parse_repr(
204 repr_attr: &Attribute,
205) -> Result<(bool, bool, Option<u8>, Option<usize>, Option<usize>), Error> {
206 let mut repr_c = false;
207 let mut repr_transparent = false;
208 let mut repr_numeric = None::<u8>;
209 let mut repr_align = None::<usize>;
210 let mut repr_packed = None::<usize>;
211
212 repr_attr.parse_nested_meta(|meta| {
214 if meta.path.is_ident("C") {
215 repr_c = true;
216 return Ok(());
217 }
218 if meta.path.is_ident("u8") {
219 repr_numeric.replace(8);
220 return Ok(());
221 }
222 if meta.path.is_ident("u16") {
223 repr_numeric.replace(16);
224 return Ok(());
225 }
226 if meta.path.is_ident("u32") {
227 repr_numeric.replace(32);
228 return Ok(());
229 }
230 if meta.path.is_ident("u64") {
231 repr_numeric.replace(64);
232 return Ok(());
233 }
234 if meta.path.is_ident("u128") {
235 repr_numeric.replace(128);
236 return Ok(());
237 }
238 if meta.path.is_ident("transparent") {
239 repr_transparent = true;
240 return Ok(());
241 }
242
243 if meta.path.is_ident("align") {
245 let content;
246 parenthesized!(content in meta.input);
247 let lit = content.parse::<LitInt>()?;
248 let n = lit.base10_parse::<usize>()?;
249 repr_align = Some(n);
250 return Ok(());
251 }
252
253 if meta.path.is_ident("packed") {
255 if meta.input.peek(Paren) {
256 let content;
257 parenthesized!(content in meta.input);
258 let lit = content.parse::<LitInt>()?;
259 let n = lit.base10_parse::<usize>()?;
260 repr_packed = Some(n);
261 } else {
262 repr_packed = Some(1);
263 }
264 return Ok(());
265 }
266
267 Err(meta.error("Unsupported `#[repr(..)]`"))
268 })?;
269
270 Ok((
271 repr_c,
272 repr_transparent,
273 repr_numeric,
274 repr_align,
275 repr_packed,
276 ))
277}
278
279fn generate_struct_metadata(ident: &Ident, data_struct: &DataStruct) -> Result<TokenStream, Error> {
280 let num_fields = data_struct.fields.len();
281 let struct_with_fields = data_struct
282 .fields
283 .iter()
284 .next()
285 .is_some_and(|field| field.ident.is_some());
286 let (io_type_metadata, with_num_fields) = if struct_with_fields {
287 match num_fields {
288 0..=16 => (format_ident!("Struct{num_fields}"), false),
289 _ => (format_ident!("Struct"), true),
290 }
291 } else {
292 match num_fields {
293 1..=16 => (format_ident!("TupleStruct{num_fields}"), false),
294 _ => (format_ident!("TupleStruct"), true),
295 }
296 };
297 let inner_struct_metadata =
298 generate_inner_struct_metadata(ident, &data_struct.fields, with_num_fields)
299 .collect::<Result<Vec<_>, _>>()?;
300
301 Ok(quote! {{
305 #[inline(always)]
306 const fn metadata() -> (
307 [::core::primitive::u8; ::ab_io_type::metadata::MAX_METADATA_CAPACITY],
308 usize,
309 )
310 {
311 ::ab_io_type::metadata::concat_metadata_sources(&[
312 &[::ab_io_type::metadata::IoTypeMetadataKind::#io_type_metadata as ::core::primitive::u8],
313 #( #inner_struct_metadata )*
314 ])
315 }
316
317 metadata()
319 .0
320 .split_at(metadata().1)
321 .0
322 }})
323}
324
325fn generate_enum_metadata(ident: &Ident, data_enum: &DataEnum) -> Result<TokenStream, Error> {
326 let type_name_string = ident.to_string();
327 let type_name_bytes = type_name_string.as_bytes();
328
329 let type_name_bytes_len = u8::try_from(type_name_bytes.len()).map_err(|_error| {
330 Error::new(
331 ident.span(),
332 format!(
333 "Name of the enum must not be more than {} bytes in length",
334 u8::MAX
335 ),
336 )
337 })?;
338 let num_variants = u8::try_from(data_enum.variants.len()).map_err(|_error| {
339 Error::new(
340 ident.span(),
341 format!("Enum must not have more than {} variants", u8::MAX),
342 )
343 })?;
344 let variant_has_fields = data_enum
345 .variants
346 .iter()
347 .next()
348 .is_some_and(|variant| !variant.fields.is_empty());
349 let enum_type = if variant_has_fields {
350 "Enum"
351 } else {
352 "EnumNoFields"
353 };
354 let (io_type_metadata, with_num_variants) = match num_variants {
355 1..=16 => (format_ident!("{enum_type}{num_variants}"), false),
356 _ => (format_ident!("{enum_type}"), true),
357 };
358
359 let enum_metadata_header = {
365 let enum_metadata_header = [Literal::u8_unsuffixed(type_name_bytes_len)]
366 .into_iter()
367 .chain(
368 type_name_bytes
369 .iter()
370 .map(|&char| Literal::byte_character(char)),
371 )
372 .chain(with_num_variants.then_some(Literal::u8_unsuffixed(num_variants)));
373
374 quote! {
375 &[
376 ::ab_io_type::metadata::IoTypeMetadataKind::#io_type_metadata as ::core::primitive::u8,
377 #( #enum_metadata_header, )*
378 ]
379 }
380 };
381
382 let inner = data_enum
384 .variants
385 .iter()
386 .flat_map(|variant| {
387 variant
388 .fields
389 .iter()
390 .find_map(|field| {
391 if field.ident.is_none() {
392 Some(Err(Error::new(
393 field.span(),
394 "Variant must have named fields",
395 )))
396 } else {
397 None
398 }
399 })
400 .into_iter()
401 .chain(generate_inner_struct_metadata(
402 &variant.ident,
403 &variant.fields,
404 variant_has_fields,
405 ))
406 })
407 .collect::<Result<Vec<TokenStream>, Error>>()?;
408
409 Ok(quote! {{
410 #[inline(always)]
411 const fn metadata() -> (
412 [::core::primitive::u8; ::ab_io_type::metadata::MAX_METADATA_CAPACITY],
413 usize,
414 )
415 {
416 ::ab_io_type::metadata::concat_metadata_sources(&[
417 #enum_metadata_header,
418 #( #inner )*
419 ])
420 }
421
422 metadata()
424 .0
425 .split_at(metadata().1)
426 .0
427 }})
428}
429
430fn generate_inner_struct_metadata<'a>(
431 ident: &'a Ident,
432 fields: &'a Fields,
433 with_num_fields: bool,
434) -> impl Iterator<Item = Result<TokenStream, Error>> + 'a {
435 iter::once_with(move || generate_inner_struct_metadata_header(ident, fields, with_num_fields))
436 .chain(generate_fields_metadata(fields))
437}
438
439fn generate_inner_struct_metadata_header(
440 ident: &Ident,
441 fields: &Fields,
442 with_num_fields: bool,
443) -> Result<TokenStream, Error> {
444 let ident_string = ident.to_string();
445 let ident_bytes = ident_string.as_bytes();
446
447 let ident_bytes_len = u8::try_from(ident_bytes.len()).map_err(|_error| {
448 Error::new(
449 ident.span(),
450 format!(
451 "Identifier must not be more than {} bytes in length",
452 u8::MAX
453 ),
454 )
455 })?;
456 let num_fields = u8::try_from(fields.len()).map_err(|_error| {
457 Error::new(
458 fields.span(),
459 format!("Must not have more than {} field", u8::MAX),
460 )
461 })?;
462
463 Ok({
468 let struct_metadata_header = [Literal::u8_unsuffixed(ident_bytes_len)]
469 .into_iter()
470 .chain(
471 ident_bytes
472 .iter()
473 .map(|&char| Literal::byte_character(char)),
474 )
475 .chain(with_num_fields.then_some(Literal::u8_unsuffixed(num_fields)));
476
477 quote! {
478 &[#( #struct_metadata_header, )*],
479 }
480 })
481}
482
483fn generate_fields_metadata(
484 fields: &Fields,
485) -> impl Iterator<Item = Result<TokenStream, Error>> + '_ {
486 fields.iter().map(move |field| {
491 let field_metadata = if let Some(field_name) = &field.ident {
492 let field_name_string = field_name.to_string();
493 let field_name_bytes = field_name_string.as_bytes();
494 let field_name_len = u8::try_from(field_name_bytes.len()).map_err(|_error| {
495 Error::new(
496 field.span(),
497 format!(
498 "Name of the field must not be more than {} bytes in length",
499 u8::MAX
500 ),
501 )
502 })?;
503
504 let field_metadata = [Literal::u8_unsuffixed(field_name_len)].into_iter().chain(
505 field_name_bytes
506 .iter()
507 .map(|&char| Literal::byte_character(char)),
508 );
509
510 Some(quote! { #( #field_metadata, )* })
511 } else {
512 None
513 };
514 let field_type = &field.ty;
515
516 Ok(quote! {
517 &[ #field_metadata ],
518 <#field_type as ::ab_io_type::trivial_type::TrivialType>::METADATA,
519 })
520 })
521}