blob: e2307d391c05d97527f64c06835f18e8edea6bc9 [file] [log] [blame] [edit]
use async_trait::async_trait;
use axum_core::extract::{FromRef, FromRequestParts};
use http::request::Parts;
use std::{
convert::Infallible,
ops::{Deref, DerefMut},
};
/// Extractor for state.
///
/// See ["Accessing state in middleware"][state-from-middleware] for how to
/// access state in middleware.
///
/// [state-from-middleware]: crate::middleware#accessing-state-in-middleware
///
/// # With `Router`
///
/// ```
/// use axum::{Router, routing::get, extract::State};
///
/// // the application state
/// //
/// // here you can put configuration, database connection pools, or whatever
/// // state you need
/// //
/// // see "When states need to implement `Clone`" for more details on why we need
/// // `#[derive(Clone)]` here.
/// #[derive(Clone)]
/// struct AppState {}
///
/// let state = AppState {};
///
/// // create a `Router` that holds our state
/// let app = Router::new()
/// .route("/", get(handler))
/// // provide the state so the router can access it
/// .with_state(state);
///
/// async fn handler(
/// // access the state via the `State` extractor
/// // extracting a state of the wrong type results in a compile error
/// State(state): State<AppState>,
/// ) {
/// // use `state`...
/// }
/// # let _: axum::Router = app;
/// ```
///
/// Note that `State` is an extractor, so be sure to put it before any body
/// extractors, see ["the order of extractors"][order-of-extractors].
///
/// [order-of-extractors]: crate::extract#the-order-of-extractors
///
/// ## Combining stateful routers
///
/// Multiple [`Router`]s can be combined with [`Router::nest`] or [`Router::merge`]
/// When combining [`Router`]s with one of these methods, the [`Router`]s must have
/// the same state type. Generally, this can be inferred automatically:
///
/// ```
/// use axum::{Router, routing::get, extract::State};
///
/// #[derive(Clone)]
/// struct AppState {}
///
/// let state = AppState {};
///
/// // create a `Router` that will be nested within another
/// let api = Router::new()
/// .route("/posts", get(posts_handler));
///
/// let app = Router::new()
/// .nest("/api", api)
/// .with_state(state);
///
/// async fn posts_handler(State(state): State<AppState>) {
/// // use `state`...
/// }
/// # let _: axum::Router = app;
/// ```
///
/// However, if you are composing [`Router`]s that are defined in separate scopes,
/// you may need to annotate the [`State`] type explicitly:
///
/// ```
/// use axum::{Router, routing::get, extract::State};
///
/// #[derive(Clone)]
/// struct AppState {}
///
/// fn make_app() -> Router {
/// let state = AppState {};
///
/// Router::new()
/// .nest("/api", make_api())
/// .with_state(state) // the outer Router's state is inferred
/// }
///
/// // the inner Router must specify its state type to compose with the
/// // outer router
/// fn make_api() -> Router<AppState> {
/// Router::new()
/// .route("/posts", get(posts_handler))
/// }
///
/// async fn posts_handler(State(state): State<AppState>) {
/// // use `state`...
/// }
/// # let _: axum::Router = make_app();
/// ```
///
/// In short, a [`Router`]'s generic state type defaults to `()`
/// (no state) unless [`Router::with_state`] is called or the value
/// of the generic type is given explicitly.
///
/// [`Router`]: crate::Router
/// [`Router::merge`]: crate::Router::merge
/// [`Router::nest`]: crate::Router::nest
/// [`Router::with_state`]: crate::Router::with_state
///
/// # With `MethodRouter`
///
/// ```
/// use axum::{routing::get, extract::State};
///
/// #[derive(Clone)]
/// struct AppState {}
///
/// let state = AppState {};
///
/// let method_router_with_state = get(handler)
/// // provide the state so the handler can access it
/// .with_state(state);
///
/// async fn handler(State(state): State<AppState>) {
/// // use `state`...
/// }
/// # async {
/// # axum::Server::bind(&"".parse().unwrap()).serve(method_router_with_state.into_make_service()).await.unwrap();
/// # };
/// ```
///
/// # With `Handler`
///
/// ```
/// use axum::{routing::get, handler::Handler, extract::State};
///
/// #[derive(Clone)]
/// struct AppState {}
///
/// let state = AppState {};
///
/// async fn handler(State(state): State<AppState>) {
/// // use `state`...
/// }
///
/// // provide the state so the handler can access it
/// let handler_with_state = handler.with_state(state);
///
/// # async {
/// axum::Server::bind(&"0.0.0.0:3000".parse().unwrap())
/// .serve(handler_with_state.into_make_service())
/// .await
/// .expect("server failed");
/// # };
/// ```
///
/// # Substates
///
/// [`State`] only allows a single state type but you can use [`FromRef`] to extract "substates":
///
/// ```
/// use axum::{Router, routing::get, extract::{State, FromRef}};
///
/// // the application state
/// #[derive(Clone)]
/// struct AppState {
/// // that holds some api specific state
/// api_state: ApiState,
/// }
///
/// // the api specific state
/// #[derive(Clone)]
/// struct ApiState {}
///
/// // support converting an `AppState` in an `ApiState`
/// impl FromRef<AppState> for ApiState {
/// fn from_ref(app_state: &AppState) -> ApiState {
/// app_state.api_state.clone()
/// }
/// }
///
/// let state = AppState {
/// api_state: ApiState {},
/// };
///
/// let app = Router::new()
/// .route("/", get(handler))
/// .route("/api/users", get(api_users))
/// .with_state(state);
///
/// async fn api_users(
/// // access the api specific state
/// State(api_state): State<ApiState>,
/// ) {
/// }
///
/// async fn handler(
/// // we can still access to top level state
/// State(state): State<AppState>,
/// ) {
/// }
/// # let _: axum::Router = app;
/// ```
///
/// For convenience `FromRef` can also be derived using `#[derive(FromRef)]`.
///
/// # For library authors
///
/// If you're writing a library that has an extractor that needs state, this is the recommended way
/// to do it:
///
/// ```rust
/// use axum_core::extract::{FromRequestParts, FromRef};
/// use http::request::Parts;
/// use async_trait::async_trait;
/// use std::convert::Infallible;
///
/// // the extractor your library provides
/// struct MyLibraryExtractor;
///
/// #[async_trait]
/// impl<S> FromRequestParts<S> for MyLibraryExtractor
/// where
/// // keep `S` generic but require that it can produce a `MyLibraryState`
/// // this means users will have to implement `FromRef<UserState> for MyLibraryState`
/// MyLibraryState: FromRef<S>,
/// S: Send + Sync,
/// {
/// type Rejection = Infallible;
///
/// async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
/// // get a `MyLibraryState` from a reference to the state
/// let state = MyLibraryState::from_ref(state);
///
/// // ...
/// # todo!()
/// }
/// }
///
/// // the state your library needs
/// struct MyLibraryState {
/// // ...
/// }
/// ```
///
/// # When states need to implement `Clone`
///
/// Your top level state type must implement `Clone` to be extractable with `State`:
///
/// ```
/// use axum::extract::State;
///
/// // no substates, so to extract to `State<AppState>` we must implement `Clone` for `AppState`
/// #[derive(Clone)]
/// struct AppState {}
///
/// async fn handler(State(state): State<AppState>) {
/// // ...
/// }
/// ```
///
/// This works because of [`impl<S> FromRef<S> for S where S: Clone`][`FromRef`].
///
/// This is also true if you're extracting substates, unless you _never_ extract the top level
/// state itself:
///
/// ```
/// use axum::extract::{State, FromRef};
///
/// // we never extract `State<AppState>`, just `State<InnerState>`. So `AppState` doesn't need to
/// // implement `Clone`
/// struct AppState {
/// inner: InnerState,
/// }
///
/// #[derive(Clone)]
/// struct InnerState {}
///
/// impl FromRef<AppState> for InnerState {
/// fn from_ref(app_state: &AppState) -> InnerState {
/// app_state.inner.clone()
/// }
/// }
///
/// async fn api_users(State(inner): State<InnerState>) {
/// // ...
/// }
/// ```
///
/// In general however we recommend you implement `Clone` for all your state types to avoid
/// potential type errors.
///
/// # Shared mutable state
///
/// [As state is global within a `Router`][global] you can't directly get a mutable reference to
/// the state.
///
/// The most basic solution is to use an `Arc<Mutex<_>>`. Which kind of mutex you need depends on
/// your use case. See [the tokio docs] for more details.
///
/// Note that holding a locked `std::sync::Mutex` across `.await` points will result in `!Send`
/// futures which are incompatible with axum. If you need to hold a mutex across `.await` points,
/// consider using a `tokio::sync::Mutex` instead.
///
/// ## Example
///
/// ```
/// use axum::{Router, routing::get, extract::State};
/// use std::sync::{Arc, Mutex};
///
/// #[derive(Clone)]
/// struct AppState {
/// data: Arc<Mutex<String>>,
/// }
///
/// async fn handler(State(state): State<AppState>) {
/// let mut data = state.data.lock().expect("mutex was poisoned");
/// *data = "updated foo".to_owned();
///
/// // ...
/// }
///
/// let state = AppState {
/// data: Arc::new(Mutex::new("foo".to_owned())),
/// };
///
/// let app = Router::new()
/// .route("/", get(handler))
/// .with_state(state);
/// # let _: Router = app;
/// ```
///
/// [global]: crate::Router::with_state
/// [the tokio docs]: https://docs.rs/tokio/1.25.0/tokio/sync/struct.Mutex.html#which-kind-of-mutex-should-you-use
#[derive(Debug, Default, Clone, Copy)]
pub struct State<S>(pub S);
#[async_trait]
impl<OuterState, InnerState> FromRequestParts<OuterState> for State<InnerState>
where
InnerState: FromRef<OuterState>,
OuterState: Send + Sync,
{
type Rejection = Infallible;
async fn from_request_parts(
_parts: &mut Parts,
state: &OuterState,
) -> Result<Self, Self::Rejection> {
let inner_state = InnerState::from_ref(state);
Ok(Self(inner_state))
}
}
impl<S> Deref for State<S> {
type Target = S;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<S> DerefMut for State<S> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}