| //! HTTP Upgrades |
| //! |
| //! This module deals with managing [HTTP Upgrades][mdn] in hyper. Since |
| //! several concepts in HTTP allow for first talking HTTP, and then converting |
| //! to a different protocol, this module conflates them into a single API. |
| //! Those include: |
| //! |
| //! - HTTP/1.1 Upgrades |
| //! - HTTP `CONNECT` |
| //! |
| //! You are responsible for any other pre-requisites to establish an upgrade, |
| //! such as sending the appropriate headers, methods, and status codes. You can |
| //! then use [`on`][] to grab a `Future` which will resolve to the upgraded |
| //! connection object, or an error if the upgrade fails. |
| //! |
| //! [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Protocol_upgrade_mechanism |
| //! |
| //! # Client |
| //! |
| //! Sending an HTTP upgrade from the [`client`](super::client) involves setting |
| //! either the appropriate method, if wanting to `CONNECT`, or headers such as |
| //! `Upgrade` and `Connection`, on the `http::Request`. Once receiving the |
| //! `http::Response` back, you must check for the specific information that the |
| //! upgrade is agreed upon by the server (such as a `101` status code), and then |
| //! get the `Future` from the `Response`. |
| //! |
| //! # Server |
| //! |
| //! Receiving upgrade requests in a server requires you to check the relevant |
| //! headers in a `Request`, and if an upgrade should be done, you then send the |
| //! corresponding headers in a response. To then wait for hyper to finish the |
| //! upgrade, you call `on()` with the `Request`, and then can spawn a task |
| //! awaiting it. |
| //! |
| //! # Example |
| //! |
| //! See [this example][example] showing how upgrades work with both |
| //! Clients and Servers. |
| //! |
| //! [example]: https://github.com/hyperium/hyper/blob/master/examples/upgrades.rs |
| |
| use std::any::TypeId; |
| use std::error::Error as StdError; |
| use std::fmt; |
| use std::io; |
| use std::marker::Unpin; |
| |
| use bytes::Bytes; |
| use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; |
| use tokio::sync::oneshot; |
| #[cfg(any(feature = "http1", feature = "http2"))] |
| use tracing::trace; |
| |
| use crate::common::io::Rewind; |
| use crate::common::{task, Future, Pin, Poll}; |
| |
| /// An upgraded HTTP connection. |
| /// |
| /// This type holds a trait object internally of the original IO that |
| /// was used to speak HTTP before the upgrade. It can be used directly |
| /// as a `Read` or `Write` for convenience. |
| /// |
| /// Alternatively, if the exact type is known, this can be deconstructed |
| /// into its parts. |
| pub struct Upgraded { |
| io: Rewind<Box<dyn Io + Send>>, |
| } |
| |
| /// A future for a possible HTTP upgrade. |
| /// |
| /// If no upgrade was available, or it doesn't succeed, yields an `Error`. |
| pub struct OnUpgrade { |
| rx: Option<oneshot::Receiver<crate::Result<Upgraded>>>, |
| } |
| |
| /// The deconstructed parts of an [`Upgraded`](Upgraded) type. |
| /// |
| /// Includes the original IO type, and a read buffer of bytes that the |
| /// HTTP state machine may have already read before completing an upgrade. |
| #[derive(Debug)] |
| pub struct Parts<T> { |
| /// The original IO object used before the upgrade. |
| pub io: T, |
| /// A buffer of bytes that have been read but not processed as HTTP. |
| /// |
| /// For instance, if the `Connection` is used for an HTTP upgrade request, |
| /// it is possible the server sent back the first bytes of the new protocol |
| /// along with the response upgrade. |
| /// |
| /// You will want to check for any existing bytes if you plan to continue |
| /// communicating on the IO object. |
| pub read_buf: Bytes, |
| _inner: (), |
| } |
| |
| /// Gets a pending HTTP upgrade from this message. |
| /// |
| /// This can be called on the following types: |
| /// |
| /// - `http::Request<B>` |
| /// - `http::Response<B>` |
| /// - `&mut http::Request<B>` |
| /// - `&mut http::Response<B>` |
| pub fn on<T: sealed::CanUpgrade>(msg: T) -> OnUpgrade { |
| msg.on_upgrade() |
| } |
| |
| #[cfg(any(feature = "http1", feature = "http2"))] |
| pub(super) struct Pending { |
| tx: oneshot::Sender<crate::Result<Upgraded>>, |
| } |
| |
| #[cfg(any(feature = "http1", feature = "http2"))] |
| pub(super) fn pending() -> (Pending, OnUpgrade) { |
| let (tx, rx) = oneshot::channel(); |
| (Pending { tx }, OnUpgrade { rx: Some(rx) }) |
| } |
| |
| // ===== impl Upgraded ===== |
| |
| impl Upgraded { |
| #[cfg(any(feature = "http1", feature = "http2", test))] |
| pub(super) fn new<T>(io: T, read_buf: Bytes) -> Self |
| where |
| T: AsyncRead + AsyncWrite + Unpin + Send + 'static, |
| { |
| Upgraded { |
| io: Rewind::new_buffered(Box::new(io), read_buf), |
| } |
| } |
| |
| /// Tries to downcast the internal trait object to the type passed. |
| /// |
| /// On success, returns the downcasted parts. On error, returns the |
| /// `Upgraded` back. |
| pub fn downcast<T: AsyncRead + AsyncWrite + Unpin + 'static>(self) -> Result<Parts<T>, Self> { |
| let (io, buf) = self.io.into_inner(); |
| match io.__hyper_downcast() { |
| Ok(t) => Ok(Parts { |
| io: *t, |
| read_buf: buf, |
| _inner: (), |
| }), |
| Err(io) => Err(Upgraded { |
| io: Rewind::new_buffered(io, buf), |
| }), |
| } |
| } |
| } |
| |
| impl AsyncRead for Upgraded { |
| fn poll_read( |
| mut self: Pin<&mut Self>, |
| cx: &mut task::Context<'_>, |
| buf: &mut ReadBuf<'_>, |
| ) -> Poll<io::Result<()>> { |
| Pin::new(&mut self.io).poll_read(cx, buf) |
| } |
| } |
| |
| impl AsyncWrite for Upgraded { |
| fn poll_write( |
| mut self: Pin<&mut Self>, |
| cx: &mut task::Context<'_>, |
| buf: &[u8], |
| ) -> Poll<io::Result<usize>> { |
| Pin::new(&mut self.io).poll_write(cx, buf) |
| } |
| |
| fn poll_write_vectored( |
| mut self: Pin<&mut Self>, |
| cx: &mut task::Context<'_>, |
| bufs: &[io::IoSlice<'_>], |
| ) -> Poll<io::Result<usize>> { |
| Pin::new(&mut self.io).poll_write_vectored(cx, bufs) |
| } |
| |
| fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> { |
| Pin::new(&mut self.io).poll_flush(cx) |
| } |
| |
| fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> { |
| Pin::new(&mut self.io).poll_shutdown(cx) |
| } |
| |
| fn is_write_vectored(&self) -> bool { |
| self.io.is_write_vectored() |
| } |
| } |
| |
| impl fmt::Debug for Upgraded { |
| fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
| f.debug_struct("Upgraded").finish() |
| } |
| } |
| |
| // ===== impl OnUpgrade ===== |
| |
| impl OnUpgrade { |
| pub(super) fn none() -> Self { |
| OnUpgrade { rx: None } |
| } |
| |
| #[cfg(feature = "http1")] |
| pub(super) fn is_none(&self) -> bool { |
| self.rx.is_none() |
| } |
| } |
| |
| impl Future for OnUpgrade { |
| type Output = Result<Upgraded, crate::Error>; |
| |
| fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { |
| match self.rx { |
| Some(ref mut rx) => Pin::new(rx).poll(cx).map(|res| match res { |
| Ok(Ok(upgraded)) => Ok(upgraded), |
| Ok(Err(err)) => Err(err), |
| Err(_oneshot_canceled) => Err(crate::Error::new_canceled().with(UpgradeExpected)), |
| }), |
| None => Poll::Ready(Err(crate::Error::new_user_no_upgrade())), |
| } |
| } |
| } |
| |
| impl fmt::Debug for OnUpgrade { |
| fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
| f.debug_struct("OnUpgrade").finish() |
| } |
| } |
| |
| // ===== impl Pending ===== |
| |
| #[cfg(any(feature = "http1", feature = "http2"))] |
| impl Pending { |
| pub(super) fn fulfill(self, upgraded: Upgraded) { |
| trace!("pending upgrade fulfill"); |
| let _ = self.tx.send(Ok(upgraded)); |
| } |
| |
| #[cfg(feature = "http1")] |
| /// Don't fulfill the pending Upgrade, but instead signal that |
| /// upgrades are handled manually. |
| pub(super) fn manual(self) { |
| trace!("pending upgrade handled manually"); |
| let _ = self.tx.send(Err(crate::Error::new_user_manual_upgrade())); |
| } |
| } |
| |
| // ===== impl UpgradeExpected ===== |
| |
| /// Error cause returned when an upgrade was expected but canceled |
| /// for whatever reason. |
| /// |
| /// This likely means the actual `Conn` future wasn't polled and upgraded. |
| #[derive(Debug)] |
| struct UpgradeExpected; |
| |
| impl fmt::Display for UpgradeExpected { |
| fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
| f.write_str("upgrade expected but not completed") |
| } |
| } |
| |
| impl StdError for UpgradeExpected {} |
| |
| // ===== impl Io ===== |
| |
| pub(super) trait Io: AsyncRead + AsyncWrite + Unpin + 'static { |
| fn __hyper_type_id(&self) -> TypeId { |
| TypeId::of::<Self>() |
| } |
| } |
| |
| impl<T: AsyncRead + AsyncWrite + Unpin + 'static> Io for T {} |
| |
| impl dyn Io + Send { |
| fn __hyper_is<T: Io>(&self) -> bool { |
| let t = TypeId::of::<T>(); |
| self.__hyper_type_id() == t |
| } |
| |
| fn __hyper_downcast<T: Io>(self: Box<Self>) -> Result<Box<T>, Box<Self>> { |
| if self.__hyper_is::<T>() { |
| // Taken from `std::error::Error::downcast()`. |
| unsafe { |
| let raw: *mut dyn Io = Box::into_raw(self); |
| Ok(Box::from_raw(raw as *mut T)) |
| } |
| } else { |
| Err(self) |
| } |
| } |
| } |
| |
| mod sealed { |
| use super::OnUpgrade; |
| |
| pub trait CanUpgrade { |
| fn on_upgrade(self) -> OnUpgrade; |
| } |
| |
| impl<B> CanUpgrade for http::Request<B> { |
| fn on_upgrade(mut self) -> OnUpgrade { |
| self.extensions_mut() |
| .remove::<OnUpgrade>() |
| .unwrap_or_else(OnUpgrade::none) |
| } |
| } |
| |
| impl<B> CanUpgrade for &'_ mut http::Request<B> { |
| fn on_upgrade(self) -> OnUpgrade { |
| self.extensions_mut() |
| .remove::<OnUpgrade>() |
| .unwrap_or_else(OnUpgrade::none) |
| } |
| } |
| |
| impl<B> CanUpgrade for http::Response<B> { |
| fn on_upgrade(mut self) -> OnUpgrade { |
| self.extensions_mut() |
| .remove::<OnUpgrade>() |
| .unwrap_or_else(OnUpgrade::none) |
| } |
| } |
| |
| impl<B> CanUpgrade for &'_ mut http::Response<B> { |
| fn on_upgrade(self) -> OnUpgrade { |
| self.extensions_mut() |
| .remove::<OnUpgrade>() |
| .unwrap_or_else(OnUpgrade::none) |
| } |
| } |
| } |
| |
| #[cfg(test)] |
| mod tests { |
| use super::*; |
| |
| #[test] |
| fn upgraded_downcast() { |
| let upgraded = Upgraded::new(Mock, Bytes::new()); |
| |
| let upgraded = upgraded.downcast::<std::io::Cursor<Vec<u8>>>().unwrap_err(); |
| |
| upgraded.downcast::<Mock>().unwrap(); |
| } |
| |
| // TODO: replace with tokio_test::io when it can test write_buf |
| struct Mock; |
| |
| impl AsyncRead for Mock { |
| fn poll_read( |
| self: Pin<&mut Self>, |
| _cx: &mut task::Context<'_>, |
| _buf: &mut ReadBuf<'_>, |
| ) -> Poll<io::Result<()>> { |
| unreachable!("Mock::poll_read") |
| } |
| } |
| |
| impl AsyncWrite for Mock { |
| fn poll_write( |
| self: Pin<&mut Self>, |
| _: &mut task::Context<'_>, |
| buf: &[u8], |
| ) -> Poll<io::Result<usize>> { |
| // panic!("poll_write shouldn't be called"); |
| Poll::Ready(Ok(buf.len())) |
| } |
| |
| fn poll_flush(self: Pin<&mut Self>, _cx: &mut task::Context<'_>) -> Poll<io::Result<()>> { |
| unreachable!("Mock::poll_flush") |
| } |
| |
| fn poll_shutdown( |
| self: Pin<&mut Self>, |
| _cx: &mut task::Context<'_>, |
| ) -> Poll<io::Result<()>> { |
| unreachable!("Mock::poll_shutdown") |
| } |
| } |
| } |