| use crate::codec::compression::{CompressionEncoding, EnabledCompressionEncodings}; |
| use crate::{ |
| body::BoxBody, |
| client::GrpcService, |
| codec::{encode_client, Codec, Decoder, Streaming}, |
| request::SanitizeHeaders, |
| Code, Request, Response, Status, |
| }; |
| use http::{ |
| header::{HeaderValue, CONTENT_TYPE, TE}, |
| uri::{PathAndQuery, Uri}, |
| }; |
| use http_body::Body; |
| use std::{fmt, future}; |
| use tokio_stream::{Stream, StreamExt}; |
| |
| /// A gRPC client dispatcher. |
| /// |
| /// This will wrap some inner [`GrpcService`] and will encode/decode |
| /// messages via the provided codec. |
| /// |
| /// Each request method takes a [`Request`], a [`PathAndQuery`], and a |
| /// [`Codec`]. The request contains the message to send via the |
| /// [`Codec::encoder`]. The path determines the fully qualified path |
| /// that will be append to the outgoing uri. The path must follow |
| /// the conventions explained in the [gRPC protocol definition] under `Path →`. An |
| /// example of this path could look like `/greeter.Greeter/SayHello`. |
| /// |
| /// [gRPC protocol definition]: https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests |
| pub struct Grpc<T> { |
| inner: T, |
| config: GrpcConfig, |
| } |
| |
| struct GrpcConfig { |
| origin: Uri, |
| /// Which compression encodings does the client accept? |
| accept_compression_encodings: EnabledCompressionEncodings, |
| /// The compression encoding that will be applied to requests. |
| send_compression_encodings: Option<CompressionEncoding>, |
| /// Limits the maximum size of a decoded message. |
| max_decoding_message_size: Option<usize>, |
| /// Limits the maximum size of an encoded message. |
| max_encoding_message_size: Option<usize>, |
| } |
| |
| impl<T> Grpc<T> { |
| /// Creates a new gRPC client with the provided [`GrpcService`]. |
| pub fn new(inner: T) -> Self { |
| Self::with_origin(inner, Uri::default()) |
| } |
| |
| /// Creates a new gRPC client with the provided [`GrpcService`] and `Uri`. |
| /// |
| /// The provided Uri will use only the scheme and authority parts as the |
| /// path_and_query portion will be set for each method. |
| pub fn with_origin(inner: T, origin: Uri) -> Self { |
| Self { |
| inner, |
| config: GrpcConfig { |
| origin, |
| send_compression_encodings: None, |
| accept_compression_encodings: EnabledCompressionEncodings::default(), |
| max_decoding_message_size: None, |
| max_encoding_message_size: None, |
| }, |
| } |
| } |
| |
| /// Compress requests with the provided encoding. |
| /// |
| /// Requires the server to accept the specified encoding, otherwise it might return an error. |
| /// |
| /// # Example |
| /// |
| /// The most common way of using this is through a client generated by tonic-build: |
| /// |
| /// ```rust |
| /// use tonic::transport::Channel; |
| /// # enum CompressionEncoding { Gzip } |
| /// # struct TestClient<T>(T); |
| /// # impl<T> TestClient<T> { |
| /// # fn new(channel: T) -> Self { Self(channel) } |
| /// # fn send_compressed(self, _: CompressionEncoding) -> Self { self } |
| /// # } |
| /// |
| /// # async { |
| /// let channel = Channel::builder("127.0.0.1:3000".parse().unwrap()) |
| /// .connect() |
| /// .await |
| /// .unwrap(); |
| /// |
| /// let client = TestClient::new(channel).send_compressed(CompressionEncoding::Gzip); |
| /// # }; |
| /// ``` |
| pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self { |
| self.config.send_compression_encodings = Some(encoding); |
| self |
| } |
| |
| /// Enable accepting compressed responses. |
| /// |
| /// Requires the server to also support sending compressed responses. |
| /// |
| /// # Example |
| /// |
| /// The most common way of using this is through a client generated by tonic-build: |
| /// |
| /// ```rust |
| /// use tonic::transport::Channel; |
| /// # enum CompressionEncoding { Gzip } |
| /// # struct TestClient<T>(T); |
| /// # impl<T> TestClient<T> { |
| /// # fn new(channel: T) -> Self { Self(channel) } |
| /// # fn accept_compressed(self, _: CompressionEncoding) -> Self { self } |
| /// # } |
| /// |
| /// # async { |
| /// let channel = Channel::builder("127.0.0.1:3000".parse().unwrap()) |
| /// .connect() |
| /// .await |
| /// .unwrap(); |
| /// |
| /// let client = TestClient::new(channel).accept_compressed(CompressionEncoding::Gzip); |
| /// # }; |
| /// ``` |
| pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self { |
| self.config.accept_compression_encodings.enable(encoding); |
| self |
| } |
| |
| /// Limits the maximum size of a decoded message. |
| /// |
| /// # Example |
| /// |
| /// The most common way of using this is through a client generated by tonic-build: |
| /// |
| /// ```rust |
| /// use tonic::transport::Channel; |
| /// # struct TestClient<T>(T); |
| /// # impl<T> TestClient<T> { |
| /// # fn new(channel: T) -> Self { Self(channel) } |
| /// # fn max_decoding_message_size(self, _: usize) -> Self { self } |
| /// # } |
| /// |
| /// # async { |
| /// let channel = Channel::builder("127.0.0.1:3000".parse().unwrap()) |
| /// .connect() |
| /// .await |
| /// .unwrap(); |
| /// |
| /// // Set the limit to 2MB, Defaults to 4MB. |
| /// let limit = 2 * 1024 * 1024; |
| /// let client = TestClient::new(channel).max_decoding_message_size(limit); |
| /// # }; |
| /// ``` |
| pub fn max_decoding_message_size(mut self, limit: usize) -> Self { |
| self.config.max_decoding_message_size = Some(limit); |
| self |
| } |
| |
| /// Limits the maximum size of an ecoded message. |
| /// |
| /// # Example |
| /// |
| /// The most common way of using this is through a client generated by tonic-build: |
| /// |
| /// ```rust |
| /// use tonic::transport::Channel; |
| /// # struct TestClient<T>(T); |
| /// # impl<T> TestClient<T> { |
| /// # fn new(channel: T) -> Self { Self(channel) } |
| /// # fn max_encoding_message_size(self, _: usize) -> Self { self } |
| /// # } |
| /// |
| /// # async { |
| /// let channel = Channel::builder("127.0.0.1:3000".parse().unwrap()) |
| /// .connect() |
| /// .await |
| /// .unwrap(); |
| /// |
| /// // Set the limit to 2MB, Defaults to 4MB. |
| /// let limit = 2 * 1024 * 1024; |
| /// let client = TestClient::new(channel).max_encoding_message_size(limit); |
| /// # }; |
| /// ``` |
| pub fn max_encoding_message_size(mut self, limit: usize) -> Self { |
| self.config.max_encoding_message_size = Some(limit); |
| self |
| } |
| |
| /// Check if the inner [`GrpcService`] is able to accept a new request. |
| /// |
| /// This will call [`GrpcService::poll_ready`] until it returns ready or |
| /// an error. If this returns ready the inner [`GrpcService`] is ready to |
| /// accept one more request. |
| pub async fn ready(&mut self) -> Result<(), T::Error> |
| where |
| T: GrpcService<BoxBody>, |
| { |
| future::poll_fn(|cx| self.inner.poll_ready(cx)).await |
| } |
| |
| /// Send a single unary gRPC request. |
| pub async fn unary<M1, M2, C>( |
| &mut self, |
| request: Request<M1>, |
| path: PathAndQuery, |
| codec: C, |
| ) -> Result<Response<M2>, Status> |
| where |
| T: GrpcService<BoxBody>, |
| T::ResponseBody: Body + Send + 'static, |
| <T::ResponseBody as Body>::Error: Into<crate::Error>, |
| C: Codec<Encode = M1, Decode = M2>, |
| M1: Send + Sync + 'static, |
| M2: Send + Sync + 'static, |
| { |
| let request = request.map(|m| tokio_stream::once(m)); |
| self.client_streaming(request, path, codec).await |
| } |
| |
| /// Send a client side streaming gRPC request. |
| pub async fn client_streaming<S, M1, M2, C>( |
| &mut self, |
| request: Request<S>, |
| path: PathAndQuery, |
| codec: C, |
| ) -> Result<Response<M2>, Status> |
| where |
| T: GrpcService<BoxBody>, |
| T::ResponseBody: Body + Send + 'static, |
| <T::ResponseBody as Body>::Error: Into<crate::Error>, |
| S: Stream<Item = M1> + Send + 'static, |
| C: Codec<Encode = M1, Decode = M2>, |
| M1: Send + Sync + 'static, |
| M2: Send + Sync + 'static, |
| { |
| let (mut parts, body, extensions) = |
| self.streaming(request, path, codec).await?.into_parts(); |
| |
| tokio::pin!(body); |
| |
| let message = body |
| .try_next() |
| .await |
| .map_err(|mut status| { |
| status.metadata_mut().merge(parts.clone()); |
| status |
| })? |
| .ok_or_else(|| Status::new(Code::Internal, "Missing response message."))?; |
| |
| if let Some(trailers) = body.trailers().await? { |
| parts.merge(trailers); |
| } |
| |
| Ok(Response::from_parts(parts, message, extensions)) |
| } |
| |
| /// Send a server side streaming gRPC request. |
| pub async fn server_streaming<M1, M2, C>( |
| &mut self, |
| request: Request<M1>, |
| path: PathAndQuery, |
| codec: C, |
| ) -> Result<Response<Streaming<M2>>, Status> |
| where |
| T: GrpcService<BoxBody>, |
| T::ResponseBody: Body + Send + 'static, |
| <T::ResponseBody as Body>::Error: Into<crate::Error>, |
| C: Codec<Encode = M1, Decode = M2>, |
| M1: Send + Sync + 'static, |
| M2: Send + Sync + 'static, |
| { |
| let request = request.map(|m| tokio_stream::once(m)); |
| self.streaming(request, path, codec).await |
| } |
| |
| /// Send a bi-directional streaming gRPC request. |
| pub async fn streaming<S, M1, M2, C>( |
| &mut self, |
| request: Request<S>, |
| path: PathAndQuery, |
| mut codec: C, |
| ) -> Result<Response<Streaming<M2>>, Status> |
| where |
| T: GrpcService<BoxBody>, |
| T::ResponseBody: Body + Send + 'static, |
| <T::ResponseBody as Body>::Error: Into<crate::Error>, |
| S: Stream<Item = M1> + Send + 'static, |
| C: Codec<Encode = M1, Decode = M2>, |
| M1: Send + Sync + 'static, |
| M2: Send + Sync + 'static, |
| { |
| let request = request |
| .map(|s| { |
| encode_client( |
| codec.encoder(), |
| s, |
| self.config.send_compression_encodings, |
| self.config.max_encoding_message_size, |
| ) |
| }) |
| .map(BoxBody::new); |
| |
| let request = self.config.prepare_request(request, path); |
| |
| let response = self |
| .inner |
| .call(request) |
| .await |
| .map_err(Status::from_error_generic)?; |
| |
| let decoder = codec.decoder(); |
| |
| self.create_response(decoder, response) |
| } |
| |
| // Keeping this code in a separate function from Self::streaming lets functions that return the |
| // same output share the generated binary code |
| fn create_response<M2>( |
| &self, |
| decoder: impl Decoder<Item = M2, Error = Status> + Send + 'static, |
| response: http::Response<T::ResponseBody>, |
| ) -> Result<Response<Streaming<M2>>, Status> |
| where |
| T: GrpcService<BoxBody>, |
| T::ResponseBody: Body + Send + 'static, |
| <T::ResponseBody as Body>::Error: Into<crate::Error>, |
| { |
| let encoding = CompressionEncoding::from_encoding_header( |
| response.headers(), |
| self.config.accept_compression_encodings, |
| )?; |
| |
| let status_code = response.status(); |
| let trailers_only_status = Status::from_header_map(response.headers()); |
| |
| // We do not need to check for trailers if the `grpc-status` header is present |
| // with a valid code. |
| let expect_additional_trailers = if let Some(status) = trailers_only_status { |
| if status.code() != Code::Ok { |
| return Err(status); |
| } |
| |
| false |
| } else { |
| true |
| }; |
| |
| let response = response.map(|body| { |
| if expect_additional_trailers { |
| Streaming::new_response( |
| decoder, |
| body, |
| status_code, |
| encoding, |
| self.config.max_decoding_message_size, |
| ) |
| } else { |
| Streaming::new_empty(decoder, body) |
| } |
| }); |
| |
| Ok(Response::from_http(response)) |
| } |
| } |
| |
| impl GrpcConfig { |
| fn prepare_request( |
| &self, |
| request: Request<BoxBody>, |
| path: PathAndQuery, |
| ) -> http::Request<BoxBody> { |
| let mut parts = self.origin.clone().into_parts(); |
| |
| match &parts.path_and_query { |
| Some(pnq) if pnq != "/" => { |
| parts.path_and_query = Some( |
| format!("{}{}", pnq.path(), path) |
| .parse() |
| .expect("must form valid path_and_query"), |
| ) |
| } |
| _ => { |
| parts.path_and_query = Some(path); |
| } |
| } |
| |
| let uri = Uri::from_parts(parts).expect("path_and_query only is valid Uri"); |
| |
| let mut request = request.into_http( |
| uri, |
| http::Method::POST, |
| http::Version::HTTP_2, |
| SanitizeHeaders::Yes, |
| ); |
| |
| // Add the gRPC related HTTP headers |
| request |
| .headers_mut() |
| .insert(TE, HeaderValue::from_static("trailers")); |
| |
| // Set the content type |
| request |
| .headers_mut() |
| .insert(CONTENT_TYPE, HeaderValue::from_static("application/grpc")); |
| |
| #[cfg(any(feature = "gzip", feature = "zstd"))] |
| if let Some(encoding) = self.send_compression_encodings { |
| request.headers_mut().insert( |
| crate::codec::compression::ENCODING_HEADER, |
| encoding.into_header_value(), |
| ); |
| } |
| |
| if let Some(header_value) = self |
| .accept_compression_encodings |
| .into_accept_encoding_header_value() |
| { |
| request.headers_mut().insert( |
| crate::codec::compression::ACCEPT_ENCODING_HEADER, |
| header_value, |
| ); |
| } |
| |
| request |
| } |
| } |
| |
| impl<T: Clone> Clone for Grpc<T> { |
| fn clone(&self) -> Self { |
| Self { |
| inner: self.inner.clone(), |
| config: GrpcConfig { |
| origin: self.config.origin.clone(), |
| send_compression_encodings: self.config.send_compression_encodings, |
| accept_compression_encodings: self.config.accept_compression_encodings, |
| max_encoding_message_size: self.config.max_encoding_message_size, |
| max_decoding_message_size: self.config.max_decoding_message_size, |
| }, |
| } |
| } |
| } |
| |
| impl<T: fmt::Debug> fmt::Debug for Grpc<T> { |
| fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
| let mut f = f.debug_struct("Grpc"); |
| |
| f.field("inner", &self.inner); |
| |
| f.field("origin", &self.config.origin); |
| |
| f.field( |
| "compression_encoding", |
| &self.config.send_compression_encodings, |
| ); |
| |
| f.field( |
| "accept_compression_encodings", |
| &self.config.accept_compression_encodings, |
| ); |
| |
| f.field( |
| "max_decoding_message_size", |
| &self.config.max_decoding_message_size, |
| ); |
| |
| f.field( |
| "max_encoding_message_size", |
| &self.config.max_encoding_message_size, |
| ); |
| |
| f.finish() |
| } |
| } |