blob: 81c0274274a1493604f4e228235e4d70d8f82e04 [file] [log] [blame]
extern crate proc_macro;
use proc_macro2::{Span, TokenStream};
use quote::quote;
use syn::*;
mod container_attributes;
mod field_attributes;
mod variant_attributes;
use container_attributes::ContainerAttributes;
use field_attributes::{determine_field_constructor, FieldConstructor};
use variant_attributes::not_skipped;
const ARBITRARY_ATTRIBUTE_NAME: &str = "arbitrary";
const ARBITRARY_LIFETIME_NAME: &str = "'arbitrary";
#[proc_macro_derive(Arbitrary, attributes(arbitrary))]
pub fn derive_arbitrary(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
let input = syn::parse_macro_input!(tokens as syn::DeriveInput);
expand_derive_arbitrary(input)
.unwrap_or_else(syn::Error::into_compile_error)
.into()
}
fn expand_derive_arbitrary(input: syn::DeriveInput) -> Result<TokenStream> {
let container_attrs = ContainerAttributes::from_derive_input(&input)?;
let (lifetime_without_bounds, lifetime_with_bounds) =
build_arbitrary_lifetime(input.generics.clone());
let recursive_count = syn::Ident::new(
&format!("RECURSIVE_COUNT_{}", input.ident),
Span::call_site(),
);
let arbitrary_method =
gen_arbitrary_method(&input, lifetime_without_bounds.clone(), &recursive_count)?;
let size_hint_method = gen_size_hint_method(&input)?;
let name = input.ident;
// Apply user-supplied bounds or automatic `T: ArbitraryBounds`.
let generics = apply_trait_bounds(
input.generics,
lifetime_without_bounds.clone(),
&container_attrs,
)?;
// Build ImplGeneric with a lifetime (https://github.com/dtolnay/syn/issues/90)
let mut generics_with_lifetime = generics.clone();
generics_with_lifetime
.params
.push(GenericParam::Lifetime(lifetime_with_bounds));
let (impl_generics, _, _) = generics_with_lifetime.split_for_impl();
// Build TypeGenerics and WhereClause without a lifetime
let (_, ty_generics, where_clause) = generics.split_for_impl();
Ok(quote! {
const _: () = {
::std::thread_local! {
#[allow(non_upper_case_globals)]
static #recursive_count: ::core::cell::Cell<u32> = ::core::cell::Cell::new(0);
}
#[automatically_derived]
impl #impl_generics arbitrary::Arbitrary<#lifetime_without_bounds> for #name #ty_generics #where_clause {
#arbitrary_method
#size_hint_method
}
};
})
}
// Returns: (lifetime without bounds, lifetime with bounds)
// Example: ("'arbitrary", "'arbitrary: 'a + 'b")
fn build_arbitrary_lifetime(generics: Generics) -> (LifetimeParam, LifetimeParam) {
let lifetime_without_bounds =
LifetimeParam::new(Lifetime::new(ARBITRARY_LIFETIME_NAME, Span::call_site()));
let mut lifetime_with_bounds = lifetime_without_bounds.clone();
for param in generics.params.iter() {
if let GenericParam::Lifetime(lifetime_def) = param {
lifetime_with_bounds
.bounds
.push(lifetime_def.lifetime.clone());
}
}
(lifetime_without_bounds, lifetime_with_bounds)
}
fn apply_trait_bounds(
mut generics: Generics,
lifetime: LifetimeParam,
container_attrs: &ContainerAttributes,
) -> Result<Generics> {
// If user-supplied bounds exist, apply them to their matching type parameters.
if let Some(config_bounds) = &container_attrs.bounds {
let mut config_bounds_applied = 0;
for param in generics.params.iter_mut() {
if let GenericParam::Type(type_param) = param {
if let Some(replacement) = config_bounds
.iter()
.flatten()
.find(|p| p.ident == type_param.ident)
{
*type_param = replacement.clone();
config_bounds_applied += 1;
} else {
// If no user-supplied bounds exist for this type, delete the original bounds.
// This mimics serde.
type_param.bounds = Default::default();
type_param.default = None;
}
}
}
let config_bounds_supplied = config_bounds
.iter()
.map(|bounds| bounds.len())
.sum::<usize>();
if config_bounds_applied != config_bounds_supplied {
return Err(Error::new(
Span::call_site(),
format!(
"invalid `{}` attribute. too many bounds, only {} out of {} are applicable",
ARBITRARY_ATTRIBUTE_NAME, config_bounds_applied, config_bounds_supplied,
),
));
}
Ok(generics)
} else {
// Otherwise, inject a `T: Arbitrary` bound for every parameter.
Ok(add_trait_bounds(generics, lifetime))
}
}
// Add a bound `T: Arbitrary` to every type parameter T.
fn add_trait_bounds(mut generics: Generics, lifetime: LifetimeParam) -> Generics {
for param in generics.params.iter_mut() {
if let GenericParam::Type(type_param) = param {
type_param
.bounds
.push(parse_quote!(arbitrary::Arbitrary<#lifetime>));
}
}
generics
}
fn with_recursive_count_guard(
recursive_count: &syn::Ident,
expr: impl quote::ToTokens,
) -> impl quote::ToTokens {
quote! {
let guard_against_recursion = u.is_empty();
if guard_against_recursion {
#recursive_count.with(|count| {
if count.get() > 0 {
return Err(arbitrary::Error::NotEnoughData);
}
count.set(count.get() + 1);
Ok(())
})?;
}
let result = (|| { #expr })();
if guard_against_recursion {
#recursive_count.with(|count| {
count.set(count.get() - 1);
});
}
result
}
}
fn gen_arbitrary_method(
input: &DeriveInput,
lifetime: LifetimeParam,
recursive_count: &syn::Ident,
) -> Result<TokenStream> {
fn arbitrary_structlike(
fields: &Fields,
ident: &syn::Ident,
lifetime: LifetimeParam,
recursive_count: &syn::Ident,
) -> Result<TokenStream> {
let arbitrary = construct(fields, |_idx, field| gen_constructor_for_field(field))?;
let body = with_recursive_count_guard(recursive_count, quote! { Ok(#ident #arbitrary) });
let arbitrary_take_rest = construct_take_rest(fields)?;
let take_rest_body =
with_recursive_count_guard(recursive_count, quote! { Ok(#ident #arbitrary_take_rest) });
Ok(quote! {
fn arbitrary(u: &mut arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
#body
}
fn arbitrary_take_rest(mut u: arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
#take_rest_body
}
})
}
fn arbitrary_variant(
index: u64,
enum_name: &Ident,
variant_name: &Ident,
ctor: TokenStream,
) -> TokenStream {
quote! { #index => #enum_name::#variant_name #ctor }
}
fn arbitrary_enum_method(
recursive_count: &syn::Ident,
unstructured: TokenStream,
variants: &[TokenStream],
) -> impl quote::ToTokens {
let count = variants.len() as u64;
with_recursive_count_guard(
recursive_count,
quote! {
// Use a multiply + shift to generate a ranged random number
// with slight bias. For details, see:
// https://lemire.me/blog/2016/06/30/fast-random-shuffling
Ok(match (u64::from(<u32 as arbitrary::Arbitrary>::arbitrary(#unstructured)?) * #count) >> 32 {
#(#variants,)*
_ => unreachable!()
})
},
)
}
fn arbitrary_enum(
DataEnum { variants, .. }: &DataEnum,
enum_name: &Ident,
lifetime: LifetimeParam,
recursive_count: &syn::Ident,
) -> Result<TokenStream> {
let filtered_variants = variants.iter().filter(not_skipped);
// Check attributes of all variants:
filtered_variants
.clone()
.try_for_each(check_variant_attrs)?;
// From here on, we can assume that the attributes of all variants were checked.
let enumerated_variants = filtered_variants
.enumerate()
.map(|(index, variant)| (index as u64, variant));
// Construct `match`-arms for the `arbitrary` method.
let variants = enumerated_variants
.clone()
.map(|(index, Variant { fields, ident, .. })| {
construct(fields, |_, field| gen_constructor_for_field(field))
.map(|ctor| arbitrary_variant(index, enum_name, ident, ctor))
})
.collect::<Result<Vec<TokenStream>>>()?;
// Construct `match`-arms for the `arbitrary_take_rest` method.
let variants_take_rest = enumerated_variants
.map(|(index, Variant { fields, ident, .. })| {
construct_take_rest(fields)
.map(|ctor| arbitrary_variant(index, enum_name, ident, ctor))
})
.collect::<Result<Vec<TokenStream>>>()?;
// Most of the time, `variants` is not empty (the happy path),
// thus `variants_take_rest` will be used,
// so no need to move this check before constructing `variants_take_rest`.
// If `variants` is empty, this will emit a compiler-error.
(!variants.is_empty())
.then(|| {
// TODO: Improve dealing with `u` vs. `&mut u`.
let arbitrary = arbitrary_enum_method(recursive_count, quote! { u }, &variants);
let arbitrary_take_rest = arbitrary_enum_method(recursive_count, quote! { &mut u }, &variants_take_rest);
quote! {
fn arbitrary(u: &mut arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
#arbitrary
}
fn arbitrary_take_rest(mut u: arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
#arbitrary_take_rest
}
}
})
.ok_or_else(|| Error::new_spanned(
enum_name,
"Enum must have at least one variant, that is not skipped"
))
}
let ident = &input.ident;
match &input.data {
Data::Struct(data) => arbitrary_structlike(&data.fields, ident, lifetime, recursive_count),
Data::Union(data) => arbitrary_structlike(
&Fields::Named(data.fields.clone()),
ident,
lifetime,
recursive_count,
),
Data::Enum(data) => arbitrary_enum(data, ident, lifetime, recursive_count),
}
}
fn construct(
fields: &Fields,
ctor: impl Fn(usize, &Field) -> Result<TokenStream>,
) -> Result<TokenStream> {
let output = match fields {
Fields::Named(names) => {
let names: Vec<TokenStream> = names
.named
.iter()
.enumerate()
.map(|(i, f)| {
let name = f.ident.as_ref().unwrap();
ctor(i, f).map(|ctor| quote! { #name: #ctor })
})
.collect::<Result<_>>()?;
quote! { { #(#names,)* } }
}
Fields::Unnamed(names) => {
let names: Vec<TokenStream> = names
.unnamed
.iter()
.enumerate()
.map(|(i, f)| ctor(i, f).map(|ctor| quote! { #ctor }))
.collect::<Result<_>>()?;
quote! { ( #(#names),* ) }
}
Fields::Unit => quote!(),
};
Ok(output)
}
fn construct_take_rest(fields: &Fields) -> Result<TokenStream> {
construct(fields, |idx, field| {
determine_field_constructor(field).map(|field_constructor| match field_constructor {
FieldConstructor::Default => quote!(::core::default::Default::default()),
FieldConstructor::Arbitrary => {
if idx + 1 == fields.len() {
quote! { arbitrary::Arbitrary::arbitrary_take_rest(u)? }
} else {
quote! { arbitrary::Arbitrary::arbitrary(&mut u)? }
}
}
FieldConstructor::With(function_or_closure) => quote!((#function_or_closure)(&mut u)?),
FieldConstructor::Value(value) => quote!(#value),
})
})
}
fn gen_size_hint_method(input: &DeriveInput) -> Result<TokenStream> {
let size_hint_fields = |fields: &Fields| {
fields
.iter()
.map(|f| {
let ty = &f.ty;
determine_field_constructor(f).map(|field_constructor| {
match field_constructor {
FieldConstructor::Default | FieldConstructor::Value(_) => {
quote!(Ok((0, Some(0))))
}
FieldConstructor::Arbitrary => {
quote! { <#ty as arbitrary::Arbitrary>::try_size_hint(depth) }
}
// Note that in this case it's hard to determine what size_hint must be, so size_of::<T>() is
// just an educated guess, although it's gonna be inaccurate for dynamically
// allocated types (Vec, HashMap, etc.).
FieldConstructor::With(_) => {
quote! { Ok((::core::mem::size_of::<#ty>(), None)) }
}
}
})
})
.collect::<Result<Vec<TokenStream>>>()
.map(|hints| {
quote! {
Ok(arbitrary::size_hint::and_all(&[
#( #hints? ),*
]))
}
})
};
let size_hint_structlike = |fields: &Fields| {
size_hint_fields(fields).map(|hint| {
quote! {
#[inline]
fn size_hint(depth: usize) -> (usize, ::core::option::Option<usize>) {
Self::try_size_hint(depth).unwrap_or_default()
}
#[inline]
fn try_size_hint(depth: usize) -> ::core::result::Result<(usize, ::core::option::Option<usize>), arbitrary::MaxRecursionReached> {
arbitrary::size_hint::try_recursion_guard(depth, |depth| #hint)
}
}
})
};
match &input.data {
Data::Struct(data) => size_hint_structlike(&data.fields),
Data::Union(data) => size_hint_structlike(&Fields::Named(data.fields.clone())),
Data::Enum(data) => data
.variants
.iter()
.filter(not_skipped)
.map(|Variant { fields, .. }| {
// The attributes of all variants are checked in `gen_arbitrary_method` above
// and can therefore assume that they are valid.
size_hint_fields(fields)
})
.collect::<Result<Vec<TokenStream>>>()
.map(|variants| {
quote! {
fn size_hint(depth: usize) -> (usize, ::core::option::Option<usize>) {
Self::try_size_hint(depth).unwrap_or_default()
}
#[inline]
fn try_size_hint(depth: usize) -> ::core::result::Result<(usize, ::core::option::Option<usize>), arbitrary::MaxRecursionReached> {
Ok(arbitrary::size_hint::and(
<u32 as arbitrary::Arbitrary>::try_size_hint(depth)?,
arbitrary::size_hint::try_recursion_guard(depth, |depth| {
Ok(arbitrary::size_hint::or_all(&[ #( #variants? ),* ]))
})?,
))
}
}
}),
}
}
fn gen_constructor_for_field(field: &Field) -> Result<TokenStream> {
let ctor = match determine_field_constructor(field)? {
FieldConstructor::Default => quote!(::core::default::Default::default()),
FieldConstructor::Arbitrary => quote!(arbitrary::Arbitrary::arbitrary(u)?),
FieldConstructor::With(function_or_closure) => quote!((#function_or_closure)(u)?),
FieldConstructor::Value(value) => quote!(#value),
};
Ok(ctor)
}
fn check_variant_attrs(variant: &Variant) -> Result<()> {
for attr in &variant.attrs {
if attr.path().is_ident(ARBITRARY_ATTRIBUTE_NAME) {
return Err(Error::new_spanned(
attr,
format!(
"invalid `{}` attribute. it is unsupported on enum variants. try applying it to a field of the variant instead",
ARBITRARY_ATTRIBUTE_NAME
),
));
}
}
Ok(())
}