| // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. |
| // Copyright by contributors to this project. |
| // SPDX-License-Identifier: (Apache-2.0 OR MIT) |
| |
| use alloc::vec::Vec; |
| use core::{ |
| fmt::{self, Debug}, |
| ops::{Deref, DerefMut}, |
| }; |
| |
| use zeroize::Zeroizing; |
| |
| use crate::{client::MlsError, map::LargeMap, tree_kem::math::TreeIndex, CipherSuiteProvider}; |
| |
| use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize}; |
| use mls_rs_core::error::IntoAnyError; |
| |
| use super::key_schedule::kdf_expand_with_label; |
| |
| pub(crate) const MAX_RATCHET_BACK_HISTORY: u32 = 1024; |
| |
| #[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)] |
| #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] |
| #[repr(u8)] |
| enum SecretTreeNode { |
| Secret(TreeSecret) = 0u8, |
| Ratchet(SecretRatchets) = 1u8, |
| } |
| |
| impl SecretTreeNode { |
| fn into_secret(self) -> Option<TreeSecret> { |
| if let SecretTreeNode::Secret(secret) = self { |
| Some(secret) |
| } else { |
| None |
| } |
| } |
| } |
| |
| #[derive(Clone, PartialEq, MlsEncode, MlsDecode, MlsSize)] |
| #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] |
| struct TreeSecret( |
| #[mls_codec(with = "mls_rs_codec::byte_vec")] |
| #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))] |
| Zeroizing<Vec<u8>>, |
| ); |
| |
| impl Debug for TreeSecret { |
| fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
| mls_rs_core::debug::pretty_bytes(&self.0) |
| .named("TreeSecret") |
| .fmt(f) |
| } |
| } |
| |
| impl Deref for TreeSecret { |
| type Target = Vec<u8>; |
| |
| fn deref(&self) -> &Self::Target { |
| &self.0 |
| } |
| } |
| |
| impl DerefMut for TreeSecret { |
| fn deref_mut(&mut self) -> &mut Self::Target { |
| &mut self.0 |
| } |
| } |
| |
| impl AsRef<[u8]> for TreeSecret { |
| fn as_ref(&self) -> &[u8] { |
| &self.0 |
| } |
| } |
| |
| impl From<Vec<u8>> for TreeSecret { |
| fn from(vec: Vec<u8>) -> Self { |
| TreeSecret(Zeroizing::new(vec)) |
| } |
| } |
| |
| impl From<Zeroizing<Vec<u8>>> for TreeSecret { |
| fn from(vec: Zeroizing<Vec<u8>>) -> Self { |
| TreeSecret(vec) |
| } |
| } |
| |
| #[derive(Clone, Debug, PartialEq, MlsEncode, MlsDecode, MlsSize, Default)] |
| #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] |
| struct TreeSecretsVec<T: TreeIndex> { |
| inner: LargeMap<T, SecretTreeNode>, |
| } |
| |
| impl<T: TreeIndex> TreeSecretsVec<T> { |
| fn set_node(&mut self, index: T, value: SecretTreeNode) { |
| self.inner.insert(index, value); |
| } |
| |
| fn take_node(&mut self, index: &T) -> Option<SecretTreeNode> { |
| self.inner.remove(index) |
| } |
| } |
| |
| #[derive(Clone, Debug, PartialEq, MlsEncode, MlsDecode, MlsSize)] |
| #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] |
| pub struct SecretTree<T: TreeIndex> { |
| known_secrets: TreeSecretsVec<T>, |
| leaf_count: T, |
| } |
| |
| impl<T: TreeIndex> SecretTree<T> { |
| pub(crate) fn empty() -> SecretTree<T> { |
| SecretTree { |
| known_secrets: Default::default(), |
| leaf_count: T::zero(), |
| } |
| } |
| } |
| |
| #[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)] |
| #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] |
| pub struct SecretRatchets { |
| pub application: SecretKeyRatchet, |
| pub handshake: SecretKeyRatchet, |
| } |
| |
| impl SecretRatchets { |
| #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] |
| pub async fn message_key_generation<P: CipherSuiteProvider>( |
| &mut self, |
| cipher_suite_provider: &P, |
| generation: u32, |
| key_type: KeyType, |
| ) -> Result<MessageKeyData, MlsError> { |
| match key_type { |
| KeyType::Handshake => { |
| self.handshake |
| .get_message_key(cipher_suite_provider, generation) |
| .await |
| } |
| KeyType::Application => { |
| self.application |
| .get_message_key(cipher_suite_provider, generation) |
| .await |
| } |
| } |
| } |
| |
| #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] |
| pub async fn next_message_key<P: CipherSuiteProvider>( |
| &mut self, |
| cipher_suite: &P, |
| key_type: KeyType, |
| ) -> Result<MessageKeyData, MlsError> { |
| match key_type { |
| KeyType::Handshake => self.handshake.next_message_key(cipher_suite).await, |
| KeyType::Application => self.application.next_message_key(cipher_suite).await, |
| } |
| } |
| } |
| |
| impl<T: TreeIndex> SecretTree<T> { |
| pub fn new(leaf_count: T, encryption_secret: Zeroizing<Vec<u8>>) -> SecretTree<T> { |
| let mut known_secrets = TreeSecretsVec::default(); |
| |
| let root_secret = SecretTreeNode::Secret(TreeSecret::from(encryption_secret)); |
| known_secrets.set_node(leaf_count.root(), root_secret); |
| |
| Self { |
| known_secrets, |
| leaf_count, |
| } |
| } |
| |
| #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] |
| async fn consume_node<P: CipherSuiteProvider>( |
| &mut self, |
| cipher_suite_provider: &P, |
| index: &T, |
| ) -> Result<(), MlsError> { |
| let node = self.known_secrets.take_node(index); |
| |
| if let Some(secret) = node.and_then(|n| n.into_secret()) { |
| let left_index = index.left().ok_or(MlsError::LeafNodeNoChildren)?; |
| let right_index = index.right().ok_or(MlsError::LeafNodeNoChildren)?; |
| |
| let left_secret = |
| kdf_expand_with_label(cipher_suite_provider, &secret, b"tree", b"left", None) |
| .await?; |
| |
| let right_secret = |
| kdf_expand_with_label(cipher_suite_provider, &secret, b"tree", b"right", None) |
| .await?; |
| |
| self.known_secrets |
| .set_node(left_index, SecretTreeNode::Secret(left_secret.into())); |
| |
| self.known_secrets |
| .set_node(right_index, SecretTreeNode::Secret(right_secret.into())); |
| } |
| |
| Ok(()) |
| } |
| |
| #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] |
| async fn take_leaf_ratchet<P: CipherSuiteProvider>( |
| &mut self, |
| cipher_suite: &P, |
| leaf_index: &T, |
| ) -> Result<SecretRatchets, MlsError> { |
| let node_index = leaf_index; |
| |
| let node = match self.known_secrets.take_node(node_index) { |
| Some(node) => node, |
| None => { |
| // Start at the root node and work your way down consuming any intermediates needed |
| for i in node_index.direct_copath(&self.leaf_count).into_iter().rev() { |
| self.consume_node(cipher_suite, &i.path).await?; |
| } |
| |
| self.known_secrets |
| .take_node(node_index) |
| .ok_or(MlsError::InvalidLeafConsumption)? |
| } |
| }; |
| |
| Ok(match node { |
| SecretTreeNode::Ratchet(ratchet) => ratchet, |
| SecretTreeNode::Secret(secret) => SecretRatchets { |
| application: SecretKeyRatchet::new(cipher_suite, &secret, KeyType::Application) |
| .await?, |
| handshake: SecretKeyRatchet::new(cipher_suite, &secret, KeyType::Handshake).await?, |
| }, |
| }) |
| } |
| |
| #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] |
| pub async fn next_message_key<P: CipherSuiteProvider>( |
| &mut self, |
| cipher_suite: &P, |
| leaf_index: T, |
| key_type: KeyType, |
| ) -> Result<MessageKeyData, MlsError> { |
| let mut ratchet = self.take_leaf_ratchet(cipher_suite, &leaf_index).await?; |
| let res = ratchet.next_message_key(cipher_suite, key_type).await?; |
| |
| self.known_secrets |
| .set_node(leaf_index, SecretTreeNode::Ratchet(ratchet)); |
| |
| Ok(res) |
| } |
| |
| #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] |
| pub async fn message_key_generation<P: CipherSuiteProvider>( |
| &mut self, |
| cipher_suite: &P, |
| leaf_index: T, |
| key_type: KeyType, |
| generation: u32, |
| ) -> Result<MessageKeyData, MlsError> { |
| let mut ratchet = self.take_leaf_ratchet(cipher_suite, &leaf_index).await?; |
| |
| let res = ratchet |
| .message_key_generation(cipher_suite, generation, key_type) |
| .await?; |
| |
| self.known_secrets |
| .set_node(leaf_index, SecretTreeNode::Ratchet(ratchet)); |
| |
| Ok(res) |
| } |
| } |
| |
| #[derive(Clone, Copy)] |
| pub enum KeyType { |
| Handshake, |
| Application, |
| } |
| |
| #[cfg_attr( |
| all(feature = "ffi", not(test)), |
| safer_ffi_gen::ffi_type(clone, opaque) |
| )] |
| #[derive(Clone, PartialEq, Eq, MlsEncode, MlsDecode, MlsSize)] |
| #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] |
| /// AEAD key derived by the MLS secret tree. |
| pub struct MessageKeyData { |
| #[mls_codec(with = "mls_rs_codec::byte_vec")] |
| #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))] |
| pub(crate) nonce: Zeroizing<Vec<u8>>, |
| #[mls_codec(with = "mls_rs_codec::byte_vec")] |
| #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))] |
| pub(crate) key: Zeroizing<Vec<u8>>, |
| pub(crate) generation: u32, |
| } |
| |
| impl Debug for MessageKeyData { |
| fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
| f.debug_struct("MessageKeyData") |
| .field("nonce", &mls_rs_core::debug::pretty_bytes(&self.nonce)) |
| .field("key", &mls_rs_core::debug::pretty_bytes(&self.key)) |
| .field("generation", &self.generation) |
| .finish() |
| } |
| } |
| |
| #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)] |
| impl MessageKeyData { |
| /// AEAD nonce. |
| #[cfg_attr(not(feature = "secret_tree_access"), allow(dead_code))] |
| pub fn nonce(&self) -> &[u8] { |
| &self.nonce |
| } |
| |
| /// AEAD key. |
| #[cfg_attr(not(feature = "secret_tree_access"), allow(dead_code))] |
| pub fn key(&self) -> &[u8] { |
| &self.key |
| } |
| |
| /// Generation of this key within the key schedule. |
| #[cfg_attr(not(feature = "secret_tree_access"), allow(dead_code))] |
| pub fn generation(&self) -> u32 { |
| self.generation |
| } |
| } |
| |
| #[derive(Debug, Clone, PartialEq)] |
| #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] |
| pub struct SecretKeyRatchet { |
| secret: TreeSecret, |
| generation: u32, |
| #[cfg(feature = "out_of_order")] |
| history: LargeMap<u32, MessageKeyData>, |
| } |
| |
| impl MlsSize for SecretKeyRatchet { |
| fn mls_encoded_len(&self) -> usize { |
| let len = mls_rs_codec::byte_vec::mls_encoded_len(&self.secret) |
| + self.generation.mls_encoded_len(); |
| |
| #[cfg(feature = "out_of_order")] |
| return len + mls_rs_codec::iter::mls_encoded_len(self.history.values()); |
| #[cfg(not(feature = "out_of_order"))] |
| return len; |
| } |
| } |
| |
| #[cfg(feature = "out_of_order")] |
| impl MlsEncode for SecretKeyRatchet { |
| fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error> { |
| mls_rs_codec::byte_vec::mls_encode(&self.secret, writer)?; |
| self.generation.mls_encode(writer)?; |
| mls_rs_codec::iter::mls_encode(self.history.values(), writer) |
| } |
| } |
| |
| #[cfg(not(feature = "out_of_order"))] |
| impl MlsEncode for SecretKeyRatchet { |
| fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error> { |
| mls_rs_codec::byte_vec::mls_encode(&self.secret, writer)?; |
| self.generation.mls_encode(writer) |
| } |
| } |
| |
| impl MlsDecode for SecretKeyRatchet { |
| fn mls_decode(reader: &mut &[u8]) -> Result<Self, mls_rs_codec::Error> { |
| Ok(Self { |
| secret: mls_rs_codec::byte_vec::mls_decode(reader)?, |
| generation: u32::mls_decode(reader)?, |
| #[cfg(feature = "out_of_order")] |
| history: mls_rs_codec::iter::mls_decode_collection(reader, |data| { |
| let mut items = LargeMap::default(); |
| |
| while !data.is_empty() { |
| let item = MessageKeyData::mls_decode(data)?; |
| items.insert(item.generation, item); |
| } |
| |
| Ok(items) |
| })?, |
| }) |
| } |
| } |
| |
| impl SecretKeyRatchet { |
| #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] |
| async fn new<P: CipherSuiteProvider>( |
| cipher_suite_provider: &P, |
| secret: &[u8], |
| key_type: KeyType, |
| ) -> Result<Self, MlsError> { |
| let label = match key_type { |
| KeyType::Handshake => b"handshake".as_slice(), |
| KeyType::Application => b"application".as_slice(), |
| }; |
| |
| let secret = kdf_expand_with_label(cipher_suite_provider, secret, label, &[], None) |
| .await |
| .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?; |
| |
| Ok(Self { |
| secret: TreeSecret::from(secret), |
| generation: 0, |
| #[cfg(feature = "out_of_order")] |
| history: Default::default(), |
| }) |
| } |
| |
| #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] |
| async fn get_message_key<P: CipherSuiteProvider>( |
| &mut self, |
| cipher_suite_provider: &P, |
| generation: u32, |
| ) -> Result<MessageKeyData, MlsError> { |
| #[cfg(feature = "out_of_order")] |
| if generation < self.generation { |
| return self |
| .history |
| .remove_entry(&generation) |
| .map(|(_, mk)| mk) |
| .ok_or(MlsError::KeyMissing(generation)); |
| } |
| |
| #[cfg(not(feature = "out_of_order"))] |
| if generation < self.generation { |
| return Err(MlsError::KeyMissing(generation)); |
| } |
| |
| let max_generation_allowed = self.generation + MAX_RATCHET_BACK_HISTORY; |
| |
| if generation > max_generation_allowed { |
| return Err(MlsError::InvalidFutureGeneration(generation)); |
| } |
| |
| #[cfg(not(feature = "out_of_order"))] |
| while self.generation < generation { |
| self.next_message_key(cipher_suite_provider)?; |
| } |
| |
| #[cfg(feature = "out_of_order")] |
| while self.generation < generation { |
| let key_data = self.next_message_key(cipher_suite_provider).await?; |
| self.history.insert(key_data.generation, key_data); |
| } |
| |
| self.next_message_key(cipher_suite_provider).await |
| } |
| |
| #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] |
| async fn next_message_key<P: CipherSuiteProvider>( |
| &mut self, |
| cipher_suite_provider: &P, |
| ) -> Result<MessageKeyData, MlsError> { |
| let generation = self.generation; |
| |
| let key = MessageKeyData { |
| nonce: self |
| .derive_secret( |
| cipher_suite_provider, |
| b"nonce", |
| cipher_suite_provider.aead_nonce_size(), |
| ) |
| .await?, |
| key: self |
| .derive_secret( |
| cipher_suite_provider, |
| b"key", |
| cipher_suite_provider.aead_key_size(), |
| ) |
| .await?, |
| generation, |
| }; |
| |
| self.secret = self |
| .derive_secret( |
| cipher_suite_provider, |
| b"secret", |
| cipher_suite_provider.kdf_extract_size(), |
| ) |
| .await? |
| .into(); |
| |
| self.generation = generation + 1; |
| |
| Ok(key) |
| } |
| |
| #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] |
| async fn derive_secret<P: CipherSuiteProvider>( |
| &self, |
| cipher_suite_provider: &P, |
| label: &[u8], |
| len: usize, |
| ) -> Result<Zeroizing<Vec<u8>>, MlsError> { |
| kdf_expand_with_label( |
| cipher_suite_provider, |
| self.secret.as_ref(), |
| label, |
| &self.generation.to_be_bytes(), |
| Some(len), |
| ) |
| .await |
| .map_err(|e| MlsError::CryptoProviderError(e.into_any_error())) |
| } |
| } |
| |
| #[cfg(test)] |
| pub(crate) mod test_utils { |
| use alloc::{string::String, vec::Vec}; |
| use mls_rs_core::crypto::CipherSuiteProvider; |
| use zeroize::Zeroizing; |
| |
| use crate::{crypto::test_utils::try_test_cipher_suite_provider, tree_kem::math::TreeIndex}; |
| |
| use super::{KeyType, SecretKeyRatchet, SecretTree}; |
| |
| pub(crate) fn get_test_tree<T: TreeIndex>(secret: Vec<u8>, leaf_count: T) -> SecretTree<T> { |
| SecretTree::new(leaf_count, Zeroizing::new(secret)) |
| } |
| |
| impl SecretTree<u32> { |
| pub(crate) fn get_root_secret(&self) -> Vec<u8> { |
| self.known_secrets |
| .clone() |
| .take_node(&self.leaf_count.root()) |
| .unwrap() |
| .into_secret() |
| .unwrap() |
| .to_vec() |
| } |
| } |
| |
| #[derive(Debug, serde::Serialize, serde::Deserialize)] |
| pub struct RatchetInteropTestCase { |
| #[serde(with = "hex::serde")] |
| secret: Vec<u8>, |
| label: String, |
| generation: u32, |
| length: usize, |
| #[serde(with = "hex::serde")] |
| out: Vec<u8>, |
| } |
| |
| #[derive(Debug, serde::Serialize, serde::Deserialize)] |
| pub struct InteropTestCase { |
| cipher_suite: u16, |
| derive_tree_secret: RatchetInteropTestCase, |
| } |
| |
| #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] |
| async fn test_basic_crypto_test_vectors() { |
| let test_cases: Vec<InteropTestCase> = |
| load_test_case_json!(basic_crypto, Vec::<InteropTestCase>::new()); |
| |
| for test_case in test_cases { |
| if let Some(cs) = try_test_cipher_suite_provider(test_case.cipher_suite) { |
| test_case.derive_tree_secret.verify(&cs).await |
| } |
| } |
| } |
| |
| impl RatchetInteropTestCase { |
| #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] |
| pub async fn verify<P: CipherSuiteProvider>(&self, cs: &P) { |
| let mut ratchet = SecretKeyRatchet::new(cs, &self.secret, KeyType::Application) |
| .await |
| .unwrap(); |
| |
| ratchet.secret = self.secret.clone().into(); |
| ratchet.generation = self.generation; |
| |
| let computed = ratchet |
| .derive_secret(cs, self.label.as_bytes(), self.length) |
| .await |
| .unwrap(); |
| |
| assert_eq!(&computed.to_vec(), &self.out); |
| } |
| } |
| } |
| |
| #[cfg(test)] |
| mod tests { |
| use alloc::vec; |
| |
| use crate::{ |
| cipher_suite::CipherSuite, |
| client::test_utils::TEST_CIPHER_SUITE, |
| crypto::test_utils::{ |
| test_cipher_suite_provider, try_test_cipher_suite_provider, TestCryptoProvider, |
| }, |
| tree_kem::node::NodeIndex, |
| }; |
| |
| #[cfg(not(mls_build_async))] |
| use crate::group::test_utils::random_bytes; |
| |
| use super::{test_utils::get_test_tree, *}; |
| |
| use assert_matches::assert_matches; |
| |
| #[cfg(target_arch = "wasm32")] |
| use wasm_bindgen_test::wasm_bindgen_test as test; |
| |
| #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] |
| async fn test_secret_tree() { |
| test_secret_tree_custom(16u32, (0..16).map(|i| 2 * i).collect(), true).await; |
| test_secret_tree_custom(1u64 << 62, (1..62).map(|i| 1u64 << i).collect(), false).await; |
| } |
| |
| #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] |
| async fn test_secret_tree_custom<T: TreeIndex>( |
| leaf_count: T, |
| leaves_to_check: Vec<T>, |
| all_deleted: bool, |
| ) { |
| for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() { |
| let cs_provider = test_cipher_suite_provider(cipher_suite); |
| |
| let test_secret = vec![0u8; cs_provider.kdf_extract_size()]; |
| let mut test_tree = get_test_tree(test_secret, leaf_count.clone()); |
| |
| let mut secrets = Vec::<SecretRatchets>::new(); |
| |
| for i in &leaves_to_check { |
| let secret = test_tree |
| .take_leaf_ratchet(&test_cipher_suite_provider(cipher_suite), i) |
| .await |
| .unwrap(); |
| |
| secrets.push(secret); |
| } |
| |
| // Verify the tree is now completely empty |
| assert!(!all_deleted || test_tree.known_secrets.inner.is_empty()); |
| |
| // Verify that all the secrets are unique |
| let count = secrets.len(); |
| secrets.dedup(); |
| assert_eq!(count, secrets.len()); |
| } |
| } |
| |
| #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] |
| async fn test_secret_key_ratchet() { |
| for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() { |
| let provider = test_cipher_suite_provider(cipher_suite); |
| |
| let mut app_ratchet = SecretKeyRatchet::new( |
| &provider, |
| &vec![0u8; provider.kdf_extract_size()], |
| KeyType::Application, |
| ) |
| .await |
| .unwrap(); |
| |
| let mut handshake_ratchet = SecretKeyRatchet::new( |
| &provider, |
| &vec![0u8; provider.kdf_extract_size()], |
| KeyType::Handshake, |
| ) |
| .await |
| .unwrap(); |
| |
| let app_key_one = app_ratchet.next_message_key(&provider).await.unwrap(); |
| let app_key_two = app_ratchet.next_message_key(&provider).await.unwrap(); |
| let app_keys = vec![app_key_one, app_key_two]; |
| |
| let handshake_key_one = handshake_ratchet.next_message_key(&provider).await.unwrap(); |
| let handshake_key_two = handshake_ratchet.next_message_key(&provider).await.unwrap(); |
| let handshake_keys = vec![handshake_key_one, handshake_key_two]; |
| |
| // Verify that the keys have different outcomes due to their different labels |
| assert_ne!(app_keys, handshake_keys); |
| |
| // Verify that the keys at each generation are different |
| assert_ne!(handshake_keys[0], handshake_keys[1]); |
| } |
| } |
| |
| #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] |
| async fn test_get_key() { |
| for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() { |
| let provider = test_cipher_suite_provider(cipher_suite); |
| |
| let mut ratchet = SecretKeyRatchet::new( |
| &test_cipher_suite_provider(cipher_suite), |
| &vec![0u8; provider.kdf_extract_size()], |
| KeyType::Application, |
| ) |
| .await |
| .unwrap(); |
| |
| let mut ratchet_clone = ratchet.clone(); |
| |
| // This will generate keys 0 and 1 in ratchet_clone |
| let _ = ratchet_clone.next_message_key(&provider).await.unwrap(); |
| let clone_2 = ratchet_clone.next_message_key(&provider).await.unwrap(); |
| |
| // Going back in time should result in an error |
| let res = ratchet_clone.get_message_key(&provider, 0).await; |
| assert!(res.is_err()); |
| |
| // Calling get key should be the same as calling next until hitting the desired generation |
| let second_key = ratchet |
| .get_message_key(&provider, ratchet_clone.generation - 1) |
| .await |
| .unwrap(); |
| |
| assert_eq!(clone_2, second_key) |
| } |
| } |
| |
| #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] |
| async fn test_secret_ratchet() { |
| for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() { |
| let provider = test_cipher_suite_provider(cipher_suite); |
| |
| let mut ratchet = SecretKeyRatchet::new( |
| &provider, |
| &vec![0u8; provider.kdf_extract_size()], |
| KeyType::Application, |
| ) |
| .await |
| .unwrap(); |
| |
| let original_secret = ratchet.secret.clone(); |
| let _ = ratchet.next_message_key(&provider).await.unwrap(); |
| let new_secret = ratchet.secret; |
| assert_ne!(original_secret, new_secret) |
| } |
| } |
| |
| #[cfg(feature = "out_of_order")] |
| #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] |
| async fn test_out_of_order_keys() { |
| let cipher_suite = TEST_CIPHER_SUITE; |
| let provider = test_cipher_suite_provider(cipher_suite); |
| |
| let mut ratchet = SecretKeyRatchet::new(&provider, &[0u8; 32], KeyType::Handshake) |
| .await |
| .unwrap(); |
| let mut ratchet_clone = ratchet.clone(); |
| |
| // Ask for all the keys in order from the original ratchet |
| let mut ordered_keys = Vec::<MessageKeyData>::new(); |
| |
| for i in 0..=MAX_RATCHET_BACK_HISTORY { |
| ordered_keys.push(ratchet.get_message_key(&provider, i).await.unwrap()); |
| } |
| |
| // Ask for a key at index MAX_RATCHET_BACK_HISTORY in the clone |
| let last_key = ratchet_clone |
| .get_message_key(&provider, MAX_RATCHET_BACK_HISTORY) |
| .await |
| .unwrap(); |
| |
| assert_eq!(last_key, ordered_keys[ordered_keys.len() - 1]); |
| |
| // Get all the other keys |
| let mut back_history_keys = Vec::<MessageKeyData>::new(); |
| |
| for i in 0..MAX_RATCHET_BACK_HISTORY - 1 { |
| back_history_keys.push(ratchet_clone.get_message_key(&provider, i).await.unwrap()); |
| } |
| |
| assert_eq!( |
| back_history_keys, |
| ordered_keys[..(MAX_RATCHET_BACK_HISTORY as usize) - 1] |
| ); |
| } |
| |
| #[cfg(not(feature = "out_of_order"))] |
| #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] |
| async fn out_of_order_keys_should_throw_error() { |
| let cipher_suite = TEST_CIPHER_SUITE; |
| let provider = test_cipher_suite_provider(cipher_suite); |
| |
| let mut ratchet = SecretKeyRatchet::new(&provider, &[0u8; 32], KeyType::Handshake) |
| .await |
| .unwrap(); |
| |
| ratchet.get_message_key(&provider, 10).await.unwrap(); |
| let res = ratchet.get_message_key(&provider, 9).await; |
| assert_matches!(res, Err(MlsError::KeyMissing(9))) |
| } |
| |
| #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] |
| async fn test_too_out_of_order() { |
| let cipher_suite = TEST_CIPHER_SUITE; |
| let provider = test_cipher_suite_provider(cipher_suite); |
| |
| let mut ratchet = SecretKeyRatchet::new(&provider, &[0u8; 32], KeyType::Handshake) |
| .await |
| .unwrap(); |
| |
| let res = ratchet |
| .get_message_key(&provider, MAX_RATCHET_BACK_HISTORY + 1) |
| .await; |
| |
| let invalid_generation = MAX_RATCHET_BACK_HISTORY + 1; |
| |
| assert_matches!( |
| res, |
| Err(MlsError::InvalidFutureGeneration(invalid)) |
| if invalid == invalid_generation |
| ) |
| } |
| |
| #[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)] |
| struct Ratchet { |
| application_keys: Vec<Vec<u8>>, |
| handshake_keys: Vec<Vec<u8>>, |
| } |
| |
| #[derive(Debug, serde::Serialize, serde::Deserialize)] |
| struct TestCase { |
| cipher_suite: u16, |
| #[serde(with = "hex::serde")] |
| encryption_secret: Vec<u8>, |
| ratchets: Vec<Ratchet>, |
| } |
| |
| #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] |
| async fn get_ratchet_data( |
| secret_tree: &mut SecretTree<NodeIndex>, |
| cipher_suite: CipherSuite, |
| ) -> Vec<Ratchet> { |
| let provider = test_cipher_suite_provider(cipher_suite); |
| let mut ratchet_data = Vec::new(); |
| |
| for index in 0..16 { |
| let mut ratchets = secret_tree |
| .take_leaf_ratchet(&provider, &(index * 2)) |
| .await |
| .unwrap(); |
| |
| let mut application_keys = Vec::new(); |
| |
| for _ in 0..20 { |
| let key = ratchets |
| .handshake |
| .next_message_key(&provider) |
| .await |
| .unwrap() |
| .mls_encode_to_vec() |
| .unwrap(); |
| |
| application_keys.push(key); |
| } |
| |
| let mut handshake_keys = Vec::new(); |
| |
| for _ in 0..20 { |
| let key = ratchets |
| .handshake |
| .next_message_key(&provider) |
| .await |
| .unwrap() |
| .mls_encode_to_vec() |
| .unwrap(); |
| |
| handshake_keys.push(key); |
| } |
| |
| ratchet_data.push(Ratchet { |
| application_keys, |
| handshake_keys, |
| }); |
| } |
| |
| ratchet_data |
| } |
| |
| #[cfg(not(mls_build_async))] |
| #[cfg_attr(coverage_nightly, coverage(off))] |
| fn generate_test_vector() -> Vec<TestCase> { |
| CipherSuite::all() |
| .map(|cipher_suite| { |
| let provider = test_cipher_suite_provider(cipher_suite); |
| let encryption_secret = random_bytes(provider.kdf_extract_size()); |
| |
| let mut secret_tree = |
| SecretTree::new(16, Zeroizing::new(encryption_secret.clone())); |
| |
| TestCase { |
| cipher_suite: cipher_suite.into(), |
| encryption_secret, |
| ratchets: get_ratchet_data(&mut secret_tree, cipher_suite), |
| } |
| }) |
| .collect() |
| } |
| |
| #[cfg(mls_build_async)] |
| fn generate_test_vector() -> Vec<TestCase> { |
| panic!("Tests cannot be generated in async mode"); |
| } |
| |
| #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] |
| async fn test_secret_tree_test_vectors() { |
| let test_cases: Vec<TestCase> = load_test_case_json!(secret_tree, generate_test_vector()); |
| |
| for case in test_cases { |
| let Some(cs_provider) = try_test_cipher_suite_provider(case.cipher_suite) else { |
| continue; |
| }; |
| |
| let mut secret_tree = SecretTree::new(16, Zeroizing::new(case.encryption_secret)); |
| let ratchet_data = get_ratchet_data(&mut secret_tree, cs_provider.cipher_suite()).await; |
| |
| assert_eq!(ratchet_data, case.ratchets); |
| } |
| } |
| } |
| |
| #[cfg(all(test, feature = "rfc_compliant", feature = "std"))] |
| mod interop_tests { |
| #[cfg(not(mls_build_async))] |
| use mls_rs_core::crypto::{CipherSuite, CipherSuiteProvider}; |
| use zeroize::Zeroizing; |
| |
| use crate::{ |
| crypto::test_utils::try_test_cipher_suite_provider, |
| group::{ciphertext_processor::InteropSenderData, secret_tree::KeyType}, |
| }; |
| |
| use super::SecretTree; |
| |
| #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] |
| async fn interop_test_vector() { |
| // The test vector can be found here https://github.com/mlswg/mls-implementations/blob/main/test-vectors/secret-tree.json |
| let test_cases = load_interop_test_cases(); |
| |
| for case in test_cases { |
| let Some(cs) = try_test_cipher_suite_provider(case.cipher_suite) else { |
| continue; |
| }; |
| |
| case.sender_data.verify(&cs).await; |
| |
| let mut tree = SecretTree::new( |
| case.leaves.len() as u32, |
| Zeroizing::new(case.encryption_secret), |
| ); |
| |
| for (index, leaves) in case.leaves.iter().enumerate() { |
| for leaf in leaves.iter() { |
| let key = tree |
| .message_key_generation( |
| &cs, |
| (index as u32) * 2, |
| KeyType::Application, |
| leaf.generation, |
| ) |
| .await |
| .unwrap(); |
| |
| assert_eq!(key.key.to_vec(), leaf.application_key); |
| assert_eq!(key.nonce.to_vec(), leaf.application_nonce); |
| |
| let key = tree |
| .message_key_generation( |
| &cs, |
| (index as u32) * 2, |
| KeyType::Handshake, |
| leaf.generation, |
| ) |
| .await |
| .unwrap(); |
| |
| assert_eq!(key.key.to_vec(), leaf.handshake_key); |
| assert_eq!(key.nonce.to_vec(), leaf.handshake_nonce); |
| } |
| } |
| } |
| } |
| |
| #[derive(Debug, serde::Serialize, serde::Deserialize)] |
| struct InteropTestCase { |
| cipher_suite: u16, |
| #[serde(with = "hex::serde")] |
| encryption_secret: Vec<u8>, |
| sender_data: InteropSenderData, |
| leaves: Vec<Vec<InteropLeaf>>, |
| } |
| |
| #[derive(Debug, serde::Serialize, serde::Deserialize)] |
| struct InteropLeaf { |
| generation: u32, |
| #[serde(with = "hex::serde")] |
| application_key: Vec<u8>, |
| #[serde(with = "hex::serde")] |
| application_nonce: Vec<u8>, |
| #[serde(with = "hex::serde")] |
| handshake_key: Vec<u8>, |
| #[serde(with = "hex::serde")] |
| handshake_nonce: Vec<u8>, |
| } |
| |
| fn load_interop_test_cases() -> Vec<InteropTestCase> { |
| load_test_case_json!(secret_tree_interop, generate_test_vector()) |
| } |
| |
| #[cfg(not(mls_build_async))] |
| #[cfg_attr(coverage_nightly, coverage(off))] |
| fn generate_test_vector() -> Vec<InteropTestCase> { |
| let mut test_cases = vec![]; |
| |
| for cs in CipherSuite::all() { |
| let Some(cs) = try_test_cipher_suite_provider(*cs) else { |
| continue; |
| }; |
| |
| let gens = [0, 15]; |
| let tree_sizes = [1, 8, 32]; |
| |
| for n_leaves in tree_sizes { |
| let encryption_secret = cs.random_bytes_vec(cs.kdf_extract_size()).unwrap(); |
| |
| let mut tree = SecretTree::new(n_leaves, Zeroizing::new(encryption_secret.clone())); |
| |
| let leaves = (0..n_leaves) |
| .map(|leaf| { |
| gens.into_iter() |
| .map(|gen| { |
| let index = leaf * 2u32; |
| |
| let handshake_key = tree |
| .message_key_generation(&cs, index, KeyType::Handshake, gen) |
| .unwrap(); |
| |
| let app_key = tree |
| .message_key_generation(&cs, index, KeyType::Application, gen) |
| .unwrap(); |
| |
| InteropLeaf { |
| generation: gen, |
| application_key: app_key.key.to_vec(), |
| application_nonce: app_key.nonce.to_vec(), |
| handshake_key: handshake_key.key.to_vec(), |
| handshake_nonce: handshake_key.nonce.to_vec(), |
| } |
| }) |
| .collect() |
| }) |
| .collect(); |
| |
| let case = InteropTestCase { |
| cipher_suite: *cs.cipher_suite(), |
| encryption_secret, |
| sender_data: InteropSenderData::new(&cs), |
| leaves, |
| }; |
| |
| test_cases.push(case); |
| } |
| } |
| |
| test_cases |
| } |
| |
| #[cfg(mls_build_async)] |
| fn generate_test_vector() -> Vec<InteropTestCase> { |
| panic!("Tests cannot be generated in async mode"); |
| } |
| } |