| // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. |
| // Copyright by contributors to this project. |
| // SPDX-License-Identifier: (Apache-2.0 OR MIT) |
| |
| use alloc::vec; |
| use alloc::vec::Vec; |
| use mls_rs_codec::{MlsDecode, MlsEncode}; |
| use mls_rs_core::{error::IntoAnyError, identity::IdentityProvider, key_package::KeyPackageData}; |
| |
| use crate::client::MlsError; |
| use crate::{ |
| crypto::{HpkeSecretKey, SignatureSecretKey}, |
| group::framing::MlsMessagePayload, |
| identity::SigningIdentity, |
| protocol_version::ProtocolVersion, |
| signer::Signable, |
| tree_kem::{ |
| leaf_node::{ConfigProperties, LeafNode}, |
| Capabilities, Lifetime, |
| }, |
| CipherSuiteProvider, ExtensionList, MlsMessage, |
| }; |
| |
| use super::{KeyPackage, KeyPackageRef}; |
| |
| #[derive(Clone, Debug)] |
| pub struct KeyPackageGenerator<'a, IP, CP> |
| where |
| IP: IdentityProvider, |
| CP: CipherSuiteProvider, |
| { |
| pub protocol_version: ProtocolVersion, |
| pub cipher_suite_provider: &'a CP, |
| pub signing_identity: &'a SigningIdentity, |
| pub signing_key: &'a SignatureSecretKey, |
| pub identity_provider: &'a IP, |
| } |
| |
| #[derive(Clone, Debug)] |
| pub struct KeyPackageGeneration { |
| pub(crate) reference: KeyPackageRef, |
| pub(crate) key_package: KeyPackage, |
| pub(crate) init_secret_key: HpkeSecretKey, |
| pub(crate) leaf_node_secret_key: HpkeSecretKey, |
| } |
| |
| impl KeyPackageGeneration { |
| pub fn to_storage(&self) -> Result<(Vec<u8>, KeyPackageData), MlsError> { |
| let id = self.reference.to_vec(); |
| |
| let data = KeyPackageData::new( |
| self.key_package.mls_encode_to_vec()?, |
| self.init_secret_key.clone(), |
| self.leaf_node_secret_key.clone(), |
| self.key_package.expiration()?, |
| ); |
| |
| Ok((id, data)) |
| } |
| |
| pub fn from_storage(id: Vec<u8>, data: KeyPackageData) -> Result<Self, MlsError> { |
| Ok(KeyPackageGeneration { |
| reference: KeyPackageRef::from(id), |
| key_package: KeyPackage::mls_decode(&mut &*data.key_package_bytes)?, |
| init_secret_key: data.init_key, |
| leaf_node_secret_key: data.leaf_node_key, |
| }) |
| } |
| |
| pub fn key_package_message(&self) -> MlsMessage { |
| MlsMessage::new( |
| self.key_package.version, |
| MlsMessagePayload::KeyPackage(self.key_package.clone()), |
| ) |
| } |
| } |
| |
| impl<'a, IP, CP> KeyPackageGenerator<'a, IP, CP> |
| where |
| IP: IdentityProvider, |
| CP: CipherSuiteProvider, |
| { |
| #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] |
| pub(super) async fn sign(&self, package: &mut KeyPackage) -> Result<(), MlsError> { |
| package |
| .sign(self.cipher_suite_provider, self.signing_key, &()) |
| .await |
| } |
| |
| #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] |
| pub async fn generate( |
| &self, |
| lifetime: Lifetime, |
| capabilities: Capabilities, |
| key_package_extensions: ExtensionList, |
| leaf_node_extensions: ExtensionList, |
| ) -> Result<KeyPackageGeneration, MlsError> { |
| let (init_secret_key, public_init) = self |
| .cipher_suite_provider |
| .kem_generate() |
| .await |
| .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?; |
| |
| let properties = ConfigProperties { |
| capabilities, |
| extensions: leaf_node_extensions, |
| }; |
| |
| let (leaf_node, leaf_node_secret) = LeafNode::generate( |
| self.cipher_suite_provider, |
| properties, |
| self.signing_identity.clone(), |
| self.signing_key, |
| lifetime, |
| ) |
| .await?; |
| |
| let mut package = KeyPackage { |
| version: self.protocol_version, |
| cipher_suite: self.cipher_suite_provider.cipher_suite(), |
| hpke_init_key: public_init, |
| leaf_node, |
| extensions: key_package_extensions, |
| signature: vec![], |
| }; |
| |
| package.grease(self.cipher_suite_provider)?; |
| |
| self.sign(&mut package).await?; |
| |
| let reference = package.to_reference(self.cipher_suite_provider).await?; |
| |
| Ok(KeyPackageGeneration { |
| key_package: package, |
| init_secret_key, |
| leaf_node_secret_key: leaf_node_secret, |
| reference, |
| }) |
| } |
| } |
| |
| #[cfg(test)] |
| mod tests { |
| use assert_matches::assert_matches; |
| use mls_rs_core::crypto::CipherSuiteProvider; |
| |
| use crate::{ |
| crypto::test_utils::{test_cipher_suite_provider, TestCryptoProvider}, |
| extension::test_utils::TestExtension, |
| group::test_utils::random_bytes, |
| identity::basic::BasicIdentityProvider, |
| identity::test_utils::get_test_signing_identity, |
| key_package::validate_key_package_properties, |
| protocol_version::ProtocolVersion, |
| tree_kem::{ |
| leaf_node::{test_utils::get_test_capabilities, LeafNodeSource}, |
| leaf_node_validator::{LeafNodeValidator, ValidationContext}, |
| Lifetime, |
| }, |
| ExtensionList, |
| }; |
| |
| use super::KeyPackageGenerator; |
| |
| fn test_key_package_ext(val: u8) -> ExtensionList { |
| let mut ext_list = ExtensionList::new(); |
| ext_list.set_from(TestExtension::from(val)).unwrap(); |
| ext_list |
| } |
| |
| fn test_leaf_node_ext(val: u8) -> ExtensionList { |
| let mut ext_list = ExtensionList::new(); |
| ext_list.set_from(TestExtension::from(val)).unwrap(); |
| ext_list |
| } |
| |
| fn test_lifetime() -> Lifetime { |
| Lifetime::years(1).unwrap() |
| } |
| |
| #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] |
| async fn test_key_generation() { |
| for (protocol_version, cipher_suite) in ProtocolVersion::all().flat_map(|p| { |
| TestCryptoProvider::all_supported_cipher_suites() |
| .into_iter() |
| .map(move |cs| (p, cs)) |
| }) { |
| let cipher_suite_provider = test_cipher_suite_provider(cipher_suite); |
| |
| let (signing_identity, signing_key) = |
| get_test_signing_identity(cipher_suite, b"foo").await; |
| |
| let key_package_ext = test_key_package_ext(32); |
| let leaf_node_ext = test_leaf_node_ext(42); |
| let lifetime = test_lifetime(); |
| |
| let test_generator = KeyPackageGenerator { |
| protocol_version, |
| cipher_suite_provider: &cipher_suite_provider, |
| signing_identity: &signing_identity, |
| signing_key: &signing_key, |
| identity_provider: &BasicIdentityProvider, |
| }; |
| |
| let mut capabilities = get_test_capabilities(); |
| capabilities.extensions.push(42.into()); |
| capabilities.extensions.push(43.into()); |
| capabilities.extensions.push(32.into()); |
| |
| let generated = test_generator |
| .generate( |
| lifetime.clone(), |
| capabilities.clone(), |
| key_package_ext.clone(), |
| leaf_node_ext.clone(), |
| ) |
| .await |
| .unwrap(); |
| |
| assert_matches!(generated.key_package.leaf_node.leaf_node_source, |
| LeafNodeSource::KeyPackage(ref lt) if lt == &lifetime); |
| |
| assert_eq!( |
| generated.key_package.leaf_node.ungreased_capabilities(), |
| capabilities |
| ); |
| |
| assert_eq!( |
| generated.key_package.leaf_node.ungreased_extensions(), |
| leaf_node_ext |
| ); |
| |
| assert_eq!( |
| generated.key_package.ungreased_extensions(), |
| key_package_ext |
| ); |
| |
| assert_ne!( |
| generated.key_package.hpke_init_key.as_ref(), |
| generated.key_package.leaf_node.public_key.as_ref() |
| ); |
| |
| assert_eq!(generated.key_package.cipher_suite, cipher_suite); |
| assert_eq!(generated.key_package.version, protocol_version); |
| |
| // Verify that the hpke key pair generated will work |
| let test_data = random_bytes(32); |
| |
| let sealed = cipher_suite_provider |
| .hpke_seal(&generated.key_package.hpke_init_key, &[], None, &test_data) |
| .await |
| .unwrap(); |
| |
| let opened = cipher_suite_provider |
| .hpke_open( |
| &sealed, |
| &generated.init_secret_key, |
| &generated.key_package.hpke_init_key, |
| &[], |
| None, |
| ) |
| .await |
| .unwrap(); |
| |
| assert_eq!(opened, test_data); |
| |
| let validator = |
| LeafNodeValidator::new(&cipher_suite_provider, &BasicIdentityProvider, None); |
| |
| validator |
| .check_if_valid( |
| &generated.key_package.leaf_node, |
| ValidationContext::Add(None), |
| ) |
| .await |
| .unwrap(); |
| |
| validate_key_package_properties( |
| &generated.key_package, |
| protocol_version, |
| &cipher_suite_provider, |
| ) |
| .await |
| .unwrap(); |
| } |
| } |
| |
| #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] |
| async fn test_randomness() { |
| for (protocol_version, cipher_suite) in ProtocolVersion::all().flat_map(|p| { |
| TestCryptoProvider::all_supported_cipher_suites() |
| .into_iter() |
| .map(move |cs| (p, cs)) |
| }) { |
| let (signing_identity, signing_key) = |
| get_test_signing_identity(cipher_suite, b"foo").await; |
| |
| let test_generator = KeyPackageGenerator { |
| protocol_version, |
| cipher_suite_provider: &test_cipher_suite_provider(cipher_suite), |
| signing_identity: &signing_identity, |
| signing_key: &signing_key, |
| identity_provider: &BasicIdentityProvider, |
| }; |
| |
| let first_key_package = test_generator |
| .generate( |
| test_lifetime(), |
| get_test_capabilities(), |
| ExtensionList::default(), |
| ExtensionList::default(), |
| ) |
| .await |
| .unwrap(); |
| |
| for _ in 0..100 { |
| let next_key_package = test_generator |
| .generate( |
| test_lifetime(), |
| get_test_capabilities(), |
| ExtensionList::default(), |
| ExtensionList::default(), |
| ) |
| .await |
| .unwrap(); |
| |
| assert_ne!( |
| first_key_package.key_package.hpke_init_key, |
| next_key_package.key_package.hpke_init_key |
| ); |
| |
| assert_ne!( |
| first_key_package.key_package.leaf_node.public_key, |
| next_key_package.key_package.leaf_node.public_key |
| ); |
| } |
| } |
| } |
| } |