| //! Extractor for getting connection information from a client. |
| //! |
| //! See [`Router::into_make_service_with_connect_info`] for more details. |
| //! |
| //! [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info |
| |
| use super::{Extension, FromRequestParts}; |
| use crate::middleware::AddExtension; |
| use async_trait::async_trait; |
| use http::request::Parts; |
| use hyper::server::conn::AddrStream; |
| use std::{ |
| convert::Infallible, |
| fmt, |
| future::ready, |
| marker::PhantomData, |
| net::SocketAddr, |
| task::{Context, Poll}, |
| }; |
| use tower_layer::Layer; |
| use tower_service::Service; |
| |
| /// A [`MakeService`] created from a router. |
| /// |
| /// See [`Router::into_make_service_with_connect_info`] for more details. |
| /// |
| /// [`MakeService`]: tower::make::MakeService |
| /// [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info |
| pub struct IntoMakeServiceWithConnectInfo<S, C> { |
| svc: S, |
| _connect_info: PhantomData<fn() -> C>, |
| } |
| |
| #[test] |
| fn traits() { |
| use crate::test_helpers::*; |
| assert_send::<IntoMakeServiceWithConnectInfo<(), NotSendSync>>(); |
| } |
| |
| impl<S, C> IntoMakeServiceWithConnectInfo<S, C> { |
| pub(crate) fn new(svc: S) -> Self { |
| Self { |
| svc, |
| _connect_info: PhantomData, |
| } |
| } |
| } |
| |
| impl<S, C> fmt::Debug for IntoMakeServiceWithConnectInfo<S, C> |
| where |
| S: fmt::Debug, |
| { |
| fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
| f.debug_struct("IntoMakeServiceWithConnectInfo") |
| .field("svc", &self.svc) |
| .finish() |
| } |
| } |
| |
| impl<S, C> Clone for IntoMakeServiceWithConnectInfo<S, C> |
| where |
| S: Clone, |
| { |
| fn clone(&self) -> Self { |
| Self { |
| svc: self.svc.clone(), |
| _connect_info: PhantomData, |
| } |
| } |
| } |
| |
| /// Trait that connected IO resources implement and use to produce information |
| /// about the connection. |
| /// |
| /// The goal for this trait is to allow users to implement custom IO types that |
| /// can still provide the same connection metadata. |
| /// |
| /// See [`Router::into_make_service_with_connect_info`] for more details. |
| /// |
| /// [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info |
| pub trait Connected<T>: Clone + Send + Sync + 'static { |
| /// Create type holding information about the connection. |
| fn connect_info(target: T) -> Self; |
| } |
| |
| impl Connected<&AddrStream> for SocketAddr { |
| fn connect_info(target: &AddrStream) -> Self { |
| target.remote_addr() |
| } |
| } |
| |
| impl<S, C, T> Service<T> for IntoMakeServiceWithConnectInfo<S, C> |
| where |
| S: Clone, |
| C: Connected<T>, |
| { |
| type Response = AddExtension<S, ConnectInfo<C>>; |
| type Error = Infallible; |
| type Future = ResponseFuture<S, C>; |
| |
| #[inline] |
| fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { |
| Poll::Ready(Ok(())) |
| } |
| |
| fn call(&mut self, target: T) -> Self::Future { |
| let connect_info = ConnectInfo(C::connect_info(target)); |
| let svc = Extension(connect_info).layer(self.svc.clone()); |
| ResponseFuture::new(ready(Ok(svc))) |
| } |
| } |
| |
| opaque_future! { |
| /// Response future for [`IntoMakeServiceWithConnectInfo`]. |
| pub type ResponseFuture<S, C> = |
| std::future::Ready<Result<AddExtension<S, ConnectInfo<C>>, Infallible>>; |
| } |
| |
| /// Extractor for getting connection information produced by a [`Connected`]. |
| /// |
| /// Note this extractor requires you to use |
| /// [`Router::into_make_service_with_connect_info`] to run your app |
| /// otherwise it will fail at runtime. |
| /// |
| /// See [`Router::into_make_service_with_connect_info`] for more details. |
| /// |
| /// [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info |
| #[derive(Clone, Copy, Debug)] |
| pub struct ConnectInfo<T>(pub T); |
| |
| #[async_trait] |
| impl<S, T> FromRequestParts<S> for ConnectInfo<T> |
| where |
| S: Send + Sync, |
| T: Clone + Send + Sync + 'static, |
| { |
| type Rejection = <Extension<Self> as FromRequestParts<S>>::Rejection; |
| |
| async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> { |
| match Extension::<Self>::from_request_parts(parts, state).await { |
| Ok(Extension(connect_info)) => Ok(connect_info), |
| Err(err) => match parts.extensions.get::<MockConnectInfo<T>>() { |
| Some(MockConnectInfo(connect_info)) => Ok(Self(connect_info.clone())), |
| None => Err(err), |
| }, |
| } |
| } |
| } |
| |
| axum_core::__impl_deref!(ConnectInfo); |
| |
| /// Middleware used to mock [`ConnectInfo`] during tests. |
| /// |
| /// If you're accidentally using [`MockConnectInfo`] and |
| /// [`Router::into_make_service_with_connect_info`] at the same time then |
| /// [`Router::into_make_service_with_connect_info`] takes precedence. |
| /// |
| /// # Example |
| /// |
| /// ``` |
| /// use axum::{ |
| /// Router, |
| /// extract::connect_info::{MockConnectInfo, ConnectInfo}, |
| /// body::Body, |
| /// routing::get, |
| /// http::{Request, StatusCode}, |
| /// }; |
| /// use std::net::SocketAddr; |
| /// use tower::ServiceExt; |
| /// |
| /// async fn handler(ConnectInfo(addr): ConnectInfo<SocketAddr>) {} |
| /// |
| /// // this router you can run with `app.into_make_service_with_connect_info::<SocketAddr>()` |
| /// fn app() -> Router { |
| /// Router::new().route("/", get(handler)) |
| /// } |
| /// |
| /// // use this router for tests |
| /// fn test_app() -> Router { |
| /// app().layer(MockConnectInfo(SocketAddr::from(([0, 0, 0, 0], 1337)))) |
| /// } |
| /// |
| /// // #[tokio::test] |
| /// async fn some_test() { |
| /// let app = test_app(); |
| /// |
| /// let request = Request::new(Body::empty()); |
| /// let response = app.oneshot(request).await.unwrap(); |
| /// assert_eq!(response.status(), StatusCode::OK); |
| /// } |
| /// # |
| /// # #[tokio::main] |
| /// # async fn main() { |
| /// # some_test().await; |
| /// # } |
| /// ``` |
| /// |
| /// [`Router::into_make_service_with_connect_info`]: crate::Router::into_make_service_with_connect_info |
| #[derive(Clone, Copy, Debug)] |
| pub struct MockConnectInfo<T>(pub T); |
| |
| impl<S, T> Layer<S> for MockConnectInfo<T> |
| where |
| T: Clone + Send + Sync + 'static, |
| { |
| type Service = <Extension<Self> as Layer<S>>::Service; |
| |
| fn layer(&self, inner: S) -> Self::Service { |
| Extension(self.clone()).layer(inner) |
| } |
| } |
| |
| #[cfg(test)] |
| mod tests { |
| use super::*; |
| use crate::{routing::get, test_helpers::TestClient, Router, Server}; |
| use std::net::{SocketAddr, TcpListener}; |
| |
| #[crate::test] |
| async fn socket_addr() { |
| async fn handler(ConnectInfo(addr): ConnectInfo<SocketAddr>) -> String { |
| format!("{addr}") |
| } |
| |
| let listener = TcpListener::bind("127.0.0.1:0").unwrap(); |
| let addr = listener.local_addr().unwrap(); |
| |
| let (tx, rx) = tokio::sync::oneshot::channel(); |
| tokio::spawn(async move { |
| let app = Router::new().route("/", get(handler)); |
| let server = Server::from_tcp(listener) |
| .unwrap() |
| .serve(app.into_make_service_with_connect_info::<SocketAddr>()); |
| tx.send(()).unwrap(); |
| server.await.expect("server error"); |
| }); |
| rx.await.unwrap(); |
| |
| let client = reqwest::Client::new(); |
| |
| let res = client.get(format!("http://{addr}")).send().await.unwrap(); |
| let body = res.text().await.unwrap(); |
| assert!(body.starts_with("127.0.0.1:")); |
| } |
| |
| #[crate::test] |
| async fn custom() { |
| #[derive(Clone, Debug)] |
| struct MyConnectInfo { |
| value: &'static str, |
| } |
| |
| impl Connected<&AddrStream> for MyConnectInfo { |
| fn connect_info(_target: &AddrStream) -> Self { |
| Self { |
| value: "it worked!", |
| } |
| } |
| } |
| |
| async fn handler(ConnectInfo(addr): ConnectInfo<MyConnectInfo>) -> &'static str { |
| addr.value |
| } |
| |
| let listener = TcpListener::bind("127.0.0.1:0").unwrap(); |
| let addr = listener.local_addr().unwrap(); |
| |
| let (tx, rx) = tokio::sync::oneshot::channel(); |
| tokio::spawn(async move { |
| let app = Router::new().route("/", get(handler)); |
| let server = Server::from_tcp(listener) |
| .unwrap() |
| .serve(app.into_make_service_with_connect_info::<MyConnectInfo>()); |
| tx.send(()).unwrap(); |
| server.await.expect("server error"); |
| }); |
| rx.await.unwrap(); |
| |
| let client = reqwest::Client::new(); |
| |
| let res = client.get(format!("http://{addr}")).send().await.unwrap(); |
| let body = res.text().await.unwrap(); |
| assert_eq!(body, "it worked!"); |
| } |
| |
| #[crate::test] |
| async fn mock_connect_info() { |
| async fn handler(ConnectInfo(addr): ConnectInfo<SocketAddr>) -> String { |
| format!("{addr}") |
| } |
| |
| let app = Router::new() |
| .route("/", get(handler)) |
| .layer(MockConnectInfo(SocketAddr::from(([0, 0, 0, 0], 1337)))); |
| |
| let client = TestClient::new(app); |
| |
| let res = client.get("/").send().await; |
| let body = res.text().await; |
| assert!(body.starts_with("0.0.0.0:1337")); |
| } |
| |
| #[crate::test] |
| async fn both_mock_and_real_connect_info() { |
| async fn handler(ConnectInfo(addr): ConnectInfo<SocketAddr>) -> String { |
| format!("{addr}") |
| } |
| |
| let listener = TcpListener::bind("127.0.0.1:0").unwrap(); |
| let addr = listener.local_addr().unwrap(); |
| |
| tokio::spawn(async move { |
| let app = Router::new() |
| .route("/", get(handler)) |
| .layer(MockConnectInfo(SocketAddr::from(([0, 0, 0, 0], 1337)))); |
| |
| let server = Server::from_tcp(listener) |
| .unwrap() |
| .serve(app.into_make_service_with_connect_info::<SocketAddr>()); |
| server.await.expect("server error"); |
| }); |
| |
| let client = reqwest::Client::new(); |
| |
| let res = client.get(format!("http://{addr}")).send().await.unwrap(); |
| let body = res.text().await.unwrap(); |
| assert!(body.starts_with("127.0.0.1:")); |
| } |
| } |