blob: 6f8cb89dc19cf3dc7494684f764da2b2f357a829 [file] [log] [blame] [edit]
use super::{rejection::*, FromRequestParts};
use async_trait::async_trait;
use http::{request::Parts, Uri};
use serde::de::DeserializeOwned;
/// Extractor that deserializes query strings into some type.
///
/// `T` is expected to implement [`serde::Deserialize`].
///
/// # Example
///
/// ```rust,no_run
/// use axum::{
/// extract::Query,
/// routing::get,
/// Router,
/// };
/// use serde::Deserialize;
///
/// #[derive(Deserialize)]
/// struct Pagination {
/// page: usize,
/// per_page: usize,
/// }
///
/// // This will parse query strings like `?page=2&per_page=30` into `Pagination`
/// // structs.
/// async fn list_things(pagination: Query<Pagination>) {
/// let pagination: Pagination = pagination.0;
///
/// // ...
/// }
///
/// let app = Router::new().route("/list_things", get(list_things));
/// # async {
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
/// # };
/// ```
///
/// If the query string cannot be parsed it will reject the request with a `400
/// Bad Request` response.
///
/// For handling values being empty vs missing see the [query-params-with-empty-strings][example]
/// example.
///
/// [example]: https://github.com/tokio-rs/axum/blob/main/examples/query-params-with-empty-strings/src/main.rs
#[cfg_attr(docsrs, doc(cfg(feature = "query")))]
#[derive(Debug, Clone, Copy, Default)]
pub struct Query<T>(pub T);
#[async_trait]
impl<T, S> FromRequestParts<S> for Query<T>
where
T: DeserializeOwned,
S: Send + Sync,
{
type Rejection = QueryRejection;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
Self::try_from_uri(&parts.uri)
}
}
impl<T> Query<T>
where
T: DeserializeOwned,
{
/// Attempts to construct a [`Query`] from a reference to a [`Uri`].
///
/// # Example
/// ```
/// use axum::extract::Query;
/// use http::Uri;
/// use serde::Deserialize;
///
/// #[derive(Deserialize)]
/// struct ExampleParams {
/// foo: String,
/// bar: u32,
/// }
///
/// let uri: Uri = "http://example.com/path?foo=hello&bar=42".parse().unwrap();
/// let result: Query<ExampleParams> = Query::try_from_uri(&uri).unwrap();
/// assert_eq!(result.foo, String::from("hello"));
/// assert_eq!(result.bar, 42);
/// ```
pub fn try_from_uri(value: &Uri) -> Result<Self, QueryRejection> {
let query = value.query().unwrap_or_default();
let params =
serde_urlencoded::from_str(query).map_err(FailedToDeserializeQueryString::from_err)?;
Ok(Query(params))
}
}
axum_core::__impl_deref!(Query);
#[cfg(test)]
mod tests {
use crate::{routing::get, test_helpers::TestClient, Router};
use super::*;
use axum_core::extract::FromRequest;
use http::{Request, StatusCode};
use serde::Deserialize;
use std::fmt::Debug;
async fn check<T>(uri: impl AsRef<str>, value: T)
where
T: DeserializeOwned + PartialEq + Debug,
{
let req = Request::builder().uri(uri.as_ref()).body(()).unwrap();
assert_eq!(Query::<T>::from_request(req, &()).await.unwrap().0, value);
}
#[crate::test]
async fn test_query() {
#[derive(Debug, PartialEq, Deserialize)]
struct Pagination {
size: Option<u64>,
page: Option<u64>,
}
check(
"http://example.com/test",
Pagination {
size: None,
page: None,
},
)
.await;
check(
"http://example.com/test?size=10",
Pagination {
size: Some(10),
page: None,
},
)
.await;
check(
"http://example.com/test?size=10&page=20",
Pagination {
size: Some(10),
page: Some(20),
},
)
.await;
}
#[crate::test]
async fn correct_rejection_status_code() {
#[derive(Deserialize)]
#[allow(dead_code)]
struct Params {
n: i32,
}
async fn handler(_: Query<Params>) {}
let app = Router::new().route("/", get(handler));
let client = TestClient::new(app);
let res = client.get("/?n=hi").send().await;
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
}
#[test]
fn test_try_from_uri() {
#[derive(Deserialize)]
struct TestQueryParams {
foo: String,
bar: u32,
}
let uri: Uri = "http://example.com/path?foo=hello&bar=42".parse().unwrap();
let result: Query<TestQueryParams> = Query::try_from_uri(&uri).unwrap();
assert_eq!(result.foo, String::from("hello"));
assert_eq!(result.bar, 42);
}
#[test]
fn test_try_from_uri_with_invalid_query() {
#[derive(Deserialize)]
struct TestQueryParams {
_foo: String,
_bar: u32,
}
let uri: Uri = "http://example.com/path?foo=hello&bar=invalid"
.parse()
.unwrap();
let result: Result<Query<TestQueryParams>, _> = Query::try_from_uri(&uri);
assert!(result.is_err());
}
}