| #![doc(html_root_url = "https://docs.rs/tokio-native-tls/0.3.0")] |
| #![warn( |
| missing_debug_implementations, |
| missing_docs, |
| rust_2018_idioms, |
| unreachable_pub |
| )] |
| #![deny(rustdoc::broken_intra_doc_links)] |
| #![doc(test( |
| no_crate_inject, |
| attr(deny(warnings, rust_2018_idioms), allow(dead_code, unused_variables)) |
| ))] |
| |
| //! Async TLS streams |
| //! |
| //! This library is an implementation of TLS streams using the most appropriate |
| //! system library by default for negotiating the connection. That is, on |
| //! Windows this library uses SChannel, on OSX it uses SecureTransport, and on |
| //! other platforms it uses OpenSSL. |
| //! |
| //! Each TLS stream implements the `Read` and `Write` traits to interact and |
| //! interoperate with the rest of the futures I/O ecosystem. Client connections |
| //! initiated from this crate verify hostnames automatically and by default. |
| //! |
| //! This crate primarily exports this ability through two newtypes, |
| //! `TlsConnector` and `TlsAcceptor`. These newtypes augment the |
| //! functionality provided by the `native-tls` crate, on which this crate is |
| //! built. Configuration of TLS parameters is still primarily done through the |
| //! `native-tls` crate. |
| |
| use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; |
| |
| use crate::native_tls::{Error, HandshakeError, MidHandshakeTlsStream}; |
| use std::fmt; |
| use std::future::Future; |
| use std::io::{self, Read, Write}; |
| use std::marker::Unpin; |
| #[cfg(unix)] |
| use std::os::unix::io::{AsRawFd, RawFd}; |
| #[cfg(windows)] |
| use std::os::windows::io::{AsRawSocket, RawSocket}; |
| use std::pin::Pin; |
| use std::ptr::null_mut; |
| use std::task::{Context, Poll}; |
| |
| /// An intermediate wrapper for the inner stream `S`. |
| #[derive(Debug)] |
| pub struct AllowStd<S> { |
| inner: S, |
| context: *mut (), |
| } |
| |
| impl<S> AllowStd<S> { |
| /// Returns a shared reference to the inner stream. |
| pub fn get_ref(&self) -> &S { |
| &self.inner |
| } |
| |
| /// Returns a mutable reference to the inner stream. |
| pub fn get_mut(&mut self) -> &mut S { |
| &mut self.inner |
| } |
| } |
| |
| /// A wrapper around an underlying raw stream which implements the TLS or SSL |
| /// protocol. |
| /// |
| /// A `TlsStream<S>` represents a handshake that has been completed successfully |
| /// and both the server and the client are ready for receiving and sending |
| /// data. Bytes read from a `TlsStream` are decrypted from `S` and bytes written |
| /// to a `TlsStream` are encrypted when passing through to `S`. |
| #[derive(Debug)] |
| pub struct TlsStream<S>(native_tls::TlsStream<AllowStd<S>>); |
| |
| /// A wrapper around a `native_tls::TlsConnector`, providing an async `connect` |
| /// method. |
| #[derive(Clone)] |
| pub struct TlsConnector(native_tls::TlsConnector); |
| |
| /// A wrapper around a `native_tls::TlsAcceptor`, providing an async `accept` |
| /// method. |
| #[derive(Clone)] |
| pub struct TlsAcceptor(native_tls::TlsAcceptor); |
| |
| struct MidHandshake<S>(Option<MidHandshakeTlsStream<AllowStd<S>>>); |
| |
| enum StartedHandshake<S> { |
| Done(TlsStream<S>), |
| Mid(MidHandshakeTlsStream<AllowStd<S>>), |
| } |
| |
| struct StartedHandshakeFuture<F, S>(Option<StartedHandshakeFutureInner<F, S>>); |
| struct StartedHandshakeFutureInner<F, S> { |
| f: F, |
| stream: S, |
| } |
| |
| struct Guard<'a, S>(&'a mut TlsStream<S>) |
| where |
| AllowStd<S>: Read + Write; |
| |
| impl<S> Drop for Guard<'_, S> |
| where |
| AllowStd<S>: Read + Write, |
| { |
| fn drop(&mut self) { |
| (self.0).0.get_mut().context = null_mut(); |
| } |
| } |
| |
| // *mut () context is neither Send nor Sync |
| unsafe impl<S: Send> Send for AllowStd<S> {} |
| unsafe impl<S: Sync> Sync for AllowStd<S> {} |
| |
| impl<S> AllowStd<S> |
| where |
| S: Unpin, |
| { |
| fn with_context<F, R>(&mut self, f: F) -> io::Result<R> |
| where |
| F: FnOnce(&mut Context<'_>, Pin<&mut S>) -> Poll<io::Result<R>>, |
| { |
| unsafe { |
| assert!(!self.context.is_null()); |
| let waker = &mut *(self.context as *mut _); |
| match f(waker, Pin::new(&mut self.inner)) { |
| Poll::Ready(r) => r, |
| Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)), |
| } |
| } |
| } |
| } |
| |
| impl<S> Read for AllowStd<S> |
| where |
| S: AsyncRead + Unpin, |
| { |
| fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { |
| let mut buf = ReadBuf::new(buf); |
| self.with_context(|ctx, stream| stream.poll_read(ctx, &mut buf))?; |
| Ok(buf.filled().len()) |
| } |
| } |
| |
| impl<S> Write for AllowStd<S> |
| where |
| S: AsyncWrite + Unpin, |
| { |
| fn write(&mut self, buf: &[u8]) -> io::Result<usize> { |
| self.with_context(|ctx, stream| stream.poll_write(ctx, buf)) |
| } |
| |
| fn flush(&mut self) -> io::Result<()> { |
| self.with_context(|ctx, stream| stream.poll_flush(ctx)) |
| } |
| } |
| |
| impl<S> TlsStream<S> { |
| fn with_context<F, R>(&mut self, ctx: &mut Context<'_>, f: F) -> Poll<io::Result<R>> |
| where |
| F: FnOnce(&mut native_tls::TlsStream<AllowStd<S>>) -> io::Result<R>, |
| AllowStd<S>: Read + Write, |
| { |
| self.0.get_mut().context = ctx as *mut _ as *mut (); |
| let g = Guard(self); |
| match f(&mut (g.0).0) { |
| Ok(v) => Poll::Ready(Ok(v)), |
| Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending, |
| Err(e) => Poll::Ready(Err(e)), |
| } |
| } |
| |
| /// Returns a shared reference to the inner stream. |
| pub fn get_ref(&self) -> &native_tls::TlsStream<AllowStd<S>> { |
| &self.0 |
| } |
| |
| /// Returns a mutable reference to the inner stream. |
| pub fn get_mut(&mut self) -> &mut native_tls::TlsStream<AllowStd<S>> { |
| &mut self.0 |
| } |
| } |
| |
| impl<S> AsyncRead for TlsStream<S> |
| where |
| S: AsyncRead + AsyncWrite + Unpin, |
| { |
| fn poll_read( |
| mut self: Pin<&mut Self>, |
| ctx: &mut Context<'_>, |
| buf: &mut ReadBuf<'_>, |
| ) -> Poll<io::Result<()>> { |
| self.with_context(ctx, |s| { |
| let n = s.read(buf.initialize_unfilled())?; |
| buf.advance(n); |
| Ok(()) |
| }) |
| } |
| } |
| |
| impl<S> AsyncWrite for TlsStream<S> |
| where |
| S: AsyncRead + AsyncWrite + Unpin, |
| { |
| fn poll_write( |
| mut self: Pin<&mut Self>, |
| ctx: &mut Context<'_>, |
| buf: &[u8], |
| ) -> Poll<io::Result<usize>> { |
| self.with_context(ctx, |s| s.write(buf)) |
| } |
| |
| fn poll_flush(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> { |
| self.with_context(ctx, |s| s.flush()) |
| } |
| |
| fn poll_shutdown(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> { |
| self.with_context(ctx, |s| s.shutdown()) |
| } |
| } |
| |
| #[cfg(unix)] |
| impl<S> AsRawFd for TlsStream<S> |
| where |
| S: AsRawFd, |
| { |
| fn as_raw_fd(&self) -> RawFd { |
| self.get_ref().get_ref().get_ref().as_raw_fd() |
| } |
| } |
| |
| #[cfg(windows)] |
| impl<S> AsRawSocket for TlsStream<S> |
| where |
| S: AsRawSocket, |
| { |
| fn as_raw_socket(&self) -> RawSocket { |
| self.get_ref().get_ref().get_ref().as_raw_socket() |
| } |
| } |
| |
| async fn handshake<F, S>(f: F, stream: S) -> Result<TlsStream<S>, Error> |
| where |
| F: FnOnce( |
| AllowStd<S>, |
| ) -> Result<native_tls::TlsStream<AllowStd<S>>, HandshakeError<AllowStd<S>>> |
| + Unpin, |
| S: AsyncRead + AsyncWrite + Unpin, |
| { |
| let start = StartedHandshakeFuture(Some(StartedHandshakeFutureInner { f, stream })); |
| |
| match start.await { |
| Err(e) => Err(e), |
| Ok(StartedHandshake::Done(s)) => Ok(s), |
| Ok(StartedHandshake::Mid(s)) => MidHandshake(Some(s)).await, |
| } |
| } |
| |
| impl<F, S> Future for StartedHandshakeFuture<F, S> |
| where |
| F: FnOnce( |
| AllowStd<S>, |
| ) -> Result<native_tls::TlsStream<AllowStd<S>>, HandshakeError<AllowStd<S>>> |
| + Unpin, |
| S: Unpin, |
| AllowStd<S>: Read + Write, |
| { |
| type Output = Result<StartedHandshake<S>, Error>; |
| |
| fn poll( |
| mut self: Pin<&mut Self>, |
| ctx: &mut Context<'_>, |
| ) -> Poll<Result<StartedHandshake<S>, Error>> { |
| let inner = self.0.take().expect("future polled after completion"); |
| let stream = AllowStd { |
| inner: inner.stream, |
| context: ctx as *mut _ as *mut (), |
| }; |
| |
| match (inner.f)(stream) { |
| Ok(mut s) => { |
| s.get_mut().context = null_mut(); |
| Poll::Ready(Ok(StartedHandshake::Done(TlsStream(s)))) |
| } |
| Err(HandshakeError::WouldBlock(mut s)) => { |
| s.get_mut().context = null_mut(); |
| Poll::Ready(Ok(StartedHandshake::Mid(s))) |
| } |
| Err(HandshakeError::Failure(e)) => Poll::Ready(Err(e)), |
| } |
| } |
| } |
| |
| impl TlsConnector { |
| /// Connects the provided stream with this connector, assuming the provided |
| /// domain. |
| /// |
| /// This function will internally call `TlsConnector::connect` to connect |
| /// the stream and returns a future representing the resolution of the |
| /// connection operation. The returned future will resolve to either |
| /// `TlsStream<S>` or `Error` depending if it's successful or not. |
| /// |
| /// This is typically used for clients who have already established, for |
| /// example, a TCP connection to a remote server. That stream is then |
| /// provided here to perform the client half of a connection to a |
| /// TLS-powered server. |
| pub async fn connect<S>(&self, domain: &str, stream: S) -> Result<TlsStream<S>, Error> |
| where |
| S: AsyncRead + AsyncWrite + Unpin, |
| { |
| handshake(move |s| self.0.connect(domain, s), stream).await |
| } |
| } |
| |
| impl fmt::Debug for TlsConnector { |
| fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
| f.debug_struct("TlsConnector").finish() |
| } |
| } |
| |
| impl From<native_tls::TlsConnector> for TlsConnector { |
| fn from(inner: native_tls::TlsConnector) -> TlsConnector { |
| TlsConnector(inner) |
| } |
| } |
| |
| impl TlsAcceptor { |
| /// Accepts a new client connection with the provided stream. |
| /// |
| /// This function will internally call `TlsAcceptor::accept` to connect |
| /// the stream and returns a future representing the resolution of the |
| /// connection operation. The returned future will resolve to either |
| /// `TlsStream<S>` or `Error` depending if it's successful or not. |
| /// |
| /// This is typically used after a new socket has been accepted from a |
| /// `TcpListener`. That socket is then passed to this function to perform |
| /// the server half of accepting a client connection. |
| pub async fn accept<S>(&self, stream: S) -> Result<TlsStream<S>, Error> |
| where |
| S: AsyncRead + AsyncWrite + Unpin, |
| { |
| handshake(move |s| self.0.accept(s), stream).await |
| } |
| } |
| |
| impl fmt::Debug for TlsAcceptor { |
| fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
| f.debug_struct("TlsAcceptor").finish() |
| } |
| } |
| |
| impl From<native_tls::TlsAcceptor> for TlsAcceptor { |
| fn from(inner: native_tls::TlsAcceptor) -> TlsAcceptor { |
| TlsAcceptor(inner) |
| } |
| } |
| |
| impl<S: AsyncRead + AsyncWrite + Unpin> Future for MidHandshake<S> { |
| type Output = Result<TlsStream<S>, Error>; |
| |
| fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { |
| let mut_self = self.get_mut(); |
| let mut s = mut_self.0.take().expect("future polled after completion"); |
| |
| s.get_mut().context = cx as *mut _ as *mut (); |
| match s.handshake() { |
| Ok(mut s) => { |
| s.get_mut().context = null_mut(); |
| Poll::Ready(Ok(TlsStream(s))) |
| } |
| Err(HandshakeError::WouldBlock(mut s)) => { |
| s.get_mut().context = null_mut(); |
| mut_self.0 = Some(s); |
| Poll::Pending |
| } |
| Err(HandshakeError::Failure(e)) => Poll::Ready(Err(e)), |
| } |
| } |
| } |
| |
| /// re-export native_tls |
| pub mod native_tls { |
| pub use native_tls::*; |
| } |