| #![allow(dead_code)] |
| |
| use std::io; |
| use std::ops::{Deref, DerefMut}; |
| use std::sync::Arc; |
| |
| use rustls::internal::msgs::codec::Reader; |
| use rustls::internal::msgs::message::{Message, OpaqueMessage, PlainMessage}; |
| use rustls::server::{ |
| AllowAnyAnonymousOrAuthenticatedClient, AllowAnyAuthenticatedClient, UnparsedCertRevocationList, |
| }; |
| use rustls::Connection; |
| use rustls::Error; |
| use rustls::RootCertStore; |
| use rustls::{Certificate, PrivateKey}; |
| use rustls::{ClientConfig, ClientConnection}; |
| use rustls::{ConnectionCommon, ServerConfig, ServerConnection, SideData}; |
| |
| macro_rules! embed_files { |
| ( |
| $( |
| ($name:ident, $keytype:expr, $path:expr); |
| )+ |
| ) => { |
| $( |
| const $name: &'static [u8] = include_bytes!( |
| concat!("../../../test-ca/", $keytype, "/", $path)); |
| )+ |
| |
| pub fn bytes_for(keytype: &str, path: &str) -> &'static [u8] { |
| match (keytype, path) { |
| $( |
| ($keytype, $path) => $name, |
| )+ |
| _ => panic!("unknown keytype {} with path {}", keytype, path), |
| } |
| } |
| } |
| } |
| |
| embed_files! { |
| (ECDSA_CA_CERT, "ecdsa", "ca.cert"); |
| (ECDSA_CA_DER, "ecdsa", "ca.der"); |
| (ECDSA_CA_KEY, "ecdsa", "ca.key"); |
| (ECDSA_CLIENT_CERT, "ecdsa", "client.cert"); |
| (ECDSA_CLIENT_CHAIN, "ecdsa", "client.chain"); |
| (ECDSA_CLIENT_FULLCHAIN, "ecdsa", "client.fullchain"); |
| (ECDSA_CLIENT_KEY, "ecdsa", "client.key"); |
| (ECDSA_CLIENT_REQ, "ecdsa", "client.req"); |
| (ECDSA_CLIENT_CRL_PEM, "ecdsa", "client.revoked.crl.pem"); |
| (ECDSA_END_CERT, "ecdsa", "end.cert"); |
| (ECDSA_END_CHAIN, "ecdsa", "end.chain"); |
| (ECDSA_END_FULLCHAIN, "ecdsa", "end.fullchain"); |
| (ECDSA_END_KEY, "ecdsa", "end.key"); |
| (ECDSA_END_REQ, "ecdsa", "end.req"); |
| (ECDSA_INTER_CERT, "ecdsa", "inter.cert"); |
| (ECDSA_INTER_KEY, "ecdsa", "inter.key"); |
| (ECDSA_INTER_REQ, "ecdsa", "inter.req"); |
| (ECDSA_NISTP256_PEM, "ecdsa", "nistp256.pem"); |
| (ECDSA_NISTP384_PEM, "ecdsa", "nistp384.pem"); |
| |
| (EDDSA_CA_CERT, "eddsa", "ca.cert"); |
| (EDDSA_CA_DER, "eddsa", "ca.der"); |
| (EDDSA_CA_KEY, "eddsa", "ca.key"); |
| (EDDSA_CLIENT_CERT, "eddsa", "client.cert"); |
| (EDDSA_CLIENT_CHAIN, "eddsa", "client.chain"); |
| (EDDSA_CLIENT_FULLCHAIN, "eddsa", "client.fullchain"); |
| (EDDSA_CLIENT_KEY, "eddsa", "client.key"); |
| (EDDSA_CLIENT_REQ, "eddsa", "client.req"); |
| (EDDSA_CLIENT_CRL_PEM, "eddsa", "client.revoked.crl.pem"); |
| (EDDSA_END_CERT, "eddsa", "end.cert"); |
| (EDDSA_END_CHAIN, "eddsa", "end.chain"); |
| (EDDSA_END_FULLCHAIN, "eddsa", "end.fullchain"); |
| (EDDSA_END_KEY, "eddsa", "end.key"); |
| (EDDSA_END_REQ, "eddsa", "end.req"); |
| (EDDSA_INTER_CERT, "eddsa", "inter.cert"); |
| (EDDSA_INTER_KEY, "eddsa", "inter.key"); |
| (EDDSA_INTER_REQ, "eddsa", "inter.req"); |
| |
| (RSA_CA_CERT, "rsa", "ca.cert"); |
| (RSA_CA_DER, "rsa", "ca.der"); |
| (RSA_CA_KEY, "rsa", "ca.key"); |
| (RSA_CLIENT_CERT, "rsa", "client.cert"); |
| (RSA_CLIENT_CHAIN, "rsa", "client.chain"); |
| (RSA_CLIENT_FULLCHAIN, "rsa", "client.fullchain"); |
| (RSA_CLIENT_KEY, "rsa", "client.key"); |
| (RSA_CLIENT_REQ, "rsa", "client.req"); |
| (RSA_CLIENT_RSA, "rsa", "client.rsa"); |
| (RSA_CLIENT_CRL_PEM, "rsa", "client.revoked.crl.pem"); |
| (RSA_END_CERT, "rsa", "end.cert"); |
| (RSA_END_CHAIN, "rsa", "end.chain"); |
| (RSA_END_FULLCHAIN, "rsa", "end.fullchain"); |
| (RSA_END_KEY, "rsa", "end.key"); |
| (RSA_END_REQ, "rsa", "end.req"); |
| (RSA_END_RSA, "rsa", "end.rsa"); |
| (RSA_INTER_CERT, "rsa", "inter.cert"); |
| (RSA_INTER_KEY, "rsa", "inter.key"); |
| (RSA_INTER_REQ, "rsa", "inter.req"); |
| } |
| |
| pub fn transfer( |
| left: &mut (impl DerefMut + Deref<Target = ConnectionCommon<impl SideData>>), |
| right: &mut (impl DerefMut + Deref<Target = ConnectionCommon<impl SideData>>), |
| ) -> usize { |
| let mut buf = [0u8; 262144]; |
| let mut total = 0; |
| |
| while left.wants_write() { |
| let sz = { |
| let into_buf: &mut dyn io::Write = &mut &mut buf[..]; |
| left.write_tls(into_buf).unwrap() |
| }; |
| total += sz; |
| if sz == 0 { |
| return total; |
| } |
| |
| let mut offs = 0; |
| loop { |
| let from_buf: &mut dyn io::Read = &mut &buf[offs..sz]; |
| offs += right.read_tls(from_buf).unwrap(); |
| if sz == offs { |
| break; |
| } |
| } |
| } |
| |
| total |
| } |
| |
| pub fn transfer_eof(conn: &mut (impl DerefMut + Deref<Target = ConnectionCommon<impl SideData>>)) { |
| let empty_buf = [0u8; 0]; |
| let empty_cursor: &mut dyn io::Read = &mut &empty_buf[..]; |
| let sz = conn.read_tls(empty_cursor).unwrap(); |
| assert_eq!(sz, 0); |
| } |
| |
| pub enum Altered { |
| /// message has been edited in-place (or is unchanged) |
| InPlace, |
| /// send these raw bytes instead of the message. |
| Raw(Vec<u8>), |
| } |
| |
| pub fn transfer_altered<F>(left: &mut Connection, filter: F, right: &mut Connection) -> usize |
| where |
| F: Fn(&mut Message) -> Altered, |
| { |
| let mut buf = [0u8; 262144]; |
| let mut total = 0; |
| |
| while left.wants_write() { |
| let sz = { |
| let into_buf: &mut dyn io::Write = &mut &mut buf[..]; |
| left.write_tls(into_buf).unwrap() |
| }; |
| total += sz; |
| if sz == 0 { |
| return total; |
| } |
| |
| let mut reader = Reader::init(&buf[..sz]); |
| while reader.any_left() { |
| let message = OpaqueMessage::read(&mut reader).unwrap(); |
| let mut message = Message::try_from(message.into_plain_message()).unwrap(); |
| let message_enc = match filter(&mut message) { |
| Altered::InPlace => PlainMessage::from(message) |
| .into_unencrypted_opaque() |
| .encode(), |
| Altered::Raw(data) => data, |
| }; |
| |
| let message_enc_reader: &mut dyn io::Read = &mut &message_enc[..]; |
| let len = right |
| .read_tls(message_enc_reader) |
| .unwrap(); |
| assert_eq!(len, message_enc.len()); |
| } |
| } |
| |
| total |
| } |
| |
| #[derive(Clone, Copy, Debug, PartialEq)] |
| pub enum KeyType { |
| Rsa, |
| Ecdsa, |
| Ed25519, |
| } |
| |
| pub static ALL_KEY_TYPES: [KeyType; 3] = [KeyType::Rsa, KeyType::Ecdsa, KeyType::Ed25519]; |
| |
| impl KeyType { |
| fn bytes_for(&self, part: &str) -> &'static [u8] { |
| match self { |
| Self::Rsa => bytes_for("rsa", part), |
| Self::Ecdsa => bytes_for("ecdsa", part), |
| Self::Ed25519 => bytes_for("eddsa", part), |
| } |
| } |
| |
| pub fn get_chain(&self) -> Vec<Certificate> { |
| rustls_pemfile::certs(&mut io::BufReader::new(self.bytes_for("end.fullchain"))) |
| .unwrap() |
| .iter() |
| .map(|v| Certificate(v.clone())) |
| .collect() |
| } |
| |
| pub fn get_key(&self) -> PrivateKey { |
| PrivateKey( |
| rustls_pemfile::pkcs8_private_keys(&mut io::BufReader::new(self.bytes_for("end.key"))) |
| .unwrap()[0] |
| .clone(), |
| ) |
| } |
| |
| pub fn get_client_chain(&self) -> Vec<Certificate> { |
| rustls_pemfile::certs(&mut io::BufReader::new(self.bytes_for("client.fullchain"))) |
| .unwrap() |
| .iter() |
| .map(|v| Certificate(v.clone())) |
| .collect() |
| } |
| |
| pub fn client_crl(&self) -> UnparsedCertRevocationList { |
| UnparsedCertRevocationList( |
| rustls_pemfile::crls(&mut io::BufReader::new( |
| self.bytes_for("client.revoked.crl.pem"), |
| )) |
| .unwrap() |
| .into_iter() |
| .next() // We only expect one CRL. |
| .unwrap(), |
| ) |
| } |
| |
| fn get_client_key(&self) -> PrivateKey { |
| PrivateKey( |
| rustls_pemfile::pkcs8_private_keys(&mut io::BufReader::new( |
| self.bytes_for("client.key"), |
| )) |
| .unwrap()[0] |
| .clone(), |
| ) |
| } |
| } |
| |
| pub fn finish_server_config( |
| kt: KeyType, |
| conf: rustls::ConfigBuilder<ServerConfig, rustls::WantsVerifier>, |
| ) -> ServerConfig { |
| conf.with_no_client_auth() |
| .with_single_cert(kt.get_chain(), kt.get_key()) |
| .unwrap() |
| } |
| |
| pub fn make_server_config(kt: KeyType) -> ServerConfig { |
| finish_server_config(kt, ServerConfig::builder().with_safe_defaults()) |
| } |
| |
| pub fn make_server_config_with_versions( |
| kt: KeyType, |
| versions: &[&'static rustls::SupportedProtocolVersion], |
| ) -> ServerConfig { |
| finish_server_config( |
| kt, |
| ServerConfig::builder() |
| .with_safe_default_cipher_suites() |
| .with_safe_default_kx_groups() |
| .with_protocol_versions(versions) |
| .unwrap(), |
| ) |
| } |
| |
| pub fn make_server_config_with_kx_groups( |
| kt: KeyType, |
| kx_groups: &[&'static rustls::SupportedKxGroup], |
| ) -> ServerConfig { |
| finish_server_config( |
| kt, |
| ServerConfig::builder() |
| .with_safe_default_cipher_suites() |
| .with_kx_groups(kx_groups) |
| .with_safe_default_protocol_versions() |
| .unwrap(), |
| ) |
| } |
| |
| pub fn get_client_root_store(kt: KeyType) -> RootCertStore { |
| let mut roots = kt.get_chain(); |
| // drop server cert |
| roots.drain(0..1); |
| let mut client_auth_roots = RootCertStore::empty(); |
| for root in roots { |
| client_auth_roots.add(&root).unwrap(); |
| } |
| client_auth_roots |
| } |
| |
| pub fn make_server_config_with_mandatory_client_auth_crls( |
| kt: KeyType, |
| crls: Vec<UnparsedCertRevocationList>, |
| ) -> ServerConfig { |
| let client_auth_roots = get_client_root_store(kt); |
| |
| let client_auth = AllowAnyAuthenticatedClient::new(client_auth_roots) |
| .with_crls(crls) |
| .unwrap(); |
| |
| ServerConfig::builder() |
| .with_safe_defaults() |
| .with_client_cert_verifier(Arc::new(client_auth)) |
| .with_single_cert(kt.get_chain(), kt.get_key()) |
| .unwrap() |
| } |
| |
| pub fn make_server_config_with_mandatory_client_auth(kt: KeyType) -> ServerConfig { |
| make_server_config_with_mandatory_client_auth_crls(kt, Vec::new()) |
| } |
| |
| pub fn make_server_config_with_optional_client_auth( |
| kt: KeyType, |
| crls: Vec<UnparsedCertRevocationList>, |
| ) -> ServerConfig { |
| let client_auth_roots = get_client_root_store(kt); |
| |
| let client_auth = AllowAnyAnonymousOrAuthenticatedClient::new(client_auth_roots) |
| .with_crls(crls) |
| .unwrap(); |
| |
| ServerConfig::builder() |
| .with_safe_defaults() |
| .with_client_cert_verifier(Arc::new(client_auth)) |
| .with_single_cert(kt.get_chain(), kt.get_key()) |
| .unwrap() |
| } |
| |
| pub fn finish_client_config( |
| kt: KeyType, |
| config: rustls::ConfigBuilder<ClientConfig, rustls::WantsVerifier>, |
| ) -> ClientConfig { |
| let mut root_store = RootCertStore::empty(); |
| let mut rootbuf = io::BufReader::new(kt.bytes_for("ca.cert")); |
| root_store.add_parsable_certificates(&rustls_pemfile::certs(&mut rootbuf).unwrap()); |
| |
| config |
| .with_root_certificates(root_store) |
| .with_no_client_auth() |
| } |
| |
| pub fn finish_client_config_with_creds( |
| kt: KeyType, |
| config: rustls::ConfigBuilder<ClientConfig, rustls::WantsVerifier>, |
| ) -> ClientConfig { |
| let mut root_store = RootCertStore::empty(); |
| let mut rootbuf = io::BufReader::new(kt.bytes_for("ca.cert")); |
| root_store.add_parsable_certificates(&rustls_pemfile::certs(&mut rootbuf).unwrap()); |
| |
| config |
| .with_root_certificates(root_store) |
| .with_client_auth_cert(kt.get_client_chain(), kt.get_client_key()) |
| .unwrap() |
| } |
| |
| pub fn make_client_config(kt: KeyType) -> ClientConfig { |
| finish_client_config(kt, ClientConfig::builder().with_safe_defaults()) |
| } |
| |
| pub fn make_client_config_with_kx_groups( |
| kt: KeyType, |
| kx_groups: &[&'static rustls::SupportedKxGroup], |
| ) -> ClientConfig { |
| let builder = ClientConfig::builder() |
| .with_safe_default_cipher_suites() |
| .with_kx_groups(kx_groups) |
| .with_safe_default_protocol_versions() |
| .unwrap(); |
| finish_client_config(kt, builder) |
| } |
| |
| pub fn make_client_config_with_versions( |
| kt: KeyType, |
| versions: &[&'static rustls::SupportedProtocolVersion], |
| ) -> ClientConfig { |
| let builder = ClientConfig::builder() |
| .with_safe_default_cipher_suites() |
| .with_safe_default_kx_groups() |
| .with_protocol_versions(versions) |
| .unwrap(); |
| finish_client_config(kt, builder) |
| } |
| |
| pub fn make_client_config_with_auth(kt: KeyType) -> ClientConfig { |
| finish_client_config_with_creds(kt, ClientConfig::builder().with_safe_defaults()) |
| } |
| |
| pub fn make_client_config_with_versions_with_auth( |
| kt: KeyType, |
| versions: &[&'static rustls::SupportedProtocolVersion], |
| ) -> ClientConfig { |
| let builder = ClientConfig::builder() |
| .with_safe_default_cipher_suites() |
| .with_safe_default_kx_groups() |
| .with_protocol_versions(versions) |
| .unwrap(); |
| finish_client_config_with_creds(kt, builder) |
| } |
| |
| pub fn make_pair(kt: KeyType) -> (ClientConnection, ServerConnection) { |
| make_pair_for_configs(make_client_config(kt), make_server_config(kt)) |
| } |
| |
| pub fn make_pair_for_configs( |
| client_config: ClientConfig, |
| server_config: ServerConfig, |
| ) -> (ClientConnection, ServerConnection) { |
| make_pair_for_arc_configs(&Arc::new(client_config), &Arc::new(server_config)) |
| } |
| |
| pub fn make_pair_for_arc_configs( |
| client_config: &Arc<ClientConfig>, |
| server_config: &Arc<ServerConfig>, |
| ) -> (ClientConnection, ServerConnection) { |
| ( |
| ClientConnection::new(Arc::clone(client_config), dns_name("localhost")).unwrap(), |
| ServerConnection::new(Arc::clone(server_config)).unwrap(), |
| ) |
| } |
| |
| pub fn do_handshake( |
| client: &mut (impl DerefMut + Deref<Target = ConnectionCommon<impl SideData>>), |
| server: &mut (impl DerefMut + Deref<Target = ConnectionCommon<impl SideData>>), |
| ) -> (usize, usize) { |
| let (mut to_client, mut to_server) = (0, 0); |
| while server.is_handshaking() || client.is_handshaking() { |
| to_server += transfer(client, server); |
| server.process_new_packets().unwrap(); |
| to_client += transfer(server, client); |
| client.process_new_packets().unwrap(); |
| } |
| (to_server, to_client) |
| } |
| |
| #[derive(PartialEq, Debug)] |
| pub enum ErrorFromPeer { |
| Client(Error), |
| Server(Error), |
| } |
| |
| pub fn do_handshake_until_error( |
| client: &mut ClientConnection, |
| server: &mut ServerConnection, |
| ) -> Result<(), ErrorFromPeer> { |
| while server.is_handshaking() || client.is_handshaking() { |
| transfer(client, server); |
| server |
| .process_new_packets() |
| .map_err(ErrorFromPeer::Server)?; |
| transfer(server, client); |
| client |
| .process_new_packets() |
| .map_err(ErrorFromPeer::Client)?; |
| } |
| |
| Ok(()) |
| } |
| |
| pub fn do_handshake_until_both_error( |
| client: &mut ClientConnection, |
| server: &mut ServerConnection, |
| ) -> Result<(), Vec<ErrorFromPeer>> { |
| match do_handshake_until_error(client, server) { |
| Err(server_err @ ErrorFromPeer::Server(_)) => { |
| let mut errors = vec![server_err]; |
| transfer(server, client); |
| let client_err = client |
| .process_new_packets() |
| .map_err(ErrorFromPeer::Client) |
| .expect_err("client didn't produce error after server error"); |
| errors.push(client_err); |
| Err(errors) |
| } |
| |
| Err(client_err @ ErrorFromPeer::Client(_)) => { |
| let mut errors = vec![client_err]; |
| transfer(client, server); |
| let server_err = server |
| .process_new_packets() |
| .map_err(ErrorFromPeer::Server) |
| .expect_err("server didn't produce error after client error"); |
| errors.push(server_err); |
| Err(errors) |
| } |
| |
| Ok(()) => Ok(()), |
| } |
| } |
| |
| pub fn dns_name(name: &'static str) -> rustls::ServerName { |
| name.try_into().unwrap() |
| } |
| |
| pub struct FailsReads { |
| errkind: io::ErrorKind, |
| } |
| |
| impl FailsReads { |
| pub fn new(errkind: io::ErrorKind) -> Self { |
| Self { errkind } |
| } |
| } |
| |
| impl io::Read for FailsReads { |
| fn read(&mut self, _b: &mut [u8]) -> io::Result<usize> { |
| Err(io::Error::from(self.errkind)) |
| } |
| } |