| //! Connection helper. |
| use std::io::{Read, Write}; |
| |
| use crate::{ |
| client::{client_with_config, uri_mode, IntoClientRequest}, |
| error::UrlError, |
| handshake::client::Response, |
| protocol::WebSocketConfig, |
| stream::MaybeTlsStream, |
| ClientHandshake, Error, HandshakeError, Result, WebSocket, |
| }; |
| |
| /// A connector that can be used when establishing connections, allowing to control whether |
| /// `native-tls` or `rustls` is used to create a TLS connection. Or TLS can be disabled with the |
| /// `Plain` variant. |
| #[non_exhaustive] |
| #[allow(missing_debug_implementations)] |
| pub enum Connector { |
| /// Plain (non-TLS) connector. |
| Plain, |
| /// `native-tls` TLS connector. |
| #[cfg(feature = "native-tls")] |
| NativeTls(native_tls_crate::TlsConnector), |
| /// `rustls` TLS connector. |
| #[cfg(feature = "__rustls-tls")] |
| Rustls(std::sync::Arc<rustls::ClientConfig>), |
| } |
| |
| mod encryption { |
| #[cfg(feature = "native-tls")] |
| pub mod native_tls { |
| use native_tls_crate::{HandshakeError as TlsHandshakeError, TlsConnector}; |
| |
| use std::io::{Read, Write}; |
| |
| use crate::{ |
| error::TlsError, |
| stream::{MaybeTlsStream, Mode}, |
| Error, Result, |
| }; |
| |
| pub fn wrap_stream<S>( |
| socket: S, |
| domain: &str, |
| mode: Mode, |
| tls_connector: Option<TlsConnector>, |
| ) -> Result<MaybeTlsStream<S>> |
| where |
| S: Read + Write, |
| { |
| match mode { |
| Mode::Plain => Ok(MaybeTlsStream::Plain(socket)), |
| Mode::Tls => { |
| let try_connector = tls_connector.map_or_else(TlsConnector::new, Ok); |
| let connector = try_connector.map_err(TlsError::Native)?; |
| let connected = connector.connect(domain, socket); |
| match connected { |
| Err(e) => match e { |
| TlsHandshakeError::Failure(f) => Err(Error::Tls(f.into())), |
| TlsHandshakeError::WouldBlock(_) => { |
| panic!("Bug: TLS handshake not blocked") |
| } |
| }, |
| Ok(s) => Ok(MaybeTlsStream::NativeTls(s)), |
| } |
| } |
| } |
| } |
| } |
| |
| #[cfg(feature = "__rustls-tls")] |
| pub mod rustls { |
| use rustls::{ClientConfig, ClientConnection, RootCertStore, ServerName, StreamOwned}; |
| |
| use std::{ |
| convert::TryFrom, |
| io::{Read, Write}, |
| sync::Arc, |
| }; |
| |
| use crate::{ |
| error::TlsError, |
| stream::{MaybeTlsStream, Mode}, |
| Result, |
| }; |
| |
| pub fn wrap_stream<S>( |
| socket: S, |
| domain: &str, |
| mode: Mode, |
| tls_connector: Option<Arc<ClientConfig>>, |
| ) -> Result<MaybeTlsStream<S>> |
| where |
| S: Read + Write, |
| { |
| match mode { |
| Mode::Plain => Ok(MaybeTlsStream::Plain(socket)), |
| Mode::Tls => { |
| let config = match tls_connector { |
| Some(config) => config, |
| None => { |
| #[allow(unused_mut)] |
| let mut root_store = RootCertStore::empty(); |
| |
| #[cfg(feature = "rustls-tls-native-roots")] |
| { |
| for cert in rustls_native_certs::load_native_certs()? { |
| root_store |
| .add(&rustls::Certificate(cert.0)) |
| .map_err(TlsError::Rustls)?; |
| } |
| } |
| #[cfg(feature = "rustls-tls-webpki-roots")] |
| { |
| root_store.add_server_trust_anchors( |
| webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| { |
| rustls::OwnedTrustAnchor::from_subject_spki_name_constraints( |
| ta.subject, |
| ta.spki, |
| ta.name_constraints, |
| ) |
| }) |
| ); |
| } |
| |
| Arc::new( |
| ClientConfig::builder() |
| .with_safe_defaults() |
| .with_root_certificates(root_store) |
| .with_no_client_auth(), |
| ) |
| } |
| }; |
| let domain = |
| ServerName::try_from(domain).map_err(|_| TlsError::InvalidDnsName)?; |
| let client = ClientConnection::new(config, domain).map_err(TlsError::Rustls)?; |
| let stream = StreamOwned::new(client, socket); |
| |
| Ok(MaybeTlsStream::Rustls(stream)) |
| } |
| } |
| } |
| } |
| |
| pub mod plain { |
| use std::io::{Read, Write}; |
| |
| use crate::{ |
| error::UrlError, |
| stream::{MaybeTlsStream, Mode}, |
| Error, Result, |
| }; |
| |
| pub fn wrap_stream<S>(socket: S, mode: Mode) -> Result<MaybeTlsStream<S>> |
| where |
| S: Read + Write, |
| { |
| match mode { |
| Mode::Plain => Ok(MaybeTlsStream::Plain(socket)), |
| Mode::Tls => Err(Error::Url(UrlError::TlsFeatureNotEnabled)), |
| } |
| } |
| } |
| } |
| |
| type TlsHandshakeError<S> = HandshakeError<ClientHandshake<MaybeTlsStream<S>>>; |
| |
| /// Creates a WebSocket handshake from a request and a stream, |
| /// upgrading the stream to TLS if required. |
| pub fn client_tls<R, S>( |
| request: R, |
| stream: S, |
| ) -> Result<(WebSocket<MaybeTlsStream<S>>, Response), TlsHandshakeError<S>> |
| where |
| R: IntoClientRequest, |
| S: Read + Write, |
| { |
| client_tls_with_config(request, stream, None, None) |
| } |
| |
| /// The same as [`client_tls()`] but one can specify a websocket configuration, |
| /// and an optional connector. If no connector is specified, a default one will |
| /// be created. |
| /// |
| /// Please refer to [`client_tls()`] for more details. |
| pub fn client_tls_with_config<R, S>( |
| request: R, |
| stream: S, |
| config: Option<WebSocketConfig>, |
| connector: Option<Connector>, |
| ) -> Result<(WebSocket<MaybeTlsStream<S>>, Response), TlsHandshakeError<S>> |
| where |
| R: IntoClientRequest, |
| S: Read + Write, |
| { |
| let request = request.into_client_request()?; |
| |
| #[cfg(any(feature = "native-tls", feature = "__rustls-tls"))] |
| let domain = match request.uri().host() { |
| Some(d) => Ok(d.to_string()), |
| None => Err(Error::Url(UrlError::NoHostName)), |
| }?; |
| |
| let mode = uri_mode(request.uri())?; |
| |
| let stream = match connector { |
| Some(conn) => match conn { |
| #[cfg(feature = "native-tls")] |
| Connector::NativeTls(conn) => { |
| self::encryption::native_tls::wrap_stream(stream, &domain, mode, Some(conn)) |
| } |
| #[cfg(feature = "__rustls-tls")] |
| Connector::Rustls(conn) => { |
| self::encryption::rustls::wrap_stream(stream, &domain, mode, Some(conn)) |
| } |
| Connector::Plain => self::encryption::plain::wrap_stream(stream, mode), |
| }, |
| None => { |
| #[cfg(feature = "native-tls")] |
| { |
| self::encryption::native_tls::wrap_stream(stream, &domain, mode, None) |
| } |
| #[cfg(all(feature = "__rustls-tls", not(feature = "native-tls")))] |
| { |
| self::encryption::rustls::wrap_stream(stream, &domain, mode, None) |
| } |
| #[cfg(not(any(feature = "native-tls", feature = "__rustls-tls")))] |
| { |
| self::encryption::plain::wrap_stream(stream, mode) |
| } |
| } |
| }?; |
| |
| client_with_config(request, stream, config) |
| } |