| use std::collections::HashSet; |
| |
| use crate::params_style::ParamStyle; |
| use crate::rpc_attr::RpcMethodAttribute; |
| use quote::quote; |
| use syn::{ |
| parse_quote, |
| punctuated::Punctuated, |
| visit::{self, Visit}, |
| Result, Token, |
| }; |
| |
| pub enum MethodRegistration { |
| Standard { |
| method: RpcMethod, |
| has_metadata: bool, |
| }, |
| PubSub { |
| name: String, |
| subscribes: Vec<RpcMethod>, |
| unsubscribe: RpcMethod, |
| }, |
| Notification { |
| method: RpcMethod, |
| has_metadata: bool, |
| }, |
| } |
| |
| impl MethodRegistration { |
| fn generate(&self) -> Result<proc_macro2::TokenStream> { |
| match self { |
| MethodRegistration::Standard { method, has_metadata } => { |
| let rpc_name = &method.name(); |
| let add_method = if *has_metadata { |
| quote!(add_method_with_meta) |
| } else { |
| quote!(add_method) |
| }; |
| let closure = method.generate_delegate_closure(false)?; |
| let add_aliases = method.generate_add_aliases(); |
| |
| Ok(quote! { |
| del.#add_method(#rpc_name, #closure); |
| #add_aliases |
| }) |
| } |
| MethodRegistration::PubSub { |
| name, |
| subscribes, |
| unsubscribe, |
| } => { |
| let unsub_name = unsubscribe.name(); |
| let unsub_method_ident = unsubscribe.ident(); |
| let unsub_closure = quote! { |
| move |base, id, meta| { |
| use self::_futures::{FutureExt, TryFutureExt}; |
| self::_jsonrpc_core::WrapFuture::into_future( |
| Self::#unsub_method_ident(base, meta, id) |
| ) |
| .map_ok(|value| _jsonrpc_core::to_value(value) |
| .expect("Expected always-serializable type; qed")) |
| .map_err(Into::into) |
| } |
| }; |
| |
| let mut add_subscriptions = proc_macro2::TokenStream::new(); |
| |
| for subscribe in subscribes.iter() { |
| let sub_name = subscribe.name(); |
| let sub_closure = subscribe.generate_delegate_closure(true)?; |
| let sub_aliases = subscribe.generate_add_aliases(); |
| |
| add_subscriptions = quote! { |
| #add_subscriptions |
| del.add_subscription( |
| #name, |
| (#sub_name, #sub_closure), |
| (#unsub_name, #unsub_closure), |
| ); |
| #sub_aliases |
| }; |
| } |
| |
| let unsub_aliases = unsubscribe.generate_add_aliases(); |
| |
| Ok(quote! { |
| #add_subscriptions |
| #unsub_aliases |
| }) |
| } |
| MethodRegistration::Notification { method, has_metadata } => { |
| let name = &method.name(); |
| let add_notification = if *has_metadata { |
| quote!(add_notification_with_meta) |
| } else { |
| quote!(add_notification) |
| }; |
| let closure = method.generate_delegate_closure(false)?; |
| let add_aliases = method.generate_add_aliases(); |
| |
| Ok(quote! { |
| del.#add_notification(#name, #closure); |
| #add_aliases |
| }) |
| } |
| } |
| } |
| } |
| |
| const SUBSCRIBER_TYPE_IDENT: &str = "Subscriber"; |
| const METADATA_CLOSURE_ARG: &str = "meta"; |
| const SUBSCRIBER_CLOSURE_ARG: &str = "subscriber"; |
| |
| // tuples are limited to 16 fields: the maximum supported by `serde::Deserialize` |
| const TUPLE_FIELD_NAMES: [&str; 16] = [ |
| "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", |
| ]; |
| |
| pub fn generate_trait_item_method( |
| methods: &[MethodRegistration], |
| trait_item: &syn::ItemTrait, |
| has_metadata: bool, |
| has_pubsub_methods: bool, |
| ) -> Result<syn::TraitItemMethod> { |
| let io_delegate_type = if has_pubsub_methods { |
| quote!(_jsonrpc_pubsub::IoDelegate) |
| } else { |
| quote!(_jsonrpc_core::IoDelegate) |
| }; |
| let add_methods = methods |
| .iter() |
| .map(MethodRegistration::generate) |
| .collect::<Result<Vec<_>>>()?; |
| let to_delegate_body = quote! { |
| let mut del = #io_delegate_type::new(self.into()); |
| #(#add_methods)* |
| del |
| }; |
| |
| let method: syn::TraitItemMethod = if has_metadata { |
| parse_quote! { |
| /// Create an `IoDelegate`, wiring rpc calls to the trait methods. |
| fn to_delegate(self) -> #io_delegate_type<Self, Self::Metadata> { |
| #to_delegate_body |
| } |
| } |
| } else { |
| parse_quote! { |
| /// Create an `IoDelegate`, wiring rpc calls to the trait methods. |
| fn to_delegate<M: _jsonrpc_core::Metadata>(self) -> #io_delegate_type<Self, M> { |
| #to_delegate_body |
| } |
| } |
| }; |
| |
| let predicates = generate_where_clause_serialization_predicates(&trait_item, false); |
| let mut method = method.clone(); |
| method.sig.generics.make_where_clause().predicates.extend(predicates); |
| Ok(method) |
| } |
| |
| #[derive(Clone)] |
| pub struct RpcMethod { |
| pub attr: RpcMethodAttribute, |
| pub trait_item: syn::TraitItemMethod, |
| } |
| |
| impl RpcMethod { |
| pub fn new(attr: RpcMethodAttribute, trait_item: syn::TraitItemMethod) -> RpcMethod { |
| RpcMethod { attr, trait_item } |
| } |
| |
| pub fn attr(&self) -> &RpcMethodAttribute { |
| &self.attr |
| } |
| |
| pub fn name(&self) -> &str { |
| &self.attr.name |
| } |
| |
| pub fn ident(&self) -> &syn::Ident { |
| &self.trait_item.sig.ident |
| } |
| |
| pub fn is_pubsub(&self) -> bool { |
| self.attr.is_pubsub() |
| } |
| |
| pub fn subscriber_arg(&self) -> Option<syn::Type> { |
| self.trait_item |
| .sig |
| .inputs |
| .iter() |
| .filter_map(|arg| match arg { |
| syn::FnArg::Typed(ty) => Some(*ty.ty.clone()), |
| _ => None, |
| }) |
| .find(|ty| { |
| if let syn::Type::Path(path) = ty { |
| if path.path.segments.iter().any(|s| s.ident == SUBSCRIBER_TYPE_IDENT) { |
| return true; |
| } |
| } |
| false |
| }) |
| } |
| |
| fn generate_delegate_closure(&self, is_subscribe: bool) -> Result<proc_macro2::TokenStream> { |
| let mut param_types: Vec<_> = self |
| .trait_item |
| .sig |
| .inputs |
| .iter() |
| .cloned() |
| .filter_map(|arg| match arg { |
| syn::FnArg::Typed(ty) => Some(*ty.ty), |
| _ => None, |
| }) |
| .collect(); |
| |
| // special args are those which are not passed directly via rpc params: metadata, subscriber |
| let special_args = Self::special_args(¶m_types); |
| param_types.retain(|ty| special_args.iter().find(|(_, sty)| sty == ty).is_none()); |
| if param_types.len() > TUPLE_FIELD_NAMES.len() { |
| return Err(syn::Error::new_spanned( |
| &self.trait_item, |
| &format!("Maximum supported number of params is {}", TUPLE_FIELD_NAMES.len()), |
| )); |
| } |
| let tuple_fields: &Vec<_> = &(TUPLE_FIELD_NAMES |
| .iter() |
| .take(param_types.len()) |
| .map(|name| ident(name)) |
| .collect()); |
| let param_types = ¶m_types; |
| let parse_params = { |
| // last arguments that are `Option`-s are optional 'trailing' arguments |
| let trailing_args_num = param_types.iter().rev().take_while(|t| is_option_type(t)).count(); |
| if trailing_args_num != 0 { |
| self.params_with_trailing(trailing_args_num, param_types, tuple_fields) |
| } else if param_types.is_empty() { |
| quote! { let params = params.expect_no_params(); } |
| } else if self.attr.params_style == Some(ParamStyle::Raw) { |
| quote! { let params: _jsonrpc_core::Result<_> = Ok((params,)); } |
| } else if self.attr.params_style == Some(ParamStyle::Positional) { |
| quote! { let params = params.parse::<(#(#param_types, )*)>(); } |
| } else |
| /* if self.attr.params_style == Some(ParamStyle::Named) */ |
| { |
| unimplemented!("Server side named parameters are not implemented"); |
| } |
| }; |
| |
| let method_ident = self.ident(); |
| let result = &self.trait_item.sig.output; |
| let extra_closure_args: &Vec<_> = &special_args.iter().cloned().map(|arg| arg.0).collect(); |
| let extra_method_types: &Vec<_> = &special_args.iter().cloned().map(|arg| arg.1).collect(); |
| |
| let closure_args = quote! { base, params, #(#extra_closure_args), * }; |
| let method_sig = quote! { fn(&Self, #(#extra_method_types, ) * #(#param_types), *) #result }; |
| let method_call = quote! { (base, #(#extra_closure_args, )* #(#tuple_fields), *) }; |
| let match_params = if is_subscribe { |
| quote! { |
| Ok((#(#tuple_fields, )*)) => { |
| let subscriber = _jsonrpc_pubsub::typed::Subscriber::new(subscriber); |
| (method)#method_call |
| }, |
| Err(e) => { |
| let _ = subscriber.reject(e); |
| return |
| } |
| } |
| } else if self.attr.is_notification() { |
| quote! { |
| Ok((#(#tuple_fields, )*)) => { |
| (method)#method_call |
| }, |
| Err(_) => return, |
| } |
| } else { |
| quote! { |
| Ok((#(#tuple_fields, )*)) => { |
| use self::_futures::{FutureExt, TryFutureExt}; |
| let fut = self::_jsonrpc_core::WrapFuture::into_future((method)#method_call) |
| .map_ok(|value| _jsonrpc_core::to_value(value) |
| .expect("Expected always-serializable type; qed")) |
| .map_err(Into::into as fn(_) -> _jsonrpc_core::Error); |
| _futures::future::Either::Left(fut) |
| }, |
| Err(e) => _futures::future::Either::Right(_futures::future::ready(Err(e))), |
| } |
| }; |
| |
| Ok(quote! { |
| move |#closure_args| { |
| let method = &(Self::#method_ident as #method_sig); |
| #parse_params |
| match params { |
| #match_params |
| } |
| } |
| }) |
| } |
| |
| fn special_args(param_types: &[syn::Type]) -> Vec<(syn::Ident, syn::Type)> { |
| let meta_arg = param_types.first().and_then(|ty| { |
| if *ty == parse_quote!(Self::Metadata) { |
| Some(ty.clone()) |
| } else { |
| None |
| } |
| }); |
| let subscriber_arg = param_types.get(1).and_then(|ty| { |
| if let syn::Type::Path(path) = ty { |
| if path.path.segments.iter().any(|s| s.ident == SUBSCRIBER_TYPE_IDENT) { |
| Some(ty.clone()) |
| } else { |
| None |
| } |
| } else { |
| None |
| } |
| }); |
| |
| let mut special_args = Vec::new(); |
| if let Some(meta) = meta_arg { |
| special_args.push((ident(METADATA_CLOSURE_ARG), meta.clone())); |
| } |
| if let Some(subscriber) = subscriber_arg { |
| special_args.push((ident(SUBSCRIBER_CLOSURE_ARG), subscriber.clone())); |
| } |
| special_args |
| } |
| |
| fn params_with_trailing( |
| &self, |
| trailing_args_num: usize, |
| param_types: &[syn::Type], |
| tuple_fields: &[syn::Ident], |
| ) -> proc_macro2::TokenStream { |
| let total_args_num = param_types.len(); |
| let required_args_num = total_args_num - trailing_args_num; |
| |
| let switch_branches = (0..=trailing_args_num) |
| .map(|passed_trailing_args_num| { |
| let passed_args_num = required_args_num + passed_trailing_args_num; |
| let passed_param_types = ¶m_types[..passed_args_num]; |
| let passed_tuple_fields = &tuple_fields[..passed_args_num]; |
| let missed_args_num = total_args_num - passed_args_num; |
| let missed_params_values = ::std::iter::repeat(quote! { None }) |
| .take(missed_args_num) |
| .collect::<Vec<_>>(); |
| |
| if passed_args_num == 0 { |
| quote! { |
| #passed_args_num => params.expect_no_params() |
| .map(|_| (#(#missed_params_values, ) *)) |
| .map_err(Into::into) |
| } |
| } else { |
| quote! { |
| #passed_args_num => params.parse::<(#(#passed_param_types, )*)>() |
| .map(|(#(#passed_tuple_fields,)*)| |
| (#(#passed_tuple_fields, )* #(#missed_params_values, )*)) |
| .map_err(Into::into) |
| } |
| } |
| }) |
| .collect::<Vec<_>>(); |
| |
| quote! { |
| let passed_args_num = match params { |
| _jsonrpc_core::Params::Array(ref v) => Ok(v.len()), |
| _jsonrpc_core::Params::None => Ok(0), |
| _ => Err(_jsonrpc_core::Error::invalid_params("`params` should be an array")) |
| }; |
| |
| let params = passed_args_num.and_then(|passed_args_num| { |
| match passed_args_num { |
| _ if passed_args_num < #required_args_num => Err(_jsonrpc_core::Error::invalid_params( |
| format!("`params` should have at least {} argument(s)", #required_args_num))), |
| #(#switch_branches),*, |
| _ => Err(_jsonrpc_core::Error::invalid_params_with_details( |
| format!("Expected from {} to {} parameters.", #required_args_num, #total_args_num), |
| format!("Got: {}", passed_args_num))), |
| } |
| }); |
| } |
| } |
| |
| fn generate_add_aliases(&self) -> proc_macro2::TokenStream { |
| let name = self.name(); |
| let add_aliases: Vec<_> = self |
| .attr |
| .aliases |
| .iter() |
| .map(|alias| quote! { del.add_alias(#alias, #name); }) |
| .collect(); |
| quote! { #(#add_aliases)* } |
| } |
| } |
| |
| fn ident(s: &str) -> syn::Ident { |
| syn::Ident::new(s, proc_macro2::Span::call_site()) |
| } |
| |
| fn is_option_type(ty: &syn::Type) -> bool { |
| if let syn::Type::Path(path) = ty { |
| path.path.segments.first().map_or(false, |t| t.ident == "Option") |
| } else { |
| false |
| } |
| } |
| |
| pub fn generate_where_clause_serialization_predicates( |
| item_trait: &syn::ItemTrait, |
| client: bool, |
| ) -> Vec<syn::WherePredicate> { |
| #[derive(Default)] |
| struct FindTyParams { |
| trait_generics: HashSet<syn::Ident>, |
| server_to_client_type_params: HashSet<syn::Ident>, |
| client_to_server_type_params: HashSet<syn::Ident>, |
| visiting_return_type: bool, |
| visiting_fn_arg: bool, |
| visiting_subscriber_arg: bool, |
| } |
| impl<'ast> Visit<'ast> for FindTyParams { |
| fn visit_type_param(&mut self, ty_param: &'ast syn::TypeParam) { |
| self.trait_generics.insert(ty_param.ident.clone()); |
| } |
| fn visit_return_type(&mut self, return_type: &'ast syn::ReturnType) { |
| self.visiting_return_type = true; |
| visit::visit_return_type(self, return_type); |
| self.visiting_return_type = false |
| } |
| fn visit_path_segment(&mut self, segment: &'ast syn::PathSegment) { |
| self.visiting_subscriber_arg = |
| self.visiting_subscriber_arg || (self.visiting_fn_arg && segment.ident == SUBSCRIBER_TYPE_IDENT); |
| visit::visit_path_segment(self, segment); |
| self.visiting_subscriber_arg = self.visiting_subscriber_arg && segment.ident != SUBSCRIBER_TYPE_IDENT; |
| } |
| fn visit_ident(&mut self, ident: &'ast syn::Ident) { |
| if self.trait_generics.contains(&ident) { |
| if self.visiting_return_type || self.visiting_subscriber_arg { |
| self.server_to_client_type_params.insert(ident.clone()); |
| } |
| if self.visiting_fn_arg && !self.visiting_subscriber_arg { |
| self.client_to_server_type_params.insert(ident.clone()); |
| } |
| } |
| } |
| fn visit_fn_arg(&mut self, arg: &'ast syn::FnArg) { |
| self.visiting_fn_arg = true; |
| visit::visit_fn_arg(self, arg); |
| self.visiting_fn_arg = false; |
| } |
| } |
| let mut visitor = FindTyParams::default(); |
| visitor.visit_item_trait(item_trait); |
| |
| let additional_where_clause = item_trait.generics.where_clause.clone(); |
| |
| item_trait |
| .generics |
| .type_params() |
| .map(|ty| { |
| let ty_path = syn::TypePath { |
| qself: None, |
| path: ty.ident.clone().into(), |
| }; |
| let mut bounds: Punctuated<syn::TypeParamBound, Token![+]> = parse_quote!(Send + Sync + 'static); |
| // add json serialization trait bounds |
| if client { |
| if visitor.server_to_client_type_params.contains(&ty.ident) { |
| bounds.push(parse_quote!(_jsonrpc_core::serde::de::DeserializeOwned)) |
| } |
| if visitor.client_to_server_type_params.contains(&ty.ident) { |
| bounds.push(parse_quote!(_jsonrpc_core::serde::Serialize)) |
| } |
| } else { |
| if visitor.server_to_client_type_params.contains(&ty.ident) { |
| bounds.push(parse_quote!(_jsonrpc_core::serde::Serialize)) |
| } |
| if visitor.client_to_server_type_params.contains(&ty.ident) { |
| bounds.push(parse_quote!(_jsonrpc_core::serde::de::DeserializeOwned)) |
| } |
| } |
| |
| // add the trait bounds specified by the user in where clause. |
| if let Some(ref where_clause) = additional_where_clause { |
| for predicate in where_clause.predicates.iter() { |
| if let syn::WherePredicate::Type(where_ty) = predicate { |
| if let syn::Type::Path(ref predicate) = where_ty.bounded_ty { |
| if *predicate == ty_path { |
| bounds.extend(where_ty.bounds.clone().into_iter()); |
| } |
| } |
| } |
| } |
| } |
| |
| syn::WherePredicate::Type(syn::PredicateType { |
| lifetimes: None, |
| bounded_ty: syn::Type::Path(ty_path), |
| colon_token: <Token![:]>::default(), |
| bounds, |
| }) |
| }) |
| .collect() |
| } |