| // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. |
| // Copyright by contributors to this project. |
| // SPDX-License-Identifier: (Apache-2.0 OR MIT) |
| |
| use super::*; |
| #[cfg(feature = "tree_index")] |
| use core::fmt::{self, Debug}; |
| |
| #[cfg(all(feature = "tree_index", feature = "custom_proposal"))] |
| use crate::group::proposal::ProposalType; |
| |
| #[cfg(feature = "tree_index")] |
| use crate::identity::CredentialType; |
| |
| #[cfg(feature = "tree_index")] |
| use mls_rs_core::crypto::SignaturePublicKey; |
| |
| #[cfg(all(feature = "tree_index", feature = "std"))] |
| use itertools::Itertools; |
| |
| #[cfg(all(feature = "tree_index", not(feature = "std")))] |
| use alloc::collections::{btree_map::Entry, BTreeMap}; |
| |
| #[cfg(all(feature = "tree_index", feature = "std"))] |
| use std::collections::{hash_map::Entry, HashMap}; |
| |
| #[cfg(all(feature = "tree_index", not(feature = "std")))] |
| use alloc::collections::BTreeSet; |
| |
| #[cfg(feature = "tree_index")] |
| use mls_rs_core::crypto::HpkePublicKey; |
| |
| #[cfg(feature = "tree_index")] |
| #[derive(Clone, Default, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode, Hash, PartialOrd, Ord)] |
| pub struct Identifier(#[mls_codec(with = "mls_rs_codec::byte_vec")] Vec<u8>); |
| |
| #[cfg(feature = "tree_index")] |
| impl Debug for Identifier { |
| fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
| mls_rs_core::debug::pretty_bytes(&self.0) |
| .named("Identifier") |
| .fmt(f) |
| } |
| } |
| |
| #[cfg(all(feature = "tree_index", feature = "std"))] |
| #[derive(Clone, Debug, Default, PartialEq, MlsSize, MlsEncode, MlsDecode)] |
| pub struct TreeIndex { |
| credential_signature_key: HashMap<SignaturePublicKey, LeafIndex>, |
| hpke_key: HashMap<HpkePublicKey, LeafIndex>, |
| identities: HashMap<Identifier, LeafIndex>, |
| credential_type_counters: HashMap<CredentialType, TypeCounter>, |
| #[cfg(feature = "custom_proposal")] |
| proposal_type_counter: HashMap<ProposalType, u32>, |
| } |
| |
| #[cfg(all(feature = "tree_index", not(feature = "std")))] |
| #[derive(Clone, Debug, Default, PartialEq, MlsSize, MlsEncode, MlsDecode)] |
| pub struct TreeIndex { |
| credential_signature_key: BTreeMap<SignaturePublicKey, LeafIndex>, |
| hpke_key: BTreeMap<HpkePublicKey, LeafIndex>, |
| identities: BTreeMap<Identifier, LeafIndex>, |
| credential_type_counters: BTreeMap<CredentialType, TypeCounter>, |
| #[cfg(feature = "custom_proposal")] |
| proposal_type_counter: BTreeMap<ProposalType, u32>, |
| } |
| |
| #[cfg(feature = "tree_index")] |
| #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] |
| pub(super) async fn index_insert<I: IdentityProvider>( |
| tree_index: &mut TreeIndex, |
| new_leaf: &LeafNode, |
| new_leaf_idx: LeafIndex, |
| id_provider: &I, |
| extensions: &ExtensionList, |
| ) -> Result<(), MlsError> { |
| let new_id = id_provider |
| .identity(&new_leaf.signing_identity, extensions) |
| .await |
| .map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))?; |
| |
| tree_index.insert(new_leaf_idx, new_leaf, new_id) |
| } |
| |
| #[cfg(not(feature = "tree_index"))] |
| #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] |
| pub(super) async fn index_insert<I: IdentityProvider>( |
| nodes: &NodeVec, |
| new_leaf: &LeafNode, |
| new_leaf_idx: LeafIndex, |
| id_provider: &I, |
| extensions: &ExtensionList, |
| ) -> Result<(), MlsError> { |
| let new_id = id_provider |
| .identity(&new_leaf.signing_identity, extensions) |
| .await |
| .map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))?; |
| |
| for (i, leaf) in nodes.non_empty_leaves().filter(|(i, _)| i != &new_leaf_idx) { |
| (new_leaf.public_key != leaf.public_key) |
| .then_some(()) |
| .ok_or(MlsError::DuplicateLeafData(*i))?; |
| |
| (new_leaf.signing_identity.signature_key != leaf.signing_identity.signature_key) |
| .then_some(()) |
| .ok_or(MlsError::DuplicateLeafData(*i))?; |
| |
| let id = id_provider |
| .identity(&leaf.signing_identity, extensions) |
| .await |
| .map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))?; |
| |
| (new_id != id) |
| .then_some(()) |
| .ok_or(MlsError::DuplicateLeafData(*i))?; |
| |
| let cred_type = leaf.signing_identity.credential.credential_type(); |
| |
| new_leaf |
| .capabilities |
| .credentials |
| .contains(&cred_type) |
| .then_some(()) |
| .ok_or(MlsError::InUseCredentialTypeUnsupportedByNewLeaf)?; |
| |
| let new_cred_type = new_leaf.signing_identity.credential.credential_type(); |
| |
| leaf.capabilities |
| .credentials |
| .contains(&new_cred_type) |
| .then_some(()) |
| .ok_or(MlsError::CredentialTypeOfNewLeafIsUnsupported)?; |
| } |
| |
| Ok(()) |
| } |
| |
| #[cfg(feature = "tree_index")] |
| impl TreeIndex { |
| pub fn new() -> Self { |
| Default::default() |
| } |
| |
| pub fn is_initialized(&self) -> bool { |
| !self.identities.is_empty() |
| } |
| |
| fn insert( |
| &mut self, |
| index: LeafIndex, |
| leaf_node: &LeafNode, |
| identity: Vec<u8>, |
| ) -> Result<(), MlsError> { |
| let old_leaf_count = self.credential_signature_key.len(); |
| |
| let pub_key = leaf_node.signing_identity.signature_key.clone(); |
| let credential_entry = self.credential_signature_key.entry(pub_key); |
| |
| if let Entry::Occupied(entry) = credential_entry { |
| return Err(MlsError::DuplicateLeafData(**entry.get())); |
| } |
| |
| let hpke_entry = self.hpke_key.entry(leaf_node.public_key.clone()); |
| |
| if let Entry::Occupied(entry) = hpke_entry { |
| return Err(MlsError::DuplicateLeafData(**entry.get())); |
| } |
| |
| let identity_entry = self.identities.entry(Identifier(identity)); |
| if let Entry::Occupied(entry) = identity_entry { |
| return Err(MlsError::DuplicateLeafData(**entry.get())); |
| } |
| |
| let in_use_cred_type_unsupported_by_new_leaf = self |
| .credential_type_counters |
| .iter() |
| .filter_map(|(cred_type, counters)| Some(*cred_type).filter(|_| counters.used > 0)) |
| .find(|cred_type| !leaf_node.capabilities.credentials.contains(cred_type)); |
| |
| if in_use_cred_type_unsupported_by_new_leaf.is_some() { |
| return Err(MlsError::InUseCredentialTypeUnsupportedByNewLeaf); |
| } |
| |
| let new_leaf_cred_type = leaf_node.signing_identity.credential.credential_type(); |
| |
| let cred_type_counters = self |
| .credential_type_counters |
| .entry(new_leaf_cred_type) |
| .or_default(); |
| |
| if cred_type_counters.supported != old_leaf_count as u32 { |
| return Err(MlsError::CredentialTypeOfNewLeafIsUnsupported); |
| } |
| |
| cred_type_counters.used += 1; |
| |
| let credential_type_iter = leaf_node.capabilities.credentials.iter().copied(); |
| |
| #[cfg(feature = "std")] |
| let credential_type_iter = credential_type_iter.unique(); |
| |
| #[cfg(not(feature = "std"))] |
| let credential_type_iter = credential_type_iter.collect::<BTreeSet<_>>().into_iter(); |
| |
| // Credential type counter updates |
| credential_type_iter.for_each(|cred_type| { |
| self.credential_type_counters |
| .entry(cred_type) |
| .or_default() |
| .supported += 1; |
| }); |
| |
| #[cfg(feature = "custom_proposal")] |
| { |
| let proposal_type_iter = leaf_node.capabilities.proposals.iter().copied(); |
| |
| #[cfg(feature = "std")] |
| let proposal_type_iter = proposal_type_iter.unique(); |
| |
| #[cfg(not(feature = "std"))] |
| let proposal_type_iter = proposal_type_iter.collect::<BTreeSet<_>>().into_iter(); |
| |
| // Proposal type counter update |
| proposal_type_iter.for_each(|proposal_type| { |
| *self.proposal_type_counter.entry(proposal_type).or_default() += 1; |
| }); |
| } |
| |
| identity_entry.or_insert(index); |
| credential_entry.or_insert(index); |
| hpke_entry.or_insert(index); |
| |
| Ok(()) |
| } |
| |
| pub(crate) fn get_leaf_index_with_identity(&self, identity: &[u8]) -> Option<LeafIndex> { |
| self.identities.get(&Identifier(identity.to_vec())).copied() |
| } |
| |
| pub fn remove(&mut self, leaf_node: &LeafNode, identity: &[u8]) { |
| let existed = self |
| .identities |
| .remove(&Identifier(identity.to_vec())) |
| .is_some(); |
| |
| self.credential_signature_key |
| .remove(&leaf_node.signing_identity.signature_key); |
| |
| self.hpke_key.remove(&leaf_node.public_key); |
| |
| if !existed { |
| return; |
| } |
| |
| // Decrement credential type counters |
| let leaf_cred_type = leaf_node.signing_identity.credential.credential_type(); |
| |
| if let Some(counters) = self.credential_type_counters.get_mut(&leaf_cred_type) { |
| counters.used -= 1; |
| } |
| |
| let credential_type_iter = leaf_node.capabilities.credentials.iter(); |
| |
| #[cfg(feature = "std")] |
| let credential_type_iter = credential_type_iter.unique(); |
| |
| #[cfg(not(feature = "std"))] |
| let credential_type_iter = credential_type_iter.collect::<BTreeSet<_>>().into_iter(); |
| |
| credential_type_iter.for_each(|cred_type| { |
| if let Some(counters) = self.credential_type_counters.get_mut(cred_type) { |
| counters.supported -= 1; |
| } |
| }); |
| |
| #[cfg(feature = "custom_proposal")] |
| { |
| let proposal_type_iter = leaf_node.capabilities.proposals.iter(); |
| |
| #[cfg(feature = "std")] |
| let proposal_type_iter = proposal_type_iter.unique(); |
| |
| #[cfg(not(feature = "std"))] |
| let proposal_type_iter = proposal_type_iter.collect::<BTreeSet<_>>().into_iter(); |
| |
| // Decrement proposal type counters |
| proposal_type_iter.for_each(|proposal_type| { |
| if let Some(supported) = self.proposal_type_counter.get_mut(proposal_type) { |
| *supported -= 1; |
| } |
| }) |
| } |
| } |
| |
| #[cfg(feature = "custom_proposal")] |
| pub fn count_supporting_proposal(&self, proposal_type: ProposalType) -> u32 { |
| self.proposal_type_counter |
| .get(&proposal_type) |
| .copied() |
| .unwrap_or_default() |
| } |
| |
| #[cfg(test)] |
| pub fn len(&self) -> usize { |
| self.credential_signature_key.len() |
| } |
| } |
| |
| #[cfg(feature = "tree_index")] |
| #[derive(Clone, Debug, Default, PartialEq, MlsEncode, MlsDecode, MlsSize)] |
| struct TypeCounter { |
| supported: u32, |
| used: u32, |
| } |
| |
| #[cfg(feature = "tree_index")] |
| #[cfg(test)] |
| mod tests { |
| use super::*; |
| use crate::{ |
| client::test_utils::TEST_CIPHER_SUITE, |
| tree_kem::leaf_node::test_utils::{get_basic_test_node, get_test_client_identity}, |
| }; |
| use alloc::format; |
| use assert_matches::assert_matches; |
| |
| #[derive(Clone, Debug)] |
| struct TestData { |
| pub leaf_node: LeafNode, |
| pub index: LeafIndex, |
| } |
| |
| #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] |
| async fn get_test_data(index: LeafIndex) -> TestData { |
| let cipher_suite = TEST_CIPHER_SUITE; |
| let leaf_node = get_basic_test_node(cipher_suite, &format!("foo{}", index.0)).await; |
| |
| TestData { leaf_node, index } |
| } |
| |
| #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] |
| async fn test_setup() -> (Vec<TestData>, TreeIndex) { |
| let mut test_data = Vec::new(); |
| |
| for i in 0..10 { |
| test_data.push(get_test_data(LeafIndex(i)).await); |
| } |
| |
| let mut test_index = TreeIndex::new(); |
| |
| test_data.clone().into_iter().for_each(|d| { |
| test_index |
| .insert( |
| d.index, |
| &d.leaf_node, |
| get_test_client_identity(&d.leaf_node), |
| ) |
| .unwrap() |
| }); |
| |
| (test_data, test_index) |
| } |
| |
| #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] |
| async fn test_insert() { |
| let (test_data, test_index) = test_setup().await; |
| |
| assert_eq!(test_index.credential_signature_key.len(), test_data.len()); |
| assert_eq!(test_index.hpke_key.len(), test_data.len()); |
| |
| test_data.into_iter().enumerate().for_each(|(i, d)| { |
| let pub_key = d.leaf_node.signing_identity.signature_key; |
| |
| assert_eq!( |
| test_index.credential_signature_key.get(&pub_key), |
| Some(&LeafIndex(i as u32)) |
| ); |
| |
| assert_eq!( |
| test_index.hpke_key.get(&d.leaf_node.public_key), |
| Some(&LeafIndex(i as u32)) |
| ); |
| }) |
| } |
| |
| #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] |
| async fn test_insert_duplicate_credential_key() { |
| let (test_data, mut test_index) = test_setup().await; |
| |
| let before_error = test_index.clone(); |
| |
| let mut new_key_package = get_basic_test_node(TEST_CIPHER_SUITE, "foo").await; |
| new_key_package.signing_identity = test_data[1].leaf_node.signing_identity.clone(); |
| |
| let res = test_index.insert( |
| test_data[1].index, |
| &new_key_package, |
| get_test_client_identity(&new_key_package), |
| ); |
| |
| assert_matches!(res, Err(MlsError::DuplicateLeafData(index)) |
| if index == *test_data[1].index); |
| |
| assert_eq!(before_error, test_index); |
| } |
| |
| #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] |
| async fn test_insert_duplicate_hpke_key() { |
| let cipher_suite = TEST_CIPHER_SUITE; |
| let (test_data, mut test_index) = test_setup().await; |
| let before_error = test_index.clone(); |
| |
| let mut new_leaf_node = get_basic_test_node(cipher_suite, "foo").await; |
| new_leaf_node.public_key = test_data[1].leaf_node.public_key.clone(); |
| |
| let res = test_index.insert( |
| test_data[1].index, |
| &new_leaf_node, |
| get_test_client_identity(&new_leaf_node), |
| ); |
| |
| assert_matches!(res, Err(MlsError::DuplicateLeafData(index)) |
| if index == *test_data[1].index); |
| |
| assert_eq!(before_error, test_index); |
| } |
| |
| #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] |
| async fn test_remove() { |
| let (test_data, mut test_index) = test_setup().await; |
| |
| test_index.remove( |
| &test_data[1].leaf_node, |
| &get_test_client_identity(&test_data[1].leaf_node), |
| ); |
| |
| assert_eq!( |
| test_index.credential_signature_key.len(), |
| test_data.len() - 1 |
| ); |
| |
| assert_eq!(test_index.hpke_key.len(), test_data.len() - 1); |
| |
| assert_eq!( |
| test_index |
| .credential_signature_key |
| .get(&test_data[1].leaf_node.signing_identity.signature_key), |
| None |
| ); |
| |
| assert_eq!( |
| test_index.hpke_key.get(&test_data[1].leaf_node.public_key), |
| None |
| ); |
| } |
| |
| #[cfg(feature = "custom_proposal")] |
| #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] |
| async fn custom_proposals() { |
| let test_proposal_id = ProposalType::new(42); |
| let other_proposal_id = ProposalType::new(45); |
| |
| let mut test_data_1 = get_test_data(LeafIndex(0)).await; |
| |
| test_data_1 |
| .leaf_node |
| .capabilities |
| .proposals |
| .push(test_proposal_id); |
| |
| let mut test_data_2 = get_test_data(LeafIndex(1)).await; |
| |
| test_data_2 |
| .leaf_node |
| .capabilities |
| .proposals |
| .push(test_proposal_id); |
| |
| test_data_2 |
| .leaf_node |
| .capabilities |
| .proposals |
| .push(other_proposal_id); |
| |
| let mut test_index = TreeIndex::new(); |
| |
| test_index |
| .insert(test_data_1.index, &test_data_1.leaf_node, vec![0]) |
| .unwrap(); |
| |
| assert_eq!(test_index.count_supporting_proposal(test_proposal_id), 1); |
| |
| test_index |
| .insert(test_data_2.index, &test_data_2.leaf_node, vec![1]) |
| .unwrap(); |
| |
| assert_eq!(test_index.count_supporting_proposal(test_proposal_id), 2); |
| assert_eq!(test_index.count_supporting_proposal(other_proposal_id), 1); |
| |
| test_index.remove(&test_data_2.leaf_node, &[1]); |
| |
| assert_eq!(test_index.count_supporting_proposal(test_proposal_id), 1); |
| assert_eq!(test_index.count_supporting_proposal(other_proposal_id), 0); |
| } |
| } |