| // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. |
| // Copyright by contributors to this project. |
| // SPDX-License-Identifier: (Apache-2.0 OR MIT) |
| |
| use crate::client::MlsError; |
| use crate::crypto::{CipherSuiteProvider, HpkePublicKey}; |
| use crate::tree_kem::math as tree_math; |
| use crate::tree_kem::node::{LeafIndex, Node, NodeIndex}; |
| use crate::tree_kem::TreeKemPublic; |
| use alloc::vec::Vec; |
| use core::{ |
| fmt::{self, Debug}, |
| ops::Deref, |
| }; |
| use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize}; |
| use mls_rs_core::error::IntoAnyError; |
| use tree_math::TreeIndex; |
| |
| use super::leaf_node::LeafNodeSource; |
| |
| #[cfg(feature = "std")] |
| use std::collections::HashSet; |
| |
| #[cfg(not(feature = "std"))] |
| use alloc::collections::BTreeSet; |
| |
| #[derive(Clone, Debug, MlsSize, MlsEncode)] |
| struct ParentHashInput<'a> { |
| #[mls_codec(with = "mls_rs_codec::byte_vec")] |
| public_key: &'a HpkePublicKey, |
| #[mls_codec(with = "mls_rs_codec::byte_vec")] |
| parent_hash: &'a [u8], |
| #[mls_codec(with = "mls_rs_codec::byte_vec")] |
| original_sibling_tree_hash: &'a [u8], |
| } |
| |
| #[derive(Clone, MlsSize, MlsEncode, MlsDecode, PartialEq, Eq)] |
| #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] |
| #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] |
| pub struct ParentHash( |
| #[mls_codec(with = "mls_rs_codec::byte_vec")] |
| #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))] |
| Vec<u8>, |
| ); |
| |
| impl Debug for ParentHash { |
| fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
| mls_rs_core::debug::pretty_bytes(&self.0) |
| .named("ParentHash") |
| .fmt(f) |
| } |
| } |
| |
| impl From<Vec<u8>> for ParentHash { |
| fn from(v: Vec<u8>) -> Self { |
| Self(v) |
| } |
| } |
| |
| impl Deref for ParentHash { |
| type Target = Vec<u8>; |
| |
| fn deref(&self) -> &Self::Target { |
| &self.0 |
| } |
| } |
| |
| impl ParentHash { |
| #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] |
| pub async fn new<P: CipherSuiteProvider>( |
| cipher_suite_provider: &P, |
| public_key: &HpkePublicKey, |
| parent_hash: &ParentHash, |
| original_sibling_tree_hash: &[u8], |
| ) -> Result<Self, MlsError> { |
| let input = ParentHashInput { |
| public_key, |
| parent_hash, |
| original_sibling_tree_hash, |
| }; |
| |
| let input_bytes = input.mls_encode_to_vec()?; |
| |
| let hash = cipher_suite_provider |
| .hash(&input_bytes) |
| .await |
| .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?; |
| |
| Ok(Self(hash)) |
| } |
| |
| pub fn empty() -> Self { |
| ParentHash(Vec::new()) |
| } |
| |
| pub fn matches(&self, hash: &ParentHash) -> bool { |
| //TODO: Constant time equals |
| hash == self |
| } |
| } |
| |
| impl Node { |
| fn get_parent_hash(&self) -> Option<ParentHash> { |
| match self { |
| Node::Parent(p) => Some(p.parent_hash.clone()), |
| Node::Leaf(l) => match &l.leaf_node_source { |
| LeafNodeSource::Commit(parent_hash) => Some(parent_hash.clone()), |
| _ => None, |
| }, |
| } |
| } |
| } |
| |
| impl TreeKemPublic { |
| #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] |
| async fn parent_hash_for_leaf<P: CipherSuiteProvider>( |
| &mut self, |
| cipher_suite_provider: &P, |
| index: LeafIndex, |
| ) -> Result<ParentHash, MlsError> { |
| let mut hash = ParentHash::empty(); |
| |
| for node in self.nodes.direct_copath(index).into_iter().rev() { |
| if self.nodes.is_resolution_empty(node.copath) { |
| continue; |
| } |
| |
| let parent = self.nodes.borrow_as_parent_mut(node.path)?; |
| |
| let calculated = ParentHash::new( |
| cipher_suite_provider, |
| &parent.public_key, |
| &hash, |
| &self.tree_hashes.current[node.copath as usize], |
| ) |
| .await?; |
| |
| (parent.parent_hash, hash) = (hash, calculated); |
| } |
| |
| Ok(hash) |
| } |
| |
| // Updates all of the required parent hash values, and returns the calculated parent hash value for the leaf node |
| // If an update path is provided, additionally verify that the calculated parent hash matches |
| #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] |
| pub(crate) async fn update_parent_hashes<P: CipherSuiteProvider>( |
| &mut self, |
| index: LeafIndex, |
| verify_leaf_hash: bool, |
| cipher_suite_provider: &P, |
| ) -> Result<(), MlsError> { |
| // First update the relevant original hashes used for parent hash computation. |
| self.update_hashes(&[index], cipher_suite_provider).await?; |
| |
| let leaf_hash = self |
| .parent_hash_for_leaf(cipher_suite_provider, index) |
| .await?; |
| |
| let leaf = self.nodes.borrow_as_leaf_mut(index)?; |
| |
| if verify_leaf_hash { |
| // Verify the parent hash of the new sender leaf node and update the parent hash values |
| // in the local tree |
| if let LeafNodeSource::Commit(parent_hash) = &leaf.leaf_node_source { |
| if !leaf_hash.matches(parent_hash) { |
| return Err(MlsError::ParentHashMismatch); |
| } |
| } else { |
| return Err(MlsError::InvalidLeafNodeSource); |
| } |
| } else { |
| leaf.leaf_node_source = LeafNodeSource::Commit(leaf_hash); |
| } |
| |
| // Update hashes after changes to the tree. |
| self.update_hashes(&[index], cipher_suite_provider).await |
| } |
| |
| #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] |
| pub(super) async fn validate_parent_hashes<P: CipherSuiteProvider>( |
| &self, |
| cipher_suite_provider: &P, |
| ) -> Result<(), MlsError> { |
| let original_hashes = self.compute_original_hashes(cipher_suite_provider).await?; |
| |
| let nodes_to_validate = self |
| .nodes |
| .non_empty_parents() |
| .map(|(node_index, _)| node_index); |
| |
| #[cfg(feature = "std")] |
| let mut nodes_to_validate = nodes_to_validate.collect::<HashSet<_>>(); |
| #[cfg(not(feature = "std"))] |
| let mut nodes_to_validate = nodes_to_validate.collect::<BTreeSet<_>>(); |
| |
| let num_leaves = self.total_leaf_count(); |
| |
| // For each leaf l, validate all non-blank nodes on the chain from l up the tree. |
| for (leaf_index, _) in self.nodes.non_empty_leaves() { |
| let mut n = NodeIndex::from(leaf_index); |
| |
| while let Some(mut ps) = n.parent_sibling(&num_leaves) { |
| // Find the first non-blank ancestor p of n and p's co-path child s. |
| while self.nodes.is_blank(ps.parent)? { |
| // If we reached the root, we're done with this chain. |
| let Some(ps_parent) = ps.parent.parent_sibling(&num_leaves) else { |
| return Ok(()); |
| }; |
| |
| ps = ps_parent; |
| } |
| |
| // Check is n's parent_hash field matches the parent hash of p with co-path child s. |
| let p_parent = self.nodes.borrow_as_parent(ps.parent)?; |
| |
| let n_node = self |
| .nodes |
| .borrow_node(n)? |
| .as_ref() |
| .ok_or(MlsError::ExpectedNode)?; |
| |
| let calculated = ParentHash::new( |
| cipher_suite_provider, |
| &p_parent.public_key, |
| &p_parent.parent_hash, |
| &original_hashes[ps.sibling as usize], |
| ) |
| .await?; |
| |
| if n_node.get_parent_hash() == Some(calculated) { |
| // Check that "n is in the resolution of c, and the intersection of p's unmerged_leaves with the subtree |
| // under c is equal to the resolution of c with n removed". |
| let Some(cp) = ps.sibling.parent_sibling(&num_leaves) else { |
| return Err(MlsError::ParentHashMismatch); |
| }; |
| |
| let c = cp.sibling; |
| let c_resolution = self.nodes.get_resolution_index(c)?.into_iter(); |
| |
| #[cfg(feature = "std")] |
| let mut c_resolution = c_resolution.collect::<HashSet<_>>(); |
| #[cfg(not(feature = "std"))] |
| let mut c_resolution = c_resolution.collect::<BTreeSet<_>>(); |
| |
| let p_unmerged_in_c_subtree = self |
| .unmerged_in_subtree(ps.parent, c)? |
| .iter() |
| .copied() |
| .map(|x| *x * 2); |
| |
| #[cfg(feature = "std")] |
| let p_unmerged_in_c_subtree = p_unmerged_in_c_subtree.collect::<HashSet<_>>(); |
| #[cfg(not(feature = "std"))] |
| let p_unmerged_in_c_subtree = p_unmerged_in_c_subtree.collect::<BTreeSet<_>>(); |
| |
| if c_resolution.remove(&n) |
| && c_resolution == p_unmerged_in_c_subtree |
| && nodes_to_validate.remove(&ps.parent) |
| { |
| // If n's parent_hash field matches and p has not been validated yet, mark p as validated and continue. |
| n = ps.parent; |
| } else { |
| // If p is validated for the second time, the check fails ("all non-blank parent nodes are covered by exactly one such chain"). |
| return Err(MlsError::ParentHashMismatch); |
| } |
| } else { |
| // If n's parent_hash field doesn't match, we're done with this chain. |
| break; |
| } |
| } |
| } |
| |
| // The check passes iff all non-blank nodes are validated. |
| if nodes_to_validate.is_empty() { |
| Ok(()) |
| } else { |
| Err(MlsError::ParentHashMismatch) |
| } |
| } |
| } |
| |
| #[cfg(test)] |
| pub(crate) mod test_utils { |
| |
| use super::*; |
| use crate::{ |
| cipher_suite::CipherSuite, |
| crypto::test_utils::test_cipher_suite_provider, |
| identity::basic::BasicIdentityProvider, |
| tree_kem::{leaf_node::test_utils::get_basic_test_node, node::Parent}, |
| }; |
| |
| use alloc::vec; |
| |
| #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] |
| pub(crate) async fn test_parent( |
| cipher_suite: CipherSuite, |
| unmerged_leaves: Vec<LeafIndex>, |
| ) -> Parent { |
| let (_, public_key) = test_cipher_suite_provider(cipher_suite) |
| .kem_generate() |
| .await |
| .unwrap(); |
| |
| Parent { |
| public_key, |
| parent_hash: ParentHash::empty(), |
| unmerged_leaves, |
| } |
| } |
| |
| #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] |
| pub(crate) async fn test_parent_node( |
| cipher_suite: CipherSuite, |
| unmerged_leaves: Vec<LeafIndex>, |
| ) -> Node { |
| Node::Parent(test_parent(cipher_suite, unmerged_leaves).await) |
| } |
| |
| // Create figure 12 from MLS RFC |
| #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] |
| pub(crate) async fn get_test_tree_fig_12(cipher_suite: CipherSuite) -> TreeKemPublic { |
| let cipher_suite_provider = test_cipher_suite_provider(cipher_suite); |
| |
| let mut tree = TreeKemPublic::new(); |
| |
| let mut leaves = Vec::new(); |
| |
| for l in ["A", "B", "C", "D", "E", "F", "G"] { |
| leaves.push(get_basic_test_node(cipher_suite, l).await); |
| } |
| |
| tree.add_leaves(leaves, &BasicIdentityProvider, &cipher_suite_provider) |
| .await |
| .unwrap(); |
| |
| tree.nodes[1] = Some(test_parent_node(cipher_suite, vec![]).await); |
| tree.nodes[3] = Some(test_parent_node(cipher_suite, vec![LeafIndex(3)]).await); |
| |
| tree.nodes[7] = |
| Some(test_parent_node(cipher_suite, vec![LeafIndex(3), LeafIndex(6)]).await); |
| |
| tree.nodes[9] = Some(test_parent_node(cipher_suite, vec![LeafIndex(5)]).await); |
| |
| tree.nodes[11] = |
| Some(test_parent_node(cipher_suite, vec![LeafIndex(5), LeafIndex(6)]).await); |
| |
| tree.update_parent_hashes(LeafIndex(0), false, &cipher_suite_provider) |
| .await |
| .unwrap(); |
| |
| tree.update_parent_hashes(LeafIndex(4), false, &cipher_suite_provider) |
| .await |
| .unwrap(); |
| |
| tree |
| } |
| } |
| |
| #[cfg(test)] |
| mod tests { |
| use super::*; |
| use crate::client::test_utils::TEST_CIPHER_SUITE; |
| use crate::crypto::test_utils::test_cipher_suite_provider; |
| use crate::tree_kem::leaf_node::test_utils::get_basic_test_node; |
| use crate::tree_kem::leaf_node::LeafNodeSource; |
| use crate::tree_kem::test_utils::TreeWithSigners; |
| use crate::tree_kem::MlsError; |
| use assert_matches::assert_matches; |
| |
| #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] |
| async fn test_missing_parent_hash() { |
| let cs = test_cipher_suite_provider(TEST_CIPHER_SUITE); |
| let mut test_tree = TreeWithSigners::make_full_tree(8, &cs).await.tree; |
| |
| *test_tree.nodes.borrow_as_leaf_mut(LeafIndex(0)).unwrap() = |
| get_basic_test_node(TEST_CIPHER_SUITE, "foo").await; |
| |
| let missing_parent_hash_res = test_tree |
| .update_parent_hashes( |
| LeafIndex(0), |
| true, |
| &test_cipher_suite_provider(TEST_CIPHER_SUITE), |
| ) |
| .await; |
| |
| assert_matches!( |
| missing_parent_hash_res, |
| Err(MlsError::InvalidLeafNodeSource) |
| ); |
| } |
| |
| #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] |
| async fn test_parent_hash_mismatch() { |
| let cs = test_cipher_suite_provider(TEST_CIPHER_SUITE); |
| let mut test_tree = TreeWithSigners::make_full_tree(8, &cs).await.tree; |
| |
| let unexpected_parent_hash = ParentHash::from(hex!("f00d")); |
| |
| test_tree |
| .nodes |
| .borrow_as_leaf_mut(LeafIndex(0)) |
| .unwrap() |
| .leaf_node_source = LeafNodeSource::Commit(unexpected_parent_hash); |
| |
| let invalid_parent_hash_res = test_tree |
| .update_parent_hashes( |
| LeafIndex(0), |
| true, |
| &test_cipher_suite_provider(TEST_CIPHER_SUITE), |
| ) |
| .await; |
| |
| assert_matches!(invalid_parent_hash_res, Err(MlsError::ParentHashMismatch)); |
| } |
| |
| #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] |
| async fn test_parent_hash_invalid() { |
| let cs = test_cipher_suite_provider(TEST_CIPHER_SUITE); |
| let mut test_tree = TreeWithSigners::make_full_tree(8, &cs).await.tree; |
| |
| test_tree.nodes[2] = None; |
| |
| let res = test_tree |
| .validate_parent_hashes(&test_cipher_suite_provider(TEST_CIPHER_SUITE)) |
| .await; |
| |
| assert_matches!(res, Err(MlsError::ParentHashMismatch)); |
| } |
| } |