| //! Async TLS streams backed by OpenSSL. |
| //! |
| //! This crate provides a wrapper around the [`openssl`] crate's [`SslStream`](ssl::SslStream) type |
| //! that works with with [`tokio`]'s [`AsyncRead`] and [`AsyncWrite`] traits rather than std's |
| //! blocking [`Read`] and [`Write`] traits. |
| #![warn(missing_docs)] |
| |
| use openssl::error::ErrorStack; |
| use openssl::ssl::{self, ErrorCode, ShutdownResult, Ssl, SslRef}; |
| use std::fmt; |
| use std::future; |
| use std::io::{self, Read, Write}; |
| use std::pin::Pin; |
| use std::task::{Context, Poll}; |
| use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; |
| |
| #[cfg(test)] |
| mod test; |
| |
| struct StreamWrapper<S> { |
| stream: S, |
| context: usize, |
| } |
| |
| impl<S> fmt::Debug for StreamWrapper<S> |
| where |
| S: fmt::Debug, |
| { |
| fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { |
| fmt::Debug::fmt(&self.stream, fmt) |
| } |
| } |
| |
| impl<S> StreamWrapper<S> { |
| /// # Safety |
| /// |
| /// Must be called with `context` set to a valid pointer to a live `Context` object, and the |
| /// wrapper must be pinned in memory. |
| unsafe fn parts(&mut self) -> (Pin<&mut S>, &mut Context<'_>) { |
| debug_assert_ne!(self.context, 0); |
| let stream = Pin::new_unchecked(&mut self.stream); |
| let context = &mut *(self.context as *mut _); |
| (stream, context) |
| } |
| } |
| |
| impl<S> Read for StreamWrapper<S> |
| where |
| S: AsyncRead, |
| { |
| fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { |
| let (stream, cx) = unsafe { self.parts() }; |
| let mut buf = ReadBuf::new(buf); |
| match stream.poll_read(cx, &mut buf)? { |
| Poll::Ready(()) => Ok(buf.filled().len()), |
| Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)), |
| } |
| } |
| } |
| |
| impl<S> Write for StreamWrapper<S> |
| where |
| S: AsyncWrite, |
| { |
| fn write(&mut self, buf: &[u8]) -> io::Result<usize> { |
| let (stream, cx) = unsafe { self.parts() }; |
| match stream.poll_write(cx, buf) { |
| Poll::Ready(r) => r, |
| Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)), |
| } |
| } |
| |
| fn flush(&mut self) -> io::Result<()> { |
| let (stream, cx) = unsafe { self.parts() }; |
| match stream.poll_flush(cx) { |
| Poll::Ready(r) => r, |
| Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)), |
| } |
| } |
| } |
| |
| fn cvt<T>(r: io::Result<T>) -> Poll<io::Result<T>> { |
| match r { |
| Ok(v) => Poll::Ready(Ok(v)), |
| Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending, |
| Err(e) => Poll::Ready(Err(e)), |
| } |
| } |
| |
| fn cvt_ossl<T>(r: Result<T, ssl::Error>) -> Poll<Result<T, ssl::Error>> { |
| match r { |
| Ok(v) => Poll::Ready(Ok(v)), |
| Err(e) => match e.code() { |
| ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => Poll::Pending, |
| _ => Poll::Ready(Err(e)), |
| }, |
| } |
| } |
| |
| /// An asynchronous version of [`openssl::ssl::SslStream`]. |
| #[derive(Debug)] |
| pub struct SslStream<S>(ssl::SslStream<StreamWrapper<S>>); |
| |
| impl<S> SslStream<S> |
| where |
| S: AsyncRead + AsyncWrite, |
| { |
| /// Like [`SslStream::new`](ssl::SslStream::new). |
| pub fn new(ssl: Ssl, stream: S) -> Result<Self, ErrorStack> { |
| ssl::SslStream::new(ssl, StreamWrapper { stream, context: 0 }).map(SslStream) |
| } |
| |
| /// Like [`SslStream::connect`](ssl::SslStream::connect). |
| pub fn poll_connect( |
| self: Pin<&mut Self>, |
| cx: &mut Context<'_>, |
| ) -> Poll<Result<(), ssl::Error>> { |
| self.with_context(cx, |s| cvt_ossl(s.connect())) |
| } |
| |
| /// A convenience method wrapping [`poll_connect`](Self::poll_connect). |
| pub async fn connect(mut self: Pin<&mut Self>) -> Result<(), ssl::Error> { |
| future::poll_fn(|cx| self.as_mut().poll_connect(cx)).await |
| } |
| |
| /// Like [`SslStream::accept`](ssl::SslStream::accept). |
| pub fn poll_accept(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), ssl::Error>> { |
| self.with_context(cx, |s| cvt_ossl(s.accept())) |
| } |
| |
| /// A convenience method wrapping [`poll_accept`](Self::poll_accept). |
| pub async fn accept(mut self: Pin<&mut Self>) -> Result<(), ssl::Error> { |
| future::poll_fn(|cx| self.as_mut().poll_accept(cx)).await |
| } |
| |
| /// Like [`SslStream::do_handshake`](ssl::SslStream::do_handshake). |
| pub fn poll_do_handshake( |
| self: Pin<&mut Self>, |
| cx: &mut Context<'_>, |
| ) -> Poll<Result<(), ssl::Error>> { |
| self.with_context(cx, |s| cvt_ossl(s.do_handshake())) |
| } |
| |
| /// A convenience method wrapping [`poll_do_handshake`](Self::poll_do_handshake). |
| pub async fn do_handshake(mut self: Pin<&mut Self>) -> Result<(), ssl::Error> { |
| future::poll_fn(|cx| self.as_mut().poll_do_handshake(cx)).await |
| } |
| |
| /// Like [`SslStream::ssl_peek`](ssl::SslStream::ssl_peek). |
| pub fn poll_peek( |
| self: Pin<&mut Self>, |
| cx: &mut Context<'_>, |
| buf: &mut [u8], |
| ) -> Poll<Result<usize, ssl::Error>> { |
| self.with_context(cx, |s| cvt_ossl(s.ssl_peek(buf))) |
| } |
| |
| /// A convenience method wrapping [`poll_peek`](Self::poll_peek). |
| pub async fn peek(mut self: Pin<&mut Self>, buf: &mut [u8]) -> Result<usize, ssl::Error> { |
| future::poll_fn(|cx| self.as_mut().poll_peek(cx, buf)).await |
| } |
| |
| /// Like [`SslStream::read_early_data`](ssl::SslStream::read_early_data). |
| #[cfg(ossl111)] |
| pub fn poll_read_early_data( |
| self: Pin<&mut Self>, |
| cx: &mut Context<'_>, |
| buf: &mut [u8], |
| ) -> Poll<Result<usize, ssl::Error>> { |
| self.with_context(cx, |s| cvt_ossl(s.read_early_data(buf))) |
| } |
| |
| /// A convenience method wrapping [`poll_read_early_data`](Self::poll_read_early_data). |
| #[cfg(ossl111)] |
| pub async fn read_early_data( |
| mut self: Pin<&mut Self>, |
| buf: &mut [u8], |
| ) -> Result<usize, ssl::Error> { |
| future::poll_fn(|cx| self.as_mut().poll_read_early_data(cx, buf)).await |
| } |
| |
| /// Like [`SslStream::write_early_data`](ssl::SslStream::write_early_data). |
| #[cfg(ossl111)] |
| pub fn poll_write_early_data( |
| self: Pin<&mut Self>, |
| cx: &mut Context<'_>, |
| buf: &[u8], |
| ) -> Poll<Result<usize, ssl::Error>> { |
| self.with_context(cx, |s| cvt_ossl(s.write_early_data(buf))) |
| } |
| |
| /// A convenience method wrapping [`poll_write_early_data`](Self::poll_write_early_data). |
| #[cfg(ossl111)] |
| pub async fn write_early_data( |
| mut self: Pin<&mut Self>, |
| buf: &[u8], |
| ) -> Result<usize, ssl::Error> { |
| future::poll_fn(|cx| self.as_mut().poll_write_early_data(cx, buf)).await |
| } |
| } |
| |
| impl<S> SslStream<S> { |
| /// Returns a shared reference to the `Ssl` object associated with this stream. |
| pub fn ssl(&self) -> &SslRef { |
| self.0.ssl() |
| } |
| |
| /// Returns a shared reference to the underlying stream. |
| pub fn get_ref(&self) -> &S { |
| &self.0.get_ref().stream |
| } |
| |
| /// Returns a mutable reference to the underlying stream. |
| pub fn get_mut(&mut self) -> &mut S { |
| &mut self.0.get_mut().stream |
| } |
| |
| /// Returns a pinned mutable reference to the underlying stream. |
| pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S> { |
| unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().0.get_mut().stream) } |
| } |
| |
| fn with_context<F, R>(self: Pin<&mut Self>, ctx: &mut Context<'_>, f: F) -> R |
| where |
| F: FnOnce(&mut ssl::SslStream<StreamWrapper<S>>) -> R, |
| { |
| let this = unsafe { self.get_unchecked_mut() }; |
| this.0.get_mut().context = ctx as *mut _ as usize; |
| let r = f(&mut this.0); |
| this.0.get_mut().context = 0; |
| r |
| } |
| } |
| |
| impl<S> AsyncRead for SslStream<S> |
| where |
| S: AsyncRead + AsyncWrite, |
| { |
| fn poll_read( |
| self: Pin<&mut Self>, |
| ctx: &mut Context<'_>, |
| buf: &mut ReadBuf<'_>, |
| ) -> Poll<io::Result<()>> { |
| self.with_context(ctx, |s| { |
| // SAFETY: read_uninit does not de-initialize the buffer. |
| match cvt(s.read_uninit(unsafe { buf.unfilled_mut() }))? { |
| Poll::Ready(nread) => { |
| // SAFETY: read_uninit guarantees that nread bytes have been initialized. |
| unsafe { buf.assume_init(nread) }; |
| buf.advance(nread); |
| Poll::Ready(Ok(())) |
| } |
| Poll::Pending => Poll::Pending, |
| } |
| }) |
| } |
| } |
| |
| impl<S> AsyncWrite for SslStream<S> |
| where |
| S: AsyncRead + AsyncWrite, |
| { |
| fn poll_write(self: Pin<&mut Self>, ctx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> { |
| self.with_context(ctx, |s| cvt(s.write(buf))) |
| } |
| |
| fn poll_flush(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> { |
| self.with_context(ctx, |s| cvt(s.flush())) |
| } |
| |
| fn poll_shutdown(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> { |
| match self.as_mut().with_context(ctx, |s| s.shutdown()) { |
| Ok(ShutdownResult::Sent) | Ok(ShutdownResult::Received) => {} |
| Err(ref e) if e.code() == ErrorCode::ZERO_RETURN => {} |
| Err(ref e) if e.code() == ErrorCode::WANT_READ || e.code() == ErrorCode::WANT_WRITE => { |
| return Poll::Pending; |
| } |
| Err(e) => { |
| return Poll::Ready(Err(e |
| .into_io_error() |
| .unwrap_or_else(|e| io::Error::new(io::ErrorKind::Other, e)))); |
| } |
| } |
| |
| self.get_pin_mut().poll_shutdown(ctx) |
| } |
| } |