| use proc_macro2::TokenStream; |
| use quote::quote; |
| use syn::{ |
| visit_mut::{self, visit_item_mut, visit_path_segment_mut, VisitMut}, |
| Expr, ExprBlock, File, GenericArgument, GenericParam, Item, PathArguments, PathSegment, Stmt, |
| Type, TypeParamBound, WherePredicate, |
| }; |
| |
| pub struct ReplaceGenericType<'a> { |
| generic_type: &'a str, |
| arg_type: &'a PathSegment, |
| } |
| |
| impl<'a> ReplaceGenericType<'a> { |
| pub fn new(generic_type: &'a str, arg_type: &'a PathSegment) -> Self { |
| Self { |
| generic_type, |
| arg_type, |
| } |
| } |
| |
| pub fn replace_generic_type(item: &mut Item, generic_type: &'a str, arg_type: &'a PathSegment) { |
| let mut s = Self::new(generic_type, arg_type); |
| s.visit_item_mut(item); |
| } |
| } |
| |
| impl<'a> VisitMut for ReplaceGenericType<'a> { |
| fn visit_item_mut(&mut self, i: &mut Item) { |
| if let Item::Fn(item_fn) = i { |
| // remove generic type from generics <T, F> |
| let args = item_fn |
| .sig |
| .generics |
| .params |
| .iter() |
| .filter_map(|param| { |
| if let GenericParam::Type(type_param) = ¶m { |
| if type_param.ident.to_string().eq(self.generic_type) { |
| None |
| } else { |
| Some(param) |
| } |
| } else { |
| Some(param) |
| } |
| }) |
| .collect::<Vec<_>>(); |
| item_fn.sig.generics.params = args.into_iter().cloned().collect(); |
| |
| // remove generic type from where clause |
| if let Some(where_clause) = &mut item_fn.sig.generics.where_clause { |
| let new_where_clause = where_clause |
| .predicates |
| .iter() |
| .filter_map(|predicate| { |
| if let WherePredicate::Type(predicate_type) = predicate { |
| if let Type::Path(p) = &predicate_type.bounded_ty { |
| if p.path.segments[0].ident.to_string().eq(self.generic_type) { |
| None |
| } else { |
| Some(predicate) |
| } |
| } else { |
| Some(predicate) |
| } |
| } else { |
| Some(predicate) |
| } |
| }) |
| .collect::<Vec<_>>(); |
| |
| where_clause.predicates = new_where_clause.into_iter().cloned().collect(); |
| }; |
| } |
| visit_item_mut(self, i) |
| } |
| fn visit_path_segment_mut(&mut self, i: &mut PathSegment) { |
| // replace generic type with target type |
| if i.ident.to_string().eq(&self.generic_type) { |
| *i = self.arg_type.clone(); |
| } |
| visit_path_segment_mut(self, i); |
| } |
| } |
| |
| pub struct AsyncAwaitRemoval; |
| |
| impl AsyncAwaitRemoval { |
| pub fn remove_async_await(&mut self, item: TokenStream) -> TokenStream { |
| let mut syntax_tree: File = syn::parse(item.into()).unwrap(); |
| self.visit_file_mut(&mut syntax_tree); |
| quote!(#syntax_tree) |
| } |
| } |
| |
| impl VisitMut for AsyncAwaitRemoval { |
| fn visit_expr_mut(&mut self, node: &mut Expr) { |
| // Delegate to the default impl to visit nested expressions. |
| visit_mut::visit_expr_mut(self, node); |
| |
| match node { |
| Expr::Await(expr) => *node = (*expr.base).clone(), |
| |
| Expr::Async(expr) => { |
| let inner = &expr.block; |
| let sync_expr = if let [Stmt::Expr(expr, None)] = inner.stmts.as_slice() { |
| // remove useless braces when there is only one statement |
| expr.clone() |
| } else { |
| Expr::Block(ExprBlock { |
| attrs: expr.attrs.clone(), |
| block: inner.clone(), |
| label: None, |
| }) |
| }; |
| *node = sync_expr; |
| } |
| _ => {} |
| } |
| } |
| |
| fn visit_item_mut(&mut self, i: &mut Item) { |
| // find generic parameter of Future and replace it with its Output type |
| if let Item::Fn(item_fn) = i { |
| let mut inputs: Vec<(String, PathSegment)> = vec![]; |
| |
| // generic params: <T:Future<Output=()>, F> |
| for param in &item_fn.sig.generics.params { |
| // generic param: T:Future<Output=()> |
| if let GenericParam::Type(type_param) = param { |
| let generic_type_name = type_param.ident.to_string(); |
| |
| // bound: Future<Output=()> |
| for bound in &type_param.bounds { |
| inputs.extend(search_trait_bound(&generic_type_name, bound)); |
| } |
| } |
| } |
| |
| if let Some(where_clause) = &item_fn.sig.generics.where_clause { |
| for predicate in &where_clause.predicates { |
| if let WherePredicate::Type(predicate_type) = predicate { |
| let generic_type_name = if let Type::Path(p) = &predicate_type.bounded_ty { |
| p.path.segments[0].ident.to_string() |
| } else { |
| panic!("Please submit an issue"); |
| }; |
| |
| for bound in &predicate_type.bounds { |
| inputs.extend(search_trait_bound(&generic_type_name, bound)); |
| } |
| } |
| } |
| } |
| |
| for (generic_type_name, path_seg) in &inputs { |
| ReplaceGenericType::replace_generic_type(i, generic_type_name, path_seg); |
| } |
| } |
| visit_item_mut(self, i); |
| } |
| } |
| |
| fn search_trait_bound( |
| generic_type_name: &str, |
| bound: &TypeParamBound, |
| ) -> Vec<(String, PathSegment)> { |
| let mut inputs = vec![]; |
| |
| if let TypeParamBound::Trait(trait_bound) = bound { |
| let segment = &trait_bound.path.segments[trait_bound.path.segments.len() - 1]; |
| let name = segment.ident.to_string(); |
| if name.eq("Future") { |
| // match Future<Output=Type> |
| if let PathArguments::AngleBracketed(args) = &segment.arguments { |
| // binding: Output=Type |
| if let GenericArgument::AssocType(binding) = &args.args[0] { |
| if let Type::Path(p) = &binding.ty { |
| inputs.push((generic_type_name.to_owned(), p.path.segments[0].clone())); |
| } |
| } |
| } |
| } |
| } |
| inputs |
| } |