| // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. |
| // Copyright by contributors to this project. |
| // SPDX-License-Identifier: (Apache-2.0 OR MIT) |
| |
| use crate::client::MlsError; |
| use crate::crypto::{CipherSuiteProvider, SignatureSecretKey}; |
| use crate::group::GroupContext; |
| use crate::identity::SigningIdentity; |
| use crate::iter::wrap_iter; |
| use crate::tree_kem::math as tree_math; |
| use alloc::vec; |
| use alloc::vec::Vec; |
| use itertools::Itertools; |
| use mls_rs_codec::MlsEncode; |
| use tree_math::{CopathNode, TreeIndex}; |
| |
| #[cfg(all(not(mls_build_async), feature = "rayon"))] |
| use {crate::iter::ParallelIteratorExt, rayon::prelude::*}; |
| |
| #[cfg(mls_build_async)] |
| use futures::{StreamExt, TryStreamExt}; |
| |
| #[cfg(feature = "std")] |
| use std::collections::HashSet; |
| |
| use super::hpke_encryption::HpkeEncryptable; |
| use super::leaf_node::ConfigProperties; |
| use super::node::NodeTypeResolver; |
| use super::{ |
| node::{LeafIndex, NodeIndex}, |
| path_secret::{PathSecret, PathSecretGenerator}, |
| TreeKemPrivate, TreeKemPublic, UpdatePath, UpdatePathNode, ValidatedUpdatePath, |
| }; |
| |
| #[cfg(test)] |
| use crate::{group::CommitModifiers, signer::Signable}; |
| |
| pub struct TreeKem<'a> { |
| tree_kem_public: &'a mut TreeKemPublic, |
| private_key: &'a mut TreeKemPrivate, |
| } |
| |
| pub struct EncapGeneration { |
| pub update_path: UpdatePath, |
| pub path_secrets: Vec<Option<PathSecret>>, |
| pub commit_secret: PathSecret, |
| } |
| |
| impl<'a> TreeKem<'a> { |
| pub fn new( |
| tree_kem_public: &'a mut TreeKemPublic, |
| private_key: &'a mut TreeKemPrivate, |
| ) -> Self { |
| TreeKem { |
| tree_kem_public, |
| private_key, |
| } |
| } |
| |
| #[allow(clippy::too_many_arguments)] |
| #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] |
| pub async fn encap<P>( |
| self, |
| context: &mut GroupContext, |
| excluding: &[LeafIndex], |
| signer: &SignatureSecretKey, |
| update_leaf_properties: ConfigProperties, |
| signing_identity: Option<SigningIdentity>, |
| cipher_suite_provider: &P, |
| #[cfg(test)] commit_modifiers: &CommitModifiers, |
| ) -> Result<EncapGeneration, MlsError> |
| where |
| P: CipherSuiteProvider + Send + Sync, |
| { |
| let self_index = self.private_key.self_index; |
| let path = self.tree_kem_public.nodes.direct_copath(self_index); |
| let filtered = self.tree_kem_public.nodes.filtered(self_index)?; |
| |
| self.private_key.secret_keys.resize(path.len() + 1, None); |
| |
| let mut secret_generator = PathSecretGenerator::new(cipher_suite_provider); |
| let mut path_secrets = vec![]; |
| |
| for (i, (node, f)) in path.iter().zip(&filtered).enumerate() { |
| if !f { |
| let secret = secret_generator.next_secret().await?; |
| |
| let (secret_key, public_key) = |
| secret.to_hpke_key_pair(cipher_suite_provider).await?; |
| |
| self.private_key.secret_keys[i + 1] = Some(secret_key); |
| self.tree_kem_public.update_node(public_key, node.path)?; |
| path_secrets.push(Some(secret)); |
| } else { |
| self.private_key.secret_keys[i + 1] = None; |
| path_secrets.push(None); |
| } |
| } |
| |
| #[cfg(test)] |
| (commit_modifiers.modify_tree)(self.tree_kem_public); |
| |
| self.tree_kem_public |
| .update_parent_hashes(self_index, false, cipher_suite_provider) |
| .await?; |
| |
| let update_path_leaf = { |
| let own_leaf = self.tree_kem_public.nodes.borrow_as_leaf_mut(self_index)?; |
| |
| self.private_key.secret_keys[0] = Some( |
| own_leaf |
| .commit( |
| cipher_suite_provider, |
| &context.group_id, |
| *self_index, |
| update_leaf_properties, |
| signing_identity, |
| signer, |
| ) |
| .await?, |
| ); |
| |
| #[cfg(test)] |
| if let Some(signer) = (commit_modifiers.modify_leaf)(own_leaf, signer) { |
| let context = &(context.group_id.as_slice(), *self_index).into(); |
| |
| own_leaf |
| .sign(cipher_suite_provider, &signer, context) |
| .await |
| .unwrap(); |
| } |
| |
| own_leaf.clone() |
| }; |
| |
| // Tree modifications are all done so we can update the tree hash and encrypt with the new context |
| self.tree_kem_public |
| .update_hashes(&[self_index], cipher_suite_provider) |
| .await?; |
| |
| context.tree_hash = self |
| .tree_kem_public |
| .tree_hash(cipher_suite_provider) |
| .await?; |
| |
| let context_bytes = context.mls_encode_to_vec()?; |
| |
| let node_updates = self |
| .encrypt_path_secrets( |
| path, |
| &path_secrets, |
| &context_bytes, |
| cipher_suite_provider, |
| excluding, |
| ) |
| .await?; |
| |
| #[cfg(test)] |
| let node_updates = (commit_modifiers.modify_path)(node_updates); |
| |
| // Create an update path with the new node and parent node updates |
| let update_path = UpdatePath { |
| leaf_node: update_path_leaf, |
| nodes: node_updates, |
| }; |
| |
| Ok(EncapGeneration { |
| update_path, |
| path_secrets, |
| commit_secret: secret_generator.next_secret().await?, |
| }) |
| } |
| |
| #[cfg(any(mls_build_async, not(feature = "rayon")))] |
| #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] |
| async fn encrypt_path_secrets<P: CipherSuiteProvider>( |
| &self, |
| path: Vec<CopathNode<NodeIndex>>, |
| path_secrets: &[Option<PathSecret>], |
| context_bytes: &[u8], |
| cipher_suite: &P, |
| excluding: &[LeafIndex], |
| ) -> Result<Vec<UpdatePathNode>, MlsError> { |
| let excluding = excluding.iter().copied().map(NodeIndex::from); |
| |
| #[cfg(feature = "std")] |
| let excluding = excluding.collect::<HashSet<NodeIndex>>(); |
| #[cfg(not(feature = "std"))] |
| let excluding = excluding.collect::<Vec<NodeIndex>>(); |
| |
| let mut node_updates = Vec::new(); |
| |
| for (index, path_secret) in path.into_iter().zip(path_secrets.iter()) { |
| if let Some(path_secret) = path_secret { |
| node_updates.push( |
| self.encrypt_copath_node_resolution( |
| cipher_suite, |
| path_secret, |
| index.copath, |
| context_bytes, |
| &excluding, |
| ) |
| .await?, |
| ); |
| } |
| } |
| |
| Ok(node_updates) |
| } |
| |
| #[cfg(all(not(mls_build_async), feature = "rayon"))] |
| fn encrypt_path_secrets<P: CipherSuiteProvider>( |
| &self, |
| path: Vec<CopathNode<NodeIndex>>, |
| path_secrets: &[Option<PathSecret>], |
| context_bytes: &[u8], |
| cipher_suite: &P, |
| excluding: &[LeafIndex], |
| ) -> Result<Vec<UpdatePathNode>, MlsError> { |
| let excluding = excluding.iter().copied().map(NodeIndex::from); |
| |
| #[cfg(feature = "std")] |
| let excluding = excluding.collect::<HashSet<NodeIndex>>(); |
| #[cfg(not(feature = "std"))] |
| let excluding = excluding.collect::<Vec<NodeIndex>>(); |
| |
| path.into_par_iter() |
| .zip(path_secrets.par_iter()) |
| .filter_map(|(node, path_secret)| { |
| path_secret.as_ref().map(|path_secret| { |
| self.encrypt_copath_node_resolution( |
| cipher_suite, |
| path_secret, |
| node.copath, |
| context_bytes, |
| &excluding, |
| ) |
| }) |
| }) |
| .collect() |
| } |
| |
| #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] |
| pub async fn decap<CP>( |
| self, |
| sender_index: LeafIndex, |
| update_path: &ValidatedUpdatePath, |
| added_leaves: &[LeafIndex], |
| context_bytes: &[u8], |
| cipher_suite_provider: &CP, |
| ) -> Result<PathSecret, MlsError> |
| where |
| CP: CipherSuiteProvider, |
| { |
| let self_index = self.private_key.self_index; |
| |
| let lca_index = |
| tree_math::leaf_lca_level(self_index.into(), sender_index.into()) as usize - 2; |
| |
| let mut path = self.tree_kem_public.nodes.direct_copath(self_index); |
| let leaf = CopathNode::new(self_index.into(), 0); |
| path.insert(0, leaf); |
| let resolved_pos = self.find_resolved_pos(&path, lca_index)?; |
| |
| let ct_pos = |
| self.find_ciphertext_pos(path[lca_index].path, path[resolved_pos].path, added_leaves)?; |
| |
| let lca_node = update_path.nodes[lca_index] |
| .as_ref() |
| .ok_or(MlsError::LcaNotFoundInDirectPath)?; |
| |
| let ct = lca_node |
| .encrypted_path_secret |
| .get(ct_pos) |
| .ok_or(MlsError::LcaNotFoundInDirectPath)?; |
| |
| let secret = self.private_key.secret_keys[resolved_pos] |
| .as_ref() |
| .ok_or(MlsError::UpdateErrorNoSecretKey)?; |
| |
| let public = self |
| .tree_kem_public |
| .nodes |
| .borrow_node(path[resolved_pos].path)? |
| .as_ref() |
| .ok_or(MlsError::UpdateErrorNoSecretKey)? |
| .public_key(); |
| |
| let lca_path_secret = |
| PathSecret::decrypt(cipher_suite_provider, secret, public, context_bytes, ct).await?; |
| |
| // Derive the rest of the secrets for the tree and assign to the proper nodes |
| let mut node_secret_gen = |
| PathSecretGenerator::starting_with(cipher_suite_provider, lca_path_secret); |
| |
| // Update secrets based on the decrypted path secret in the update |
| self.private_key.secret_keys.resize(path.len() + 1, None); |
| |
| for (i, update) in update_path.nodes.iter().enumerate().skip(lca_index) { |
| if let Some(update) = update { |
| let secret = node_secret_gen.next_secret().await?; |
| |
| // Verify the private key we calculated properly matches the public key we inserted into the tree. This guarantees |
| // that we will be able to decrypt later. |
| let (hpke_private, hpke_public) = |
| secret.to_hpke_key_pair(cipher_suite_provider).await?; |
| |
| if hpke_public != update.public_key { |
| return Err(MlsError::PubKeyMismatch); |
| } |
| |
| self.private_key.secret_keys[i + 1] = Some(hpke_private); |
| } else { |
| self.private_key.secret_keys[i + 1] = None; |
| } |
| } |
| |
| node_secret_gen.next_secret().await |
| } |
| |
| #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] |
| async fn encrypt_copath_node_resolution<P: CipherSuiteProvider>( |
| &self, |
| cipher_suite_provider: &P, |
| path_secret: &PathSecret, |
| copath_index: NodeIndex, |
| context: &[u8], |
| #[cfg(feature = "std")] excluding: &HashSet<NodeIndex>, |
| #[cfg(not(feature = "std"))] excluding: &[NodeIndex], |
| ) -> Result<UpdatePathNode, MlsError> { |
| let reso = self |
| .tree_kem_public |
| .nodes |
| .get_resolution_index(copath_index)?; |
| |
| let make_ctxt = |idx| async move { |
| let node = self |
| .tree_kem_public |
| .nodes |
| .borrow_node(idx)? |
| .as_non_empty()?; |
| |
| path_secret |
| .encrypt(cipher_suite_provider, node.public_key(), context) |
| .await |
| }; |
| |
| let ctxts = wrap_iter(reso).filter(|&idx| async move { !excluding.contains(&idx) }); |
| |
| #[cfg(not(mls_build_async))] |
| let ctxts = ctxts.map(make_ctxt); |
| |
| #[cfg(mls_build_async)] |
| let ctxts = ctxts.then(make_ctxt); |
| |
| let ctxts = ctxts.try_collect().await?; |
| |
| let path_index = copath_index |
| .parent_sibling(&self.tree_kem_public.total_leaf_count()) |
| .ok_or(MlsError::ExpectedNode)? |
| .parent; |
| |
| Ok(UpdatePathNode { |
| public_key: self |
| .tree_kem_public |
| .nodes |
| .borrow_as_parent(path_index)? |
| .public_key |
| .clone(), |
| encrypted_path_secret: ctxts, |
| }) |
| } |
| |
| #[inline] |
| fn find_resolved_pos( |
| &self, |
| path: &[CopathNode<NodeIndex>], |
| mut lca_index: usize, |
| ) -> Result<usize, MlsError> { |
| while self.tree_kem_public.nodes.is_blank(path[lca_index].path)? { |
| lca_index -= 1; |
| } |
| |
| // If we don't have the key, we should be an unmerged leaf at the resolved node. (If |
| // we're not, an error will be thrown later.) |
| if self.private_key.secret_keys[lca_index].is_none() { |
| lca_index = 0; |
| } |
| |
| Ok(lca_index) |
| } |
| |
| #[inline] |
| fn find_ciphertext_pos( |
| &self, |
| lca: NodeIndex, |
| resolved: NodeIndex, |
| excluding: &[LeafIndex], |
| ) -> Result<usize, MlsError> { |
| let reso = self.tree_kem_public.nodes.get_resolution_index(lca)?; |
| |
| let (ct_pos, _) = reso |
| .iter() |
| .filter(|idx| **idx % 2 == 1 || !excluding.contains(&LeafIndex(**idx / 2))) |
| .find_position(|idx| idx == &&resolved) |
| .ok_or(MlsError::UpdateErrorNoSecretKey)?; |
| |
| Ok(ct_pos) |
| } |
| } |
| |
| #[cfg(test)] |
| mod tests { |
| use super::{tree_math, TreeKem}; |
| use crate::{ |
| cipher_suite::CipherSuite, |
| client::test_utils::TEST_CIPHER_SUITE, |
| crypto::test_utils::{test_cipher_suite_provider, TestCryptoProvider}, |
| extension::test_utils::TestExtension, |
| group::test_utils::{get_test_group_context, random_bytes}, |
| identity::basic::BasicIdentityProvider, |
| tree_kem::{ |
| leaf_node::{ |
| test_utils::{get_basic_test_node_sig_key, get_test_capabilities}, |
| ConfigProperties, |
| }, |
| node::LeafIndex, |
| Capabilities, TreeKemPrivate, TreeKemPublic, UpdatePath, ValidatedUpdatePath, |
| }, |
| ExtensionList, |
| }; |
| use alloc::{format, vec, vec::Vec}; |
| use mls_rs_codec::MlsEncode; |
| use mls_rs_core::crypto::CipherSuiteProvider; |
| use tree_math::TreeIndex; |
| |
| // Verify that the tree is in the correct state after generating an update path |
| fn verify_tree_update_path( |
| tree: &TreeKemPublic, |
| update_path: &UpdatePath, |
| index: LeafIndex, |
| capabilities: Option<Capabilities>, |
| extensions: Option<ExtensionList>, |
| ) { |
| // Make sure the update path is based on the direct path of the sender |
| let direct_path = tree.nodes.direct_copath(index); |
| |
| for (i, n) in direct_path.iter().enumerate() { |
| assert_eq!( |
| *tree |
| .nodes |
| .borrow_node(n.path) |
| .unwrap() |
| .as_ref() |
| .unwrap() |
| .public_key(), |
| update_path.nodes[i].public_key |
| ); |
| } |
| |
| // Verify that the leaf from the update path has been installed |
| assert_eq!( |
| tree.nodes.borrow_as_leaf(index).unwrap(), |
| &update_path.leaf_node |
| ); |
| |
| // Verify that updated capabilities were installed |
| if let Some(capabilities) = capabilities { |
| assert_eq!(update_path.leaf_node.capabilities, capabilities); |
| } |
| |
| // Verify that update extensions were installed |
| if let Some(extensions) = extensions { |
| assert_eq!(update_path.leaf_node.extensions, extensions); |
| } |
| |
| // Verify that we have a public keys up to the root |
| let root = tree.total_leaf_count().root(); |
| assert!(tree.nodes.borrow_node(root).unwrap().is_some()); |
| } |
| |
| #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] |
| async fn verify_tree_private_path( |
| cipher_suite: &CipherSuite, |
| public_tree: &TreeKemPublic, |
| private_tree: &TreeKemPrivate, |
| index: LeafIndex, |
| ) { |
| let provider = test_cipher_suite_provider(*cipher_suite); |
| |
| assert_eq!(private_tree.self_index, index); |
| |
| // Make sure we have private values along the direct path, and the public keys match |
| let path_iter = public_tree |
| .nodes |
| .direct_copath(index) |
| .into_iter() |
| .enumerate(); |
| |
| for (i, n) in path_iter { |
| let secret_key = private_tree.secret_keys[i + 1].as_ref().unwrap(); |
| |
| let public_key = public_tree |
| .nodes |
| .borrow_node(n.path) |
| .unwrap() |
| .as_ref() |
| .unwrap() |
| .public_key(); |
| |
| let test_data = random_bytes(32); |
| |
| let sealed = provider |
| .hpke_seal(public_key, &[], None, &test_data) |
| .await |
| .unwrap(); |
| |
| let opened = provider |
| .hpke_open(&sealed, secret_key, public_key, &[], None) |
| .await |
| .unwrap(); |
| |
| assert_eq!(test_data, opened); |
| } |
| } |
| |
| #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] |
| async fn encap_decap( |
| cipher_suite: CipherSuite, |
| size: usize, |
| capabilities: Option<Capabilities>, |
| extensions: Option<ExtensionList>, |
| ) { |
| let cipher_suite_provider = test_cipher_suite_provider(cipher_suite); |
| |
| // Generate signing keys and key package generations, and private keys for multiple |
| // participants in order to set up state |
| |
| let mut leaf_nodes = Vec::new(); |
| let mut private_keys = Vec::new(); |
| |
| for index in 1..size { |
| let (leaf_node, hpke_secret, _) = |
| get_basic_test_node_sig_key(cipher_suite, &format!("{index}")).await; |
| |
| let private_key = TreeKemPrivate::new_self_leaf(LeafIndex(index as u32), hpke_secret); |
| |
| leaf_nodes.push(leaf_node); |
| private_keys.push(private_key); |
| } |
| |
| let (encap_node, encap_hpke_secret, encap_signer) = |
| get_basic_test_node_sig_key(cipher_suite, "encap").await; |
| |
| // Build a test tree we can clone for all leaf nodes |
| let (mut test_tree, mut encap_private_key) = TreeKemPublic::derive( |
| encap_node, |
| encap_hpke_secret, |
| &BasicIdentityProvider, |
| &Default::default(), |
| ) |
| .await |
| .unwrap(); |
| |
| test_tree |
| .add_leaves(leaf_nodes, &BasicIdentityProvider, &cipher_suite_provider) |
| .await |
| .unwrap(); |
| |
| // Clone the tree for the first leaf, generate a new key package for that leaf |
| let mut encap_tree = test_tree.clone(); |
| |
| let update_leaf_properties = ConfigProperties { |
| capabilities: capabilities.clone().unwrap_or_else(get_test_capabilities), |
| extensions: extensions.clone().unwrap_or_default(), |
| }; |
| |
| // Perform the encap function |
| let encap_gen = TreeKem::new(&mut encap_tree, &mut encap_private_key) |
| .encap( |
| &mut get_test_group_context(42, cipher_suite).await, |
| &[], |
| &encap_signer, |
| update_leaf_properties, |
| None, |
| &cipher_suite_provider, |
| #[cfg(test)] |
| &Default::default(), |
| ) |
| .await |
| .unwrap(); |
| |
| // Verify that the state of the tree matches the produced update path |
| verify_tree_update_path( |
| &encap_tree, |
| &encap_gen.update_path, |
| LeafIndex(0), |
| capabilities, |
| extensions, |
| ); |
| |
| // Verify that the private key matches the data in the public key |
| verify_tree_private_path(&cipher_suite, &encap_tree, &encap_private_key, LeafIndex(0)) |
| .await; |
| |
| let filtered = test_tree.nodes.filtered(LeafIndex(0)).unwrap(); |
| let mut unfiltered_nodes = vec![None; filtered.len()]; |
| filtered |
| .into_iter() |
| .enumerate() |
| .filter(|(_, f)| !*f) |
| .zip(encap_gen.update_path.nodes.iter()) |
| .for_each(|((i, _), node)| { |
| unfiltered_nodes[i] = Some(node.clone()); |
| }); |
| |
| // Apply the update path to the rest of the leaf nodes using the decap function |
| let validated_update_path = ValidatedUpdatePath { |
| leaf_node: encap_gen.update_path.leaf_node, |
| nodes: unfiltered_nodes, |
| }; |
| |
| encap_tree |
| .update_hashes(&[LeafIndex(0)], &cipher_suite_provider) |
| .await |
| .unwrap(); |
| |
| let mut receiver_trees: Vec<TreeKemPublic> = (1..size).map(|_| test_tree.clone()).collect(); |
| |
| for (i, tree) in receiver_trees.iter_mut().enumerate() { |
| tree.apply_update_path( |
| LeafIndex(0), |
| &validated_update_path, |
| &Default::default(), |
| BasicIdentityProvider, |
| &cipher_suite_provider, |
| ) |
| .await |
| .unwrap(); |
| |
| let mut context = get_test_group_context(42, cipher_suite).await; |
| context.tree_hash = tree.tree_hash(&cipher_suite_provider).await.unwrap(); |
| |
| TreeKem::new(tree, &mut private_keys[i]) |
| .decap( |
| LeafIndex(0), |
| &validated_update_path, |
| &[], |
| &context.mls_encode_to_vec().unwrap(), |
| &cipher_suite_provider, |
| ) |
| .await |
| .unwrap(); |
| |
| tree.update_hashes(&[LeafIndex(0)], &cipher_suite_provider) |
| .await |
| .unwrap(); |
| |
| assert_eq!(tree, &encap_tree); |
| } |
| } |
| |
| #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] |
| async fn test_encap_decap() { |
| for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() { |
| encap_decap(cipher_suite, 10, None, None).await; |
| } |
| } |
| |
| #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] |
| async fn test_encap_capabilities() { |
| let cipher_suite = TEST_CIPHER_SUITE; |
| let mut capabilities = get_test_capabilities(); |
| capabilities.extensions.push(42.into()); |
| |
| encap_decap(cipher_suite, 10, Some(capabilities.clone()), None).await; |
| } |
| |
| #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] |
| async fn test_encap_extensions() { |
| let cipher_suite = TEST_CIPHER_SUITE; |
| let mut extensions = ExtensionList::default(); |
| extensions.set_from(TestExtension { foo: 10 }).unwrap(); |
| |
| encap_decap(cipher_suite, 10, None, Some(extensions)).await; |
| } |
| |
| #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] |
| async fn test_encap_capabilities_extensions() { |
| let cipher_suite = TEST_CIPHER_SUITE; |
| let mut capabilities = get_test_capabilities(); |
| capabilities.extensions.push(42.into()); |
| |
| let mut extensions = ExtensionList::default(); |
| extensions.set_from(TestExtension { foo: 10 }).unwrap(); |
| |
| encap_decap(cipher_suite, 10, Some(capabilities), Some(extensions)).await; |
| } |
| } |