| // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. |
| // Copyright by contributors to this project. |
| // SPDX-License-Identifier: (Apache-2.0 OR MIT) |
| |
| use alloc::vec; |
| use alloc::vec::Vec; |
| #[cfg(feature = "std")] |
| use core::fmt::Display; |
| use itertools::Itertools; |
| use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize}; |
| use mls_rs_core::extension::ExtensionList; |
| |
| use mls_rs_core::{error::IntoAnyError, identity::IdentityProvider}; |
| |
| #[cfg(feature = "tree_index")] |
| use mls_rs_core::identity::SigningIdentity; |
| |
| use math as tree_math; |
| use node::{LeafIndex, NodeIndex, NodeVec}; |
| |
| use self::leaf_node::LeafNode; |
| |
| use crate::client::MlsError; |
| use crate::crypto::{self, CipherSuiteProvider, HpkeSecretKey}; |
| |
| #[cfg(feature = "by_ref_proposal")] |
| use crate::group::proposal::{AddProposal, UpdateProposal}; |
| |
| #[cfg(any(test, feature = "by_ref_proposal"))] |
| use crate::group::proposal::RemoveProposal; |
| |
| use crate::group::proposal_filter::ProposalBundle; |
| use crate::tree_kem::tree_hash::TreeHashes; |
| |
| mod capabilities; |
| pub(crate) mod hpke_encryption; |
| mod lifetime; |
| pub(crate) mod math; |
| pub mod node; |
| pub mod parent_hash; |
| pub mod path_secret; |
| mod private; |
| mod tree_hash; |
| pub mod tree_validator; |
| pub mod update_path; |
| |
| pub use capabilities::*; |
| pub use lifetime::*; |
| pub(crate) use private::*; |
| pub use update_path::*; |
| |
| use tree_index::*; |
| |
| pub mod kem; |
| pub mod leaf_node; |
| pub mod leaf_node_validator; |
| mod tree_index; |
| |
| #[cfg(feature = "std")] |
| pub(crate) mod tree_utils; |
| |
| #[cfg(test)] |
| mod interop_test_vectors; |
| |
| #[cfg(feature = "custom_proposal")] |
| use crate::group::proposal::ProposalType; |
| |
| #[derive(Clone, Debug, MlsEncode, MlsDecode, MlsSize, Default)] |
| #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] |
| pub struct TreeKemPublic { |
| #[cfg(feature = "tree_index")] |
| #[cfg_attr(feature = "serde", serde(skip))] |
| index: TreeIndex, |
| pub(crate) nodes: NodeVec, |
| tree_hashes: TreeHashes, |
| } |
| |
| impl PartialEq for TreeKemPublic { |
| fn eq(&self, other: &Self) -> bool { |
| self.nodes == other.nodes |
| } |
| } |
| |
| impl TreeKemPublic { |
| pub fn new() -> TreeKemPublic { |
| Default::default() |
| } |
| |
| #[cfg_attr(not(feature = "tree_index"), allow(unused))] |
| #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] |
| pub(crate) async fn import_node_data<IP>( |
| nodes: NodeVec, |
| identity_provider: &IP, |
| extensions: &ExtensionList, |
| ) -> Result<TreeKemPublic, MlsError> |
| where |
| IP: IdentityProvider, |
| { |
| let mut tree = TreeKemPublic { |
| nodes, |
| ..Default::default() |
| }; |
| |
| #[cfg(feature = "tree_index")] |
| tree.initialize_index_if_necessary(identity_provider, extensions) |
| .await?; |
| |
| Ok(tree) |
| } |
| |
| #[cfg(feature = "tree_index")] |
| #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] |
| pub(crate) async fn initialize_index_if_necessary<IP: IdentityProvider>( |
| &mut self, |
| identity_provider: &IP, |
| extensions: &ExtensionList, |
| ) -> Result<(), MlsError> { |
| if !self.index.is_initialized() { |
| self.index = TreeIndex::new(); |
| |
| for (leaf_index, leaf) in self.nodes.non_empty_leaves() { |
| index_insert( |
| &mut self.index, |
| leaf, |
| leaf_index, |
| identity_provider, |
| extensions, |
| ) |
| .await?; |
| } |
| } |
| |
| Ok(()) |
| } |
| |
| #[cfg(feature = "tree_index")] |
| pub(crate) fn get_leaf_node_with_identity(&self, identity: &[u8]) -> Option<LeafIndex> { |
| self.index.get_leaf_index_with_identity(identity) |
| } |
| |
| #[cfg(not(feature = "tree_index"))] |
| #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] |
| pub(crate) async fn get_leaf_node_with_identity<I: IdentityProvider>( |
| &self, |
| identity: &[u8], |
| id_provider: &I, |
| extensions: &ExtensionList, |
| ) -> Result<Option<LeafIndex>, MlsError> { |
| for (i, leaf) in self.nodes.non_empty_leaves() { |
| let leaf_id = id_provider |
| .identity(&leaf.signing_identity, extensions) |
| .await |
| .map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))?; |
| |
| if leaf_id == identity { |
| return Ok(Some(i)); |
| } |
| } |
| |
| Ok(None) |
| } |
| |
| #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] |
| pub async fn derive<I: IdentityProvider>( |
| leaf_node: LeafNode, |
| secret_key: HpkeSecretKey, |
| identity_provider: &I, |
| extensions: &ExtensionList, |
| ) -> Result<(TreeKemPublic, TreeKemPrivate), MlsError> { |
| let mut public_tree = TreeKemPublic::new(); |
| |
| public_tree |
| .add_leaf(leaf_node, identity_provider, extensions, None) |
| .await?; |
| |
| let private_tree = TreeKemPrivate::new_self_leaf(LeafIndex(0), secret_key); |
| |
| Ok((public_tree, private_tree)) |
| } |
| |
| pub fn total_leaf_count(&self) -> u32 { |
| self.nodes.total_leaf_count() |
| } |
| |
| #[cfg(any(test, all(feature = "custom_proposal", feature = "tree_index")))] |
| pub fn occupied_leaf_count(&self) -> u32 { |
| self.nodes.occupied_leaf_count() |
| } |
| |
| pub fn get_leaf_node(&self, index: LeafIndex) -> Result<&LeafNode, MlsError> { |
| self.nodes.borrow_as_leaf(index) |
| } |
| |
| pub fn find_leaf_node(&self, leaf_node: &LeafNode) -> Option<LeafIndex> { |
| self.nodes.non_empty_leaves().find_map( |
| |(index, node)| { |
| if node == leaf_node { |
| Some(index) |
| } else { |
| None |
| } |
| }, |
| ) |
| } |
| |
| #[cfg(feature = "custom_proposal")] |
| pub fn can_support_proposal(&self, proposal_type: ProposalType) -> bool { |
| #[cfg(feature = "tree_index")] |
| return self.index.count_supporting_proposal(proposal_type) == self.occupied_leaf_count(); |
| |
| #[cfg(not(feature = "tree_index"))] |
| self.nodes |
| .non_empty_leaves() |
| .all(|(_, l)| l.capabilities.proposals.contains(&proposal_type)) |
| } |
| |
| #[cfg(test)] |
| #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] |
| pub async fn add_leaves<I: IdentityProvider, CP: CipherSuiteProvider>( |
| &mut self, |
| leaf_nodes: Vec<LeafNode>, |
| id_provider: &I, |
| cipher_suite_provider: &CP, |
| ) -> Result<Vec<LeafIndex>, MlsError> { |
| let mut start = LeafIndex(0); |
| let mut added = vec![]; |
| |
| for leaf in leaf_nodes.into_iter() { |
| start = self |
| .add_leaf(leaf, id_provider, &Default::default(), Some(start)) |
| .await?; |
| added.push(start); |
| } |
| |
| self.update_hashes(&added, cipher_suite_provider).await?; |
| |
| Ok(added) |
| } |
| |
| pub fn non_empty_leaves(&self) -> impl Iterator<Item = (LeafIndex, &LeafNode)> + '_ { |
| self.nodes.non_empty_leaves() |
| } |
| |
| #[cfg(feature = "prior_epoch")] |
| pub fn leaves(&self) -> impl Iterator<Item = Option<&LeafNode>> + '_ { |
| self.nodes.leaves() |
| } |
| |
| pub(crate) fn update_node( |
| &mut self, |
| pub_key: crypto::HpkePublicKey, |
| index: NodeIndex, |
| ) -> Result<(), MlsError> { |
| self.nodes |
| .borrow_or_fill_node_as_parent(index, &pub_key) |
| .map(|p| { |
| p.public_key = pub_key; |
| p.unmerged_leaves = vec![]; |
| }) |
| } |
| |
| #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] |
| pub(crate) async fn apply_update_path<IP, CP>( |
| &mut self, |
| sender: LeafIndex, |
| update_path: &ValidatedUpdatePath, |
| extensions: &ExtensionList, |
| identity_provider: IP, |
| cipher_suite_provider: &CP, |
| ) -> Result<(), MlsError> |
| where |
| IP: IdentityProvider, |
| CP: CipherSuiteProvider, |
| { |
| // Install the new leaf node |
| let existing_leaf = self.nodes.borrow_as_leaf_mut(sender)?; |
| |
| #[cfg(feature = "tree_index")] |
| let original_leaf_node = existing_leaf.clone(); |
| |
| #[cfg(feature = "tree_index")] |
| let original_identity = identity_provider |
| .identity(&original_leaf_node.signing_identity, extensions) |
| .await |
| .map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))?; |
| |
| *existing_leaf = update_path.leaf_node.clone(); |
| |
| // Update the rest of the nodes on the direct path |
| let path = self.nodes.direct_copath(sender); |
| |
| for (node, pn) in update_path.nodes.iter().zip(path) { |
| node.as_ref() |
| .map(|n| self.update_node(n.public_key.clone(), pn.path)) |
| .transpose()?; |
| } |
| |
| #[cfg(feature = "tree_index")] |
| self.index.remove(&original_leaf_node, &original_identity); |
| |
| index_insert( |
| #[cfg(feature = "tree_index")] |
| &mut self.index, |
| #[cfg(not(feature = "tree_index"))] |
| &self.nodes, |
| &update_path.leaf_node, |
| sender, |
| &identity_provider, |
| extensions, |
| ) |
| .await?; |
| |
| // Verify the parent hash of the new sender leaf node and update the parent hash values |
| // in the local tree |
| self.update_parent_hashes(sender, true, cipher_suite_provider) |
| .await?; |
| |
| Ok(()) |
| } |
| |
| fn update_unmerged(&mut self, index: LeafIndex) -> Result<(), MlsError> { |
| // For a given leaf index, find parent nodes and add the leaf to the unmerged leaf |
| self.nodes.direct_copath(index).into_iter().for_each(|i| { |
| if let Ok(p) = self.nodes.borrow_as_parent_mut(i.path) { |
| p.unmerged_leaves.push(index) |
| } |
| }); |
| |
| Ok(()) |
| } |
| |
| #[cfg(feature = "by_ref_proposal")] |
| #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] |
| pub async fn batch_edit<I, CP>( |
| &mut self, |
| proposal_bundle: &mut ProposalBundle, |
| extensions: &ExtensionList, |
| id_provider: &I, |
| cipher_suite_provider: &CP, |
| filter: bool, |
| ) -> Result<Vec<LeafIndex>, MlsError> |
| where |
| I: IdentityProvider, |
| CP: CipherSuiteProvider, |
| { |
| // Apply removes (they commute with updates because they don't touch the same leaves) |
| for i in (0..proposal_bundle.remove_proposals().len()).rev() { |
| let index = proposal_bundle.remove_proposals()[i].proposal.to_remove; |
| let res = self.nodes.blank_leaf_node(index); |
| |
| if res.is_ok() { |
| // This shouldn't fail if `blank_leaf_node` succedded. |
| self.nodes.blank_direct_path(index)?; |
| } |
| |
| #[cfg(feature = "tree_index")] |
| if let Ok(old_leaf) = &res { |
| // If this fails, it's not because the proposal is bad. |
| let identity = |
| identity(&old_leaf.signing_identity, id_provider, extensions).await?; |
| |
| self.index.remove(old_leaf, &identity); |
| } |
| |
| if proposal_bundle.remove_proposals()[i].is_by_value() || !filter { |
| res?; |
| } else if res.is_err() { |
| proposal_bundle.remove::<RemoveProposal>(i); |
| } |
| } |
| |
| // Remove from the tree old leaves from updates |
| let mut partial_updates = vec![]; |
| let senders = proposal_bundle.update_senders.iter().copied(); |
| |
| for (i, (p, index)) in proposal_bundle.updates.iter().zip(senders).enumerate() { |
| let new_leaf = p.proposal.leaf_node.clone(); |
| |
| match self.nodes.blank_leaf_node(index) { |
| Ok(old_leaf) => { |
| #[cfg(feature = "tree_index")] |
| let old_id = |
| identity(&old_leaf.signing_identity, id_provider, extensions).await?; |
| |
| #[cfg(feature = "tree_index")] |
| self.index.remove(&old_leaf, &old_id); |
| |
| partial_updates.push((index, old_leaf, new_leaf, i)); |
| } |
| _ => { |
| if !filter || !p.is_by_reference() { |
| return Err(MlsError::UpdatingNonExistingMember); |
| } |
| } |
| } |
| } |
| |
| #[cfg(feature = "tree_index")] |
| let index_clone = self.index.clone(); |
| |
| let mut removed_leaves = vec![]; |
| let mut updated_indices = vec![]; |
| let mut bad_indices = vec![]; |
| |
| // Apply updates one by one. If there's an update which we can't apply or revert, we revert |
| // all updates. |
| for (index, old_leaf, new_leaf, i) in partial_updates.into_iter() { |
| #[cfg(feature = "tree_index")] |
| let res = |
| index_insert(&mut self.index, &new_leaf, index, id_provider, extensions).await; |
| |
| #[cfg(not(feature = "tree_index"))] |
| let res = index_insert(&self.nodes, &new_leaf, index, id_provider, extensions).await; |
| |
| let err = res.is_err(); |
| |
| if !filter { |
| res?; |
| } |
| |
| if !err { |
| self.nodes.insert_leaf(index, new_leaf); |
| removed_leaves.push(old_leaf); |
| updated_indices.push(index); |
| } else { |
| #[cfg(feature = "tree_index")] |
| let res = |
| index_insert(&mut self.index, &old_leaf, index, id_provider, extensions).await; |
| |
| #[cfg(not(feature = "tree_index"))] |
| let res = |
| index_insert(&self.nodes, &old_leaf, index, id_provider, extensions).await; |
| |
| if res.is_ok() { |
| self.nodes.insert_leaf(index, old_leaf); |
| bad_indices.push(i); |
| } else { |
| // Revert all updates and stop. We're already in the "filter" case, so we don't throw an error. |
| #[cfg(feature = "tree_index")] |
| { |
| self.index = index_clone; |
| } |
| |
| removed_leaves |
| .into_iter() |
| .zip(updated_indices.iter()) |
| .for_each(|(leaf, index)| self.nodes.insert_leaf(*index, leaf)); |
| |
| updated_indices = vec![]; |
| break; |
| } |
| } |
| } |
| |
| // If we managed to update something, blank direct paths |
| updated_indices |
| .iter() |
| .try_for_each(|index| self.nodes.blank_direct_path(*index).map(|_| ()))?; |
| |
| // Remove rejected updates from applied proposals |
| if updated_indices.is_empty() { |
| // This takes care of the "revert all" scenario |
| proposal_bundle.updates = vec![]; |
| } else { |
| for i in bad_indices.into_iter().rev() { |
| proposal_bundle.remove::<UpdateProposal>(i); |
| proposal_bundle.update_senders.remove(i); |
| } |
| } |
| |
| // Apply adds |
| let mut start = LeafIndex(0); |
| let mut added = vec![]; |
| let mut bad_indexes = vec![]; |
| |
| for i in 0..proposal_bundle.additions.len() { |
| let leaf = proposal_bundle.additions[i] |
| .proposal |
| .key_package |
| .leaf_node |
| .clone(); |
| |
| let res = self |
| .add_leaf(leaf, id_provider, extensions, Some(start)) |
| .await; |
| |
| if let Ok(index) = res { |
| start = index; |
| added.push(start); |
| } else if proposal_bundle.additions[i].is_by_value() || !filter { |
| res?; |
| } else { |
| bad_indexes.push(i); |
| } |
| } |
| |
| for i in bad_indexes.into_iter().rev() { |
| proposal_bundle.remove::<AddProposal>(i); |
| } |
| |
| self.nodes.trim(); |
| |
| let updated_leaves = proposal_bundle |
| .remove_proposals() |
| .iter() |
| .map(|p| p.proposal.to_remove) |
| .chain(updated_indices) |
| .chain(added.iter().copied()) |
| .collect_vec(); |
| |
| self.update_hashes(&updated_leaves, cipher_suite_provider) |
| .await?; |
| |
| Ok(added) |
| } |
| |
| #[cfg(not(feature = "by_ref_proposal"))] |
| #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] |
| pub async fn batch_edit_lite<I, CP>( |
| &mut self, |
| proposal_bundle: &ProposalBundle, |
| extensions: &ExtensionList, |
| id_provider: &I, |
| cipher_suite_provider: &CP, |
| ) -> Result<Vec<LeafIndex>, MlsError> |
| where |
| I: IdentityProvider, |
| CP: CipherSuiteProvider, |
| { |
| // Apply removes |
| for p in &proposal_bundle.removals { |
| let index = p.proposal.to_remove; |
| |
| #[cfg(feature = "tree_index")] |
| { |
| // If this fails, it's not because the proposal is bad. |
| let old_leaf = self.nodes.blank_leaf_node(index)?; |
| |
| let identity = |
| identity(&old_leaf.signing_identity, id_provider, extensions).await?; |
| |
| self.index.remove(&old_leaf, &identity); |
| } |
| |
| #[cfg(not(feature = "tree_index"))] |
| self.nodes.blank_leaf_node(index)?; |
| |
| self.nodes.blank_direct_path(index)?; |
| } |
| |
| // Apply adds |
| let mut start = LeafIndex(0); |
| let mut added = vec![]; |
| |
| for p in &proposal_bundle.additions { |
| let leaf = p.proposal.key_package.leaf_node.clone(); |
| start = self |
| .add_leaf(leaf, id_provider, extensions, Some(start)) |
| .await?; |
| added.push(start); |
| } |
| |
| self.nodes.trim(); |
| |
| let updated_leaves = proposal_bundle |
| .remove_proposals() |
| .iter() |
| .map(|p| p.proposal.to_remove) |
| .chain(added.iter().copied()) |
| .collect_vec(); |
| |
| self.update_hashes(&updated_leaves, cipher_suite_provider) |
| .await?; |
| |
| Ok(added) |
| } |
| |
| #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] |
| pub(crate) async fn add_leaf<I: IdentityProvider>( |
| &mut self, |
| leaf: LeafNode, |
| id_provider: &I, |
| extensions: &ExtensionList, |
| start: Option<LeafIndex>, |
| ) -> Result<LeafIndex, MlsError> { |
| let index = self.nodes.next_empty_leaf(start.unwrap_or(LeafIndex(0))); |
| |
| #[cfg(feature = "tree_index")] |
| index_insert(&mut self.index, &leaf, index, id_provider, extensions).await?; |
| |
| #[cfg(not(feature = "tree_index"))] |
| index_insert(&self.nodes, &leaf, index, id_provider, extensions).await?; |
| |
| self.nodes.insert_leaf(index, leaf); |
| self.update_unmerged(index)?; |
| |
| Ok(index) |
| } |
| } |
| |
| #[cfg(feature = "tree_index")] |
| #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] |
| async fn identity<I: IdentityProvider>( |
| signing_id: &SigningIdentity, |
| provider: &I, |
| extensions: &ExtensionList, |
| ) -> Result<Vec<u8>, MlsError> { |
| provider |
| .identity(signing_id, extensions) |
| .await |
| .map_err(|e| MlsError::IdentityProviderError(e.into_any_error())) |
| } |
| |
| #[cfg(feature = "std")] |
| impl Display for TreeKemPublic { |
| fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { |
| write!(f, "{}", tree_utils::build_ascii_tree(&self.nodes)) |
| } |
| } |
| |
| #[cfg(test)] |
| use crate::group::{proposal::Proposal, proposal_filter::ProposalSource, Sender}; |
| |
| #[cfg(test)] |
| impl TreeKemPublic { |
| #[cfg(feature = "by_ref_proposal")] |
| #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] |
| pub async fn update_leaf<I, CP>( |
| &mut self, |
| leaf_index: u32, |
| leaf_node: LeafNode, |
| identity_provider: &I, |
| cipher_suite_provider: &CP, |
| ) -> Result<(), MlsError> |
| where |
| I: IdentityProvider, |
| CP: CipherSuiteProvider, |
| { |
| let p = Proposal::Update(UpdateProposal { leaf_node }); |
| |
| let mut bundle = ProposalBundle::default(); |
| bundle.add(p, Sender::Member(leaf_index), ProposalSource::ByValue); |
| bundle.update_senders = vec![LeafIndex(leaf_index)]; |
| |
| self.batch_edit( |
| &mut bundle, |
| &Default::default(), |
| identity_provider, |
| cipher_suite_provider, |
| true, |
| ) |
| .await?; |
| |
| Ok(()) |
| } |
| |
| #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] |
| pub async fn remove_leaves<I, CP>( |
| &mut self, |
| indexes: Vec<LeafIndex>, |
| identity_provider: &I, |
| cipher_suite_provider: &CP, |
| ) -> Result<Vec<(LeafIndex, LeafNode)>, MlsError> |
| where |
| I: IdentityProvider, |
| CP: CipherSuiteProvider, |
| { |
| let old_tree = self.clone(); |
| |
| let proposals = indexes |
| .iter() |
| .copied() |
| .map(|to_remove| Proposal::Remove(RemoveProposal { to_remove })); |
| |
| let mut bundle = ProposalBundle::default(); |
| |
| for p in proposals { |
| bundle.add(p, Sender::Member(0), ProposalSource::ByValue); |
| } |
| |
| #[cfg(feature = "by_ref_proposal")] |
| self.batch_edit( |
| &mut bundle, |
| &Default::default(), |
| identity_provider, |
| cipher_suite_provider, |
| true, |
| ) |
| .await?; |
| |
| #[cfg(not(feature = "by_ref_proposal"))] |
| self.batch_edit_lite( |
| &bundle, |
| &Default::default(), |
| identity_provider, |
| cipher_suite_provider, |
| ) |
| .await?; |
| |
| bundle |
| .removals |
| .iter() |
| .map(|p| { |
| let index = p.proposal.to_remove; |
| let leaf = old_tree.get_leaf_node(index)?.clone(); |
| Ok((index, leaf)) |
| }) |
| .collect() |
| } |
| |
| pub fn get_leaf_nodes(&self) -> Vec<&LeafNode> { |
| self.nodes.non_empty_leaves().map(|(_, l)| l).collect() |
| } |
| } |
| |
| #[cfg(test)] |
| pub(crate) mod test_utils { |
| use crate::crypto::test_utils::TestCryptoProvider; |
| use crate::signer::Signable; |
| use alloc::vec::Vec; |
| use alloc::{format, vec}; |
| use mls_rs_core::crypto::CipherSuiteProvider; |
| use mls_rs_core::group::Capabilities; |
| use mls_rs_core::identity::BasicCredential; |
| |
| use crate::identity::test_utils::get_test_signing_identity; |
| use crate::{ |
| cipher_suite::CipherSuite, |
| crypto::{HpkeSecretKey, SignatureSecretKey}, |
| identity::basic::BasicIdentityProvider, |
| tree_kem::leaf_node::test_utils::get_basic_test_node_sig_key, |
| }; |
| |
| use super::leaf_node::{ConfigProperties, LeafNodeSigningContext}; |
| use super::node::LeafIndex; |
| use super::Lifetime; |
| use super::{ |
| leaf_node::{test_utils::get_basic_test_node, LeafNode}, |
| TreeKemPrivate, TreeKemPublic, |
| }; |
| |
| #[derive(Debug)] |
| pub(crate) struct TestTree { |
| pub public: TreeKemPublic, |
| pub private: TreeKemPrivate, |
| pub creator_leaf: LeafNode, |
| pub creator_signing_key: SignatureSecretKey, |
| pub creator_hpke_secret: HpkeSecretKey, |
| } |
| |
| #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] |
| pub(crate) async fn get_test_tree(cipher_suite: CipherSuite) -> TestTree { |
| let (creator_leaf, creator_hpke_secret, creator_signing_key) = |
| get_basic_test_node_sig_key(cipher_suite, "creator").await; |
| |
| let (test_public, test_private) = TreeKemPublic::derive( |
| creator_leaf.clone(), |
| creator_hpke_secret.clone(), |
| &BasicIdentityProvider, |
| &Default::default(), |
| ) |
| .await |
| .unwrap(); |
| |
| TestTree { |
| public: test_public, |
| private: test_private, |
| creator_leaf, |
| creator_signing_key, |
| creator_hpke_secret, |
| } |
| } |
| |
| #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] |
| pub async fn get_test_leaf_nodes(cipher_suite: CipherSuite) -> Vec<LeafNode> { |
| [ |
| get_basic_test_node(cipher_suite, "A").await, |
| get_basic_test_node(cipher_suite, "B").await, |
| get_basic_test_node(cipher_suite, "C").await, |
| ] |
| .to_vec() |
| } |
| |
| impl TreeKemPublic { |
| #[cfg(feature = "tree_index")] |
| pub fn equal_internals(&self, other: &TreeKemPublic) -> bool { |
| self.tree_hashes == other.tree_hashes && self.index == other.index |
| } |
| } |
| |
| #[derive(Debug, Clone)] |
| pub struct TreeWithSigners { |
| pub tree: TreeKemPublic, |
| pub signers: Vec<Option<SignatureSecretKey>>, |
| pub group_id: Vec<u8>, |
| } |
| |
| impl TreeWithSigners { |
| #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] |
| pub async fn make_full_tree<P: CipherSuiteProvider>( |
| n_leaves: u32, |
| cs: &P, |
| ) -> TreeWithSigners { |
| let mut tree = TreeWithSigners { |
| tree: TreeKemPublic::new(), |
| signers: vec![], |
| group_id: cs.random_bytes_vec(cs.kdf_extract_size()).unwrap(), |
| }; |
| |
| tree.add_member("Alice", cs).await; |
| |
| // A adds B, B adds C, C adds D etc. |
| for i in 1..n_leaves { |
| tree.add_member(&format!("Alice{i}"), cs).await; |
| tree.update_committer_path(i - 1, cs).await; |
| } |
| |
| tree |
| } |
| |
| #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] |
| pub async fn add_member<P: CipherSuiteProvider>(&mut self, name: &str, cs: &P) { |
| let (leaf, signer) = make_leaf(name, cs).await; |
| let index = self.tree.nodes.next_empty_leaf(LeafIndex(0)); |
| self.tree.nodes.insert_leaf(index, leaf); |
| self.tree.update_unmerged(index).unwrap(); |
| let index = *index as usize; |
| |
| match self.signers.len() { |
| l if l == index => self.signers.push(Some(signer)), |
| l if l > index => self.signers[index] = Some(signer), |
| _ => panic!("signer tree size mismatch"), |
| } |
| } |
| |
| #[cfg(feature = "rfc_compliant")] |
| #[cfg_attr(coverage_nightly, coverage(off))] |
| pub fn remove_member(&mut self, member: u32) { |
| self.tree |
| .nodes |
| .blank_direct_path(LeafIndex(member)) |
| .unwrap(); |
| |
| self.tree.nodes.blank_leaf_node(LeafIndex(member)).unwrap(); |
| |
| *self |
| .signers |
| .get_mut(member as usize) |
| .expect("signer tree size mismatch") = None; |
| } |
| |
| #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] |
| pub async fn update_committer_path<P: CipherSuiteProvider>( |
| &mut self, |
| committer: u32, |
| cs: &P, |
| ) { |
| let committer = LeafIndex(committer); |
| |
| let path = self.tree.nodes.direct_copath(committer); |
| let filtered = self.tree.nodes.filtered(committer).unwrap(); |
| |
| for (n, f) in path.into_iter().zip(filtered) { |
| if !f { |
| self.tree |
| .update_node(cs.kem_generate().await.unwrap().1, n.path) |
| .unwrap(); |
| } |
| } |
| |
| self.tree.tree_hashes.current = vec![]; |
| self.tree.tree_hash(cs).await.unwrap(); |
| |
| self.tree |
| .update_parent_hashes(committer, false, cs) |
| .await |
| .unwrap(); |
| |
| self.tree.tree_hashes.current = vec![]; |
| self.tree.tree_hash(cs).await.unwrap(); |
| |
| let context = LeafNodeSigningContext { |
| group_id: Some(&self.group_id), |
| leaf_index: Some(*committer), |
| }; |
| |
| let signer = self.signers[*committer as usize].as_ref().unwrap(); |
| |
| self.tree |
| .nodes |
| .borrow_as_leaf_mut(committer) |
| .unwrap() |
| .sign(cs, signer, &context) |
| .await |
| .unwrap(); |
| |
| self.tree.tree_hashes.current = vec![]; |
| self.tree.tree_hash(cs).await.unwrap(); |
| } |
| } |
| |
| #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] |
| pub async fn make_leaf<P: CipherSuiteProvider>( |
| name: &str, |
| cs: &P, |
| ) -> (LeafNode, SignatureSecretKey) { |
| let (signing_identity, signature_key) = |
| get_test_signing_identity(cs.cipher_suite(), name.as_bytes()).await; |
| |
| let capabilities = Capabilities { |
| credentials: vec![BasicCredential::credential_type()], |
| cipher_suites: TestCryptoProvider::all_supported_cipher_suites(), |
| ..Default::default() |
| }; |
| |
| let properties = ConfigProperties { |
| capabilities, |
| extensions: Default::default(), |
| }; |
| |
| let (leaf, _) = LeafNode::generate( |
| cs, |
| properties, |
| signing_identity, |
| &signature_key, |
| Lifetime::years(1).unwrap(), |
| ) |
| .await |
| .unwrap(); |
| |
| (leaf, signature_key) |
| } |
| } |
| |
| #[cfg(test)] |
| mod tests { |
| use crate::client::test_utils::TEST_CIPHER_SUITE; |
| use crate::crypto::test_utils::{test_cipher_suite_provider, TestCryptoProvider}; |
| |
| #[cfg(feature = "custom_proposal")] |
| use crate::group::proposal::ProposalType; |
| |
| use crate::identity::basic::BasicIdentityProvider; |
| use crate::tree_kem::leaf_node::LeafNode; |
| use crate::tree_kem::node::{LeafIndex, Node, NodeIndex, NodeTypeResolver, Parent}; |
| use crate::tree_kem::parent_hash::ParentHash; |
| use crate::tree_kem::test_utils::{get_test_leaf_nodes, get_test_tree}; |
| use crate::tree_kem::{MlsError, TreeKemPublic}; |
| use alloc::borrow::ToOwned; |
| use alloc::vec; |
| use alloc::vec::Vec; |
| use assert_matches::assert_matches; |
| |
| #[cfg(feature = "by_ref_proposal")] |
| use alloc::boxed::Box; |
| |
| #[cfg(feature = "by_ref_proposal")] |
| use crate::{ |
| client::test_utils::TEST_PROTOCOL_VERSION, |
| group::{ |
| proposal::{Proposal, RemoveProposal, UpdateProposal}, |
| proposal_filter::{ProposalBundle, ProposalSource}, |
| proposal_ref::ProposalRef, |
| Sender, |
| }, |
| key_package::test_utils::test_key_package, |
| }; |
| |
| #[cfg(any(feature = "by_ref_proposal", feature = "custo_proposal"))] |
| use crate::tree_kem::leaf_node::test_utils::get_basic_test_node; |
| |
| #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] |
| async fn test_derive() { |
| for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() { |
| let test_tree = get_test_tree(cipher_suite).await; |
| |
| assert_eq!( |
| test_tree.public.nodes[0], |
| Some(Node::Leaf(test_tree.creator_leaf.clone())) |
| ); |
| |
| assert_eq!(test_tree.private.self_index, LeafIndex(0)); |
| |
| assert_eq!( |
| test_tree.private.secret_keys[0], |
| Some(test_tree.creator_hpke_secret) |
| ); |
| } |
| } |
| |
| #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] |
| async fn test_import_export() { |
| let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE); |
| let mut test_tree = get_test_tree(TEST_CIPHER_SUITE).await; |
| |
| let additional_key_packages = get_test_leaf_nodes(TEST_CIPHER_SUITE).await; |
| |
| test_tree |
| .public |
| .add_leaves( |
| additional_key_packages, |
| &BasicIdentityProvider, |
| &cipher_suite_provider, |
| ) |
| .await |
| .unwrap(); |
| |
| let imported = TreeKemPublic::import_node_data( |
| test_tree.public.nodes.clone(), |
| &BasicIdentityProvider, |
| &Default::default(), |
| ) |
| .await |
| .unwrap(); |
| |
| assert_eq!(test_tree.public.nodes, imported.nodes); |
| |
| #[cfg(feature = "tree_index")] |
| assert_eq!(test_tree.public.index, imported.index); |
| } |
| |
| #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] |
| async fn test_add_leaf() { |
| let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE); |
| let mut tree = TreeKemPublic::new(); |
| |
| let leaf_nodes = get_test_leaf_nodes(TEST_CIPHER_SUITE).await; |
| |
| let res = tree |
| .add_leaves( |
| leaf_nodes.clone(), |
| &BasicIdentityProvider, |
| &cipher_suite_provider, |
| ) |
| .await |
| .unwrap(); |
| |
| // The leaf count should be equal to the number of packages we added |
| assert_eq!(res.len(), leaf_nodes.len()); |
| assert_eq!(tree.occupied_leaf_count(), leaf_nodes.len() as u32); |
| |
| // Each added package should be at the proper index and searchable in the tree |
| res.into_iter().zip(leaf_nodes.clone()).for_each(|(r, kp)| { |
| assert_eq!(tree.get_leaf_node(r).unwrap(), &kp); |
| }); |
| |
| // Verify the underlying state |
| #[cfg(feature = "tree_index")] |
| assert_eq!(tree.index.len(), tree.occupied_leaf_count() as usize); |
| |
| assert_eq!(tree.nodes.len(), 5); |
| assert_eq!(tree.nodes[0], leaf_nodes[0].clone().into()); |
| assert_eq!(tree.nodes[1], None); |
| assert_eq!(tree.nodes[2], leaf_nodes[1].clone().into()); |
| assert_eq!(tree.nodes[3], None); |
| assert_eq!(tree.nodes[4], leaf_nodes[2].clone().into()); |
| } |
| |
| #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] |
| async fn test_get_key_packages() { |
| let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE); |
| let mut tree = TreeKemPublic::new(); |
| |
| let key_packages = get_test_leaf_nodes(TEST_CIPHER_SUITE).await; |
| |
| tree.add_leaves(key_packages, &BasicIdentityProvider, &cipher_suite_provider) |
| .await |
| .unwrap(); |
| |
| let key_packages = tree.get_leaf_nodes(); |
| assert_eq!(key_packages, key_packages.to_owned()); |
| } |
| |
| #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] |
| async fn test_add_leaf_duplicate() { |
| let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE); |
| let mut tree = TreeKemPublic::new(); |
| |
| let key_packages = get_test_leaf_nodes(TEST_CIPHER_SUITE).await; |
| |
| tree.add_leaves( |
| key_packages.clone(), |
| &BasicIdentityProvider, |
| &cipher_suite_provider, |
| ) |
| .await |
| .unwrap(); |
| |
| let res = tree |
| .add_leaves(key_packages, &BasicIdentityProvider, &cipher_suite_provider) |
| .await; |
| |
| assert_matches!(res, Err(MlsError::DuplicateLeafData(_))); |
| } |
| |
| #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] |
| async fn test_add_leaf_empty_leaf() { |
| let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE); |
| let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public; |
| let key_packages = get_test_leaf_nodes(TEST_CIPHER_SUITE).await; |
| |
| tree.add_leaves( |
| [key_packages[0].clone()].to_vec(), |
| &BasicIdentityProvider, |
| &cipher_suite_provider, |
| ) |
| .await |
| .unwrap(); |
| |
| tree.nodes[0] = None; // Set the original first node to none |
| // |
| tree.add_leaves( |
| [key_packages[1].clone()].to_vec(), |
| &BasicIdentityProvider, |
| &cipher_suite_provider, |
| ) |
| .await |
| .unwrap(); |
| |
| assert_eq!(tree.nodes[0], key_packages[1].clone().into()); |
| assert_eq!(tree.nodes[1], None); |
| assert_eq!(tree.nodes[2], key_packages[0].clone().into()); |
| assert_eq!(tree.nodes.len(), 3) |
| } |
| |
| #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] |
| async fn test_add_leaf_unmerged() { |
| let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE); |
| let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public; |
| let key_packages = get_test_leaf_nodes(TEST_CIPHER_SUITE).await; |
| |
| tree.add_leaves( |
| [key_packages[0].clone(), key_packages[1].clone()].to_vec(), |
| &BasicIdentityProvider, |
| &cipher_suite_provider, |
| ) |
| .await |
| .unwrap(); |
| |
| tree.nodes[3] = Parent { |
| public_key: vec![].into(), |
| parent_hash: ParentHash::empty(), |
| unmerged_leaves: vec![], |
| } |
| .into(); |
| |
| tree.add_leaves( |
| [key_packages[2].clone()].to_vec(), |
| &BasicIdentityProvider, |
| &cipher_suite_provider, |
| ) |
| .await |
| .unwrap(); |
| |
| assert_eq!( |
| tree.nodes[3].as_parent().unwrap().unmerged_leaves, |
| vec![LeafIndex(3)] |
| ) |
| } |
| |
| #[cfg(feature = "by_ref_proposal")] |
| #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] |
| async fn test_update_leaf() { |
| let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE); |
| // Create a tree |
| let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public; |
| |
| let key_packages = get_test_leaf_nodes(TEST_CIPHER_SUITE).await; |
| |
| tree.add_leaves(key_packages, &BasicIdentityProvider, &cipher_suite_provider) |
| .await |
| .unwrap(); |
| |
| // Add in parent nodes so we can detect them clearing after update |
| tree.nodes.direct_copath(LeafIndex(0)).iter().for_each(|n| { |
| tree.nodes |
| .borrow_or_fill_node_as_parent(n.path, &b"pub_key".to_vec().into()) |
| .unwrap(); |
| }); |
| |
| let original_size = tree.occupied_leaf_count(); |
| let original_leaf_index = LeafIndex(1); |
| |
| let updated_leaf = get_basic_test_node(TEST_CIPHER_SUITE, "A").await; |
| |
| tree.update_leaf( |
| *original_leaf_index, |
| updated_leaf.clone(), |
| &BasicIdentityProvider, |
| &cipher_suite_provider, |
| ) |
| .await |
| .unwrap(); |
| |
| // The tree should not have grown due to an update |
| assert_eq!(tree.occupied_leaf_count(), original_size); |
| |
| // The cache of tree package indexes should not have grown |
| #[cfg(feature = "tree_index")] |
| assert_eq!(tree.index.len() as u32, tree.occupied_leaf_count()); |
| |
| // The key package should be updated in the tree |
| assert_eq!( |
| tree.get_leaf_node(original_leaf_index).unwrap(), |
| &updated_leaf |
| ); |
| |
| // Verify that the direct path has been cleared |
| tree.nodes.direct_copath(LeafIndex(0)).iter().for_each(|n| { |
| assert!(tree.nodes[n.path as usize].is_none()); |
| }); |
| } |
| |
| #[cfg(feature = "by_ref_proposal")] |
| #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] |
| async fn test_update_leaf_not_found() { |
| let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE); |
| |
| // Create a tree |
| let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public; |
| |
| let key_packages = get_test_leaf_nodes(TEST_CIPHER_SUITE).await; |
| |
| tree.add_leaves(key_packages, &BasicIdentityProvider, &cipher_suite_provider) |
| .await |
| .unwrap(); |
| |
| let new_key_package = get_basic_test_node(TEST_CIPHER_SUITE, "new").await; |
| |
| let res = tree |
| .update_leaf( |
| 128, |
| new_key_package, |
| &BasicIdentityProvider, |
| &cipher_suite_provider, |
| ) |
| .await; |
| |
| assert_matches!(res, Err(MlsError::UpdatingNonExistingMember)); |
| } |
| |
| #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] |
| async fn test_remove_leaf() { |
| let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE); |
| |
| // Create a tree |
| let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public; |
| let key_packages = get_test_leaf_nodes(TEST_CIPHER_SUITE).await; |
| |
| let indexes = tree |
| .add_leaves( |
| key_packages.clone(), |
| &BasicIdentityProvider, |
| &cipher_suite_provider, |
| ) |
| .await |
| .unwrap(); |
| |
| let original_leaf_count = tree.occupied_leaf_count(); |
| |
| // Remove two leaves from the tree |
| let expected_result: Vec<(LeafIndex, LeafNode)> = |
| indexes.clone().into_iter().zip(key_packages).collect(); |
| |
| let res = tree |
| .remove_leaves( |
| indexes.clone(), |
| &BasicIdentityProvider, |
| &cipher_suite_provider, |
| ) |
| .await |
| .unwrap(); |
| |
| // The order may change |
| assert!(res.iter().all(|x| expected_result.contains(x))); |
| assert!(expected_result.iter().all(|x| res.contains(x))); |
| |
| // The leaves should be removed from the tree |
| assert_eq!( |
| tree.occupied_leaf_count(), |
| original_leaf_count - indexes.len() as u32 |
| ); |
| } |
| |
| #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] |
| async fn test_remove_leaf_middle() { |
| let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE); |
| |
| // Create a tree |
| let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public; |
| let leaf_nodes = get_test_leaf_nodes(TEST_CIPHER_SUITE).await; |
| |
| let to_remove = tree |
| .add_leaves( |
| leaf_nodes.clone(), |
| &BasicIdentityProvider, |
| &cipher_suite_provider, |
| ) |
| .await |
| .unwrap()[0]; |
| |
| let original_leaf_count = tree.occupied_leaf_count(); |
| |
| let res = tree |
| .remove_leaves( |
| vec![to_remove], |
| &BasicIdentityProvider, |
| &cipher_suite_provider, |
| ) |
| .await |
| .unwrap(); |
| |
| assert_eq!(res, vec![(to_remove, leaf_nodes[0].clone())]); |
| |
| // The leaf count should have been reduced by 1 |
| assert_eq!(tree.occupied_leaf_count(), original_leaf_count - 1); |
| |
| // There should be a blank in the tree |
| assert_eq!( |
| tree.nodes.get(NodeIndex::from(to_remove) as usize).unwrap(), |
| &None |
| ); |
| } |
| |
| #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] |
| async fn test_create_blanks() { |
| let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE); |
| |
| // Create a tree |
| let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public; |
| |
| let key_packages = get_test_leaf_nodes(TEST_CIPHER_SUITE).await; |
| |
| tree.add_leaves(key_packages, &BasicIdentityProvider, &cipher_suite_provider) |
| .await |
| .unwrap(); |
| |
| let original_leaf_count = tree.occupied_leaf_count(); |
| |
| let to_remove = vec![LeafIndex(2)]; |
| |
| // Remove the leaf from the tree |
| tree.remove_leaves(to_remove, &BasicIdentityProvider, &cipher_suite_provider) |
| .await |
| .unwrap(); |
| |
| // The occupied leaf count should have been reduced by 1 |
| assert_eq!(tree.occupied_leaf_count(), original_leaf_count - 1); |
| |
| // The total leaf count should remain unchanged |
| assert_eq!(tree.total_leaf_count(), original_leaf_count); |
| |
| // The location of key_packages[1] should now be blank |
| let removed_location = tree |
| .nodes |
| .get(NodeIndex::from(LeafIndex(2)) as usize) |
| .unwrap(); |
| |
| assert_eq!(removed_location, &None); |
| } |
| |
| #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] |
| async fn test_remove_leaf_failure() { |
| let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE); |
| |
| // Create a tree |
| let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public; |
| |
| let res = tree |
| .remove_leaves( |
| vec![LeafIndex(128)], |
| &BasicIdentityProvider, |
| &cipher_suite_provider, |
| ) |
| .await; |
| |
| assert_matches!(res, Err(MlsError::InvalidNodeIndex(256))); |
| } |
| |
| #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] |
| async fn test_find_leaf_node() { |
| let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE); |
| // Create a tree |
| let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public; |
| |
| let leaf_nodes = get_test_leaf_nodes(TEST_CIPHER_SUITE).await; |
| |
| tree.add_leaves( |
| leaf_nodes.clone(), |
| &BasicIdentityProvider, |
| &cipher_suite_provider, |
| ) |
| .await |
| .unwrap(); |
| |
| // Find each node |
| for (i, leaf_node) in leaf_nodes.iter().enumerate() { |
| let expected_index = LeafIndex(i as u32 + 1); |
| assert_eq!(tree.find_leaf_node(leaf_node), Some(expected_index)); |
| } |
| } |
| |
| // TODO add test for the lite version |
| #[cfg(feature = "by_ref_proposal")] |
| #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] |
| async fn batch_edit_works() { |
| let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE); |
| |
| let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public; |
| let leaf_nodes = get_test_leaf_nodes(TEST_CIPHER_SUITE).await; |
| |
| tree.add_leaves(leaf_nodes, &BasicIdentityProvider, &cipher_suite_provider) |
| .await |
| .unwrap(); |
| |
| let mut bundle = ProposalBundle::default(); |
| |
| let kp = test_key_package(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "D").await; |
| let add = Proposal::Add(Box::new(kp.into())); |
| |
| bundle.add(add, Sender::Member(0), ProposalSource::ByValue); |
| |
| let update = UpdateProposal { |
| leaf_node: get_basic_test_node(TEST_CIPHER_SUITE, "A").await, |
| }; |
| |
| let update = Proposal::Update(update); |
| let pref = ProposalRef::new_fake(vec![1, 2, 3]); |
| |
| bundle.add(update, Sender::Member(1), ProposalSource::ByReference(pref)); |
| |
| bundle.update_senders = vec![LeafIndex(1)]; |
| |
| let remove = RemoveProposal { |
| to_remove: LeafIndex(2), |
| }; |
| |
| let remove = Proposal::Remove(remove); |
| |
| bundle.add(remove, Sender::Member(0), ProposalSource::ByValue); |
| |
| tree.batch_edit( |
| &mut bundle, |
| &Default::default(), |
| &BasicIdentityProvider, |
| &cipher_suite_provider, |
| true, |
| ) |
| .await |
| .unwrap(); |
| |
| assert_eq!(bundle.add_proposals().len(), 1); |
| assert_eq!(bundle.remove_proposals().len(), 1); |
| assert_eq!(bundle.update_proposals().len(), 1); |
| } |
| |
| #[cfg(feature = "custom_proposal")] |
| #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] |
| async fn custom_proposal_support() { |
| let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE); |
| let mut tree = TreeKemPublic::new(); |
| |
| let test_proposal_type = ProposalType::from(42); |
| |
| let mut leaf_nodes = get_test_leaf_nodes(TEST_CIPHER_SUITE).await; |
| |
| leaf_nodes |
| .iter_mut() |
| .for_each(|n| n.capabilities.proposals.push(test_proposal_type)); |
| |
| tree.add_leaves(leaf_nodes, &BasicIdentityProvider, &cipher_suite_provider) |
| .await |
| .unwrap(); |
| |
| assert!(tree.can_support_proposal(test_proposal_type)); |
| assert!(!tree.can_support_proposal(ProposalType::from(43))); |
| |
| let test_node = get_basic_test_node(TEST_CIPHER_SUITE, "another").await; |
| |
| tree.add_leaves( |
| vec![test_node], |
| &BasicIdentityProvider, |
| &cipher_suite_provider, |
| ) |
| .await |
| .unwrap(); |
| |
| assert!(!tree.can_support_proposal(test_proposal_type)); |
| } |
| } |