Import 'mls-rs' crate
Request Document: go/android-rust-importing-crates
For CL Reviewers: go/android3p#cl-review
For Build Team: go/ab-third-party-imports
Bug: http://b/330708876
Test: m libmls_rs
Change-Id: Ib0a891a4d7bf582ebea9ba7a1447ea959e42e0d3
diff --git a/src/client.rs b/src/client.rs
new file mode 100644
index 0000000..a7031bb
--- /dev/null
+++ b/src/client.rs
@@ -0,0 +1,1049 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use crate::cipher_suite::CipherSuite;
+use crate::client_builder::{recreate_config, BaseConfig, ClientBuilder, MakeConfig};
+use crate::client_config::ClientConfig;
+use crate::group::framing::MlsMessage;
+
+#[cfg(feature = "by_ref_proposal")]
+use crate::group::{
+ framing::{Content, MlsMessagePayload, PublicMessage, Sender, WireFormat},
+ message_signature::AuthenticatedContent,
+ proposal::{AddProposal, Proposal},
+};
+use crate::group::{snapshot::Snapshot, ExportedTree, Group, NewMemberInfo};
+use crate::identity::SigningIdentity;
+use crate::key_package::{KeyPackageGeneration, KeyPackageGenerator};
+use crate::protocol_version::ProtocolVersion;
+use crate::tree_kem::node::NodeIndex;
+use alloc::vec::Vec;
+use mls_rs_codec::MlsDecode;
+use mls_rs_core::crypto::{CryptoProvider, SignatureSecretKey};
+use mls_rs_core::error::{AnyError, IntoAnyError};
+use mls_rs_core::extension::{ExtensionError, ExtensionList, ExtensionType};
+use mls_rs_core::group::{GroupStateStorage, ProposalType};
+use mls_rs_core::identity::CredentialType;
+use mls_rs_core::key_package::KeyPackageStorage;
+
+use crate::group::external_commit::ExternalCommitBuilder;
+
+#[cfg(feature = "by_ref_proposal")]
+use alloc::boxed::Box;
+
+#[derive(Debug)]
+#[cfg_attr(feature = "std", derive(thiserror::Error))]
+#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::enum_to_error_code)]
+#[non_exhaustive]
+pub enum MlsError {
+ #[cfg_attr(feature = "std", error(transparent))]
+ IdentityProviderError(AnyError),
+ #[cfg_attr(feature = "std", error(transparent))]
+ CryptoProviderError(AnyError),
+ #[cfg_attr(feature = "std", error(transparent))]
+ KeyPackageRepoError(AnyError),
+ #[cfg_attr(feature = "std", error(transparent))]
+ GroupStorageError(AnyError),
+ #[cfg_attr(feature = "std", error(transparent))]
+ PskStoreError(AnyError),
+ #[cfg_attr(feature = "std", error(transparent))]
+ MlsRulesError(AnyError),
+ #[cfg_attr(feature = "std", error(transparent))]
+ SerializationError(AnyError),
+ #[cfg_attr(feature = "std", error(transparent))]
+ ExtensionError(AnyError),
+ #[cfg_attr(feature = "std", error("Cipher suite does not match"))]
+ CipherSuiteMismatch,
+ #[cfg_attr(feature = "std", error("Invalid commit, missing required path"))]
+ CommitMissingPath,
+ #[cfg_attr(feature = "std", error("plaintext message for incorrect epoch"))]
+ InvalidEpoch,
+ #[cfg_attr(feature = "std", error("invalid signature found"))]
+ InvalidSignature,
+ #[cfg_attr(feature = "std", error("invalid confirmation tag"))]
+ InvalidConfirmationTag,
+ #[cfg_attr(feature = "std", error("invalid membership tag"))]
+ InvalidMembershipTag,
+ #[cfg_attr(feature = "std", error("corrupt private key, missing required values"))]
+ InvalidTreeKemPrivateKey,
+ #[cfg_attr(feature = "std", error("key package not found, unable to process"))]
+ WelcomeKeyPackageNotFound,
+ #[cfg_attr(feature = "std", error("leaf not found in tree for index {0}"))]
+ LeafNotFound(u32),
+ #[cfg_attr(feature = "std", error("message from self can't be processed"))]
+ CantProcessMessageFromSelf,
+ #[cfg_attr(
+ feature = "std",
+ error("pending proposals found, commit required before application messages can be sent")
+ )]
+ CommitRequired,
+ #[cfg_attr(
+ feature = "std",
+ error("ratchet tree not provided or discovered in GroupInfo")
+ )]
+ RatchetTreeNotFound,
+ #[cfg_attr(feature = "std", error("External sender cannot commit"))]
+ ExternalSenderCannotCommit,
+ #[cfg_attr(feature = "std", error("Unsupported protocol version {0:?}"))]
+ UnsupportedProtocolVersion(ProtocolVersion),
+ #[cfg_attr(feature = "std", error("Protocol version mismatch"))]
+ ProtocolVersionMismatch,
+ #[cfg_attr(feature = "std", error("Unsupported cipher suite {0:?}"))]
+ UnsupportedCipherSuite(CipherSuite),
+ #[cfg_attr(feature = "std", error("Signing key of external sender is unknown"))]
+ UnknownSigningIdentityForExternalSender,
+ #[cfg_attr(
+ feature = "std",
+ error("External proposals are disabled for this group")
+ )]
+ ExternalProposalsDisabled,
+ #[cfg_attr(
+ feature = "std",
+ error("Signing identity is not allowed to externally propose")
+ )]
+ InvalidExternalSigningIdentity,
+ #[cfg_attr(feature = "std", error("Missing ExternalPub extension"))]
+ MissingExternalPubExtension,
+ #[cfg_attr(feature = "std", error("Epoch not found"))]
+ EpochNotFound,
+ #[cfg_attr(feature = "std", error("Unencrypted application message"))]
+ UnencryptedApplicationMessage,
+ #[cfg_attr(
+ feature = "std",
+ error("NewMemberCommit sender type can only be used to send Commit content")
+ )]
+ ExpectedCommitForNewMemberCommit,
+ #[cfg_attr(
+ feature = "std",
+ error("NewMemberProposal sender type can only be used to send add proposals")
+ )]
+ ExpectedAddProposalForNewMemberProposal,
+ #[cfg_attr(
+ feature = "std",
+ error("External commit missing ExternalInit proposal")
+ )]
+ ExternalCommitMissingExternalInit,
+ #[cfg_attr(
+ feature = "std",
+ error(
+ "A ReIinit has been applied. The next action must be creating or receiving a welcome."
+ )
+ )]
+ GroupUsedAfterReInit,
+ #[cfg_attr(feature = "std", error("Pending ReIinit not found."))]
+ PendingReInitNotFound,
+ #[cfg_attr(
+ feature = "std",
+ error("The extensions in the welcome message and in the reinit do not match.")
+ )]
+ ReInitExtensionsMismatch,
+ #[cfg_attr(feature = "std", error("signer not found for given identity"))]
+ SignerNotFound,
+ #[cfg_attr(feature = "std", error("commit already pending"))]
+ ExistingPendingCommit,
+ #[cfg_attr(feature = "std", error("pending commit not found"))]
+ PendingCommitNotFound,
+ #[cfg_attr(feature = "std", error("unexpected message type for action"))]
+ UnexpectedMessageType,
+ #[cfg_attr(
+ feature = "std",
+ error("membership tag on MlsPlaintext for non-member sender")
+ )]
+ MembershipTagForNonMember,
+ #[cfg_attr(feature = "std", error("No member found for given identity id."))]
+ MemberNotFound,
+ #[cfg_attr(feature = "std", error("group not found"))]
+ GroupNotFound,
+ #[cfg_attr(feature = "std", error("unexpected PSK ID"))]
+ UnexpectedPskId,
+ #[cfg_attr(feature = "std", error("invalid sender for content type"))]
+ InvalidSender,
+ #[cfg_attr(feature = "std", error("GroupID mismatch"))]
+ GroupIdMismatch,
+ #[cfg_attr(feature = "std", error("storage retention can not be zero"))]
+ NonZeroRetentionRequired,
+ #[cfg_attr(feature = "std", error("Too many PSK IDs to compute PSK secret"))]
+ TooManyPskIds,
+ #[cfg_attr(feature = "std", error("Missing required Psk"))]
+ MissingRequiredPsk,
+ #[cfg_attr(feature = "std", error("Old group state not found"))]
+ OldGroupStateNotFound,
+ #[cfg_attr(feature = "std", error("leaf secret already consumed"))]
+ InvalidLeafConsumption,
+ #[cfg_attr(feature = "std", error("key not available, invalid generation {0}"))]
+ KeyMissing(u32),
+ #[cfg_attr(
+ feature = "std",
+ error("requested generation {0} is too far ahead of current generation")
+ )]
+ InvalidFutureGeneration(u32),
+ #[cfg_attr(feature = "std", error("leaf node has no children"))]
+ LeafNodeNoChildren,
+ #[cfg_attr(feature = "std", error("root node has no parent"))]
+ LeafNodeNoParent,
+ #[cfg_attr(feature = "std", error("index out of range"))]
+ InvalidTreeIndex,
+ #[cfg_attr(feature = "std", error("time overflow"))]
+ TimeOverflow,
+ #[cfg_attr(feature = "std", error("invalid leaf_node_source"))]
+ InvalidLeafNodeSource,
+ #[cfg_attr(feature = "std", error("key package has expired or is not valid yet"))]
+ InvalidLifetime,
+ #[cfg_attr(feature = "std", error("required extension not found"))]
+ RequiredExtensionNotFound(ExtensionType),
+ #[cfg_attr(feature = "std", error("required proposal not found"))]
+ RequiredProposalNotFound(ProposalType),
+ #[cfg_attr(feature = "std", error("required credential not found"))]
+ RequiredCredentialNotFound(CredentialType),
+ #[cfg_attr(feature = "std", error("capabilities must describe extensions used"))]
+ ExtensionNotInCapabilities(ExtensionType),
+ #[cfg_attr(feature = "std", error("expected non-blank node"))]
+ ExpectedNode,
+ #[cfg_attr(feature = "std", error("node index is out of bounds {0}"))]
+ InvalidNodeIndex(NodeIndex),
+ #[cfg_attr(feature = "std", error("unexpected empty node found"))]
+ UnexpectedEmptyNode,
+ #[cfg_attr(
+ feature = "std",
+ error("duplicate signature key, hpke key or identity found at index {0}")
+ )]
+ DuplicateLeafData(u32),
+ #[cfg_attr(
+ feature = "std",
+ error("In-use credential type not supported by new leaf at index")
+ )]
+ InUseCredentialTypeUnsupportedByNewLeaf,
+ #[cfg_attr(
+ feature = "std",
+ error("Not all members support the credential type used by new leaf")
+ )]
+ CredentialTypeOfNewLeafIsUnsupported,
+ #[cfg_attr(
+ feature = "std",
+ error("the length of the update path is different than the length of the direct path")
+ )]
+ WrongPathLen,
+ #[cfg_attr(
+ feature = "std",
+ error("same HPKE leaf key before and after applying the update path for leaf {0}")
+ )]
+ SameHpkeKey(u32),
+ #[cfg_attr(feature = "std", error("init key is not valid for cipher suite"))]
+ InvalidInitKey,
+ #[cfg_attr(
+ feature = "std",
+ error("init key can not be equal to leaf node public key")
+ )]
+ InitLeafKeyEquality,
+ #[cfg_attr(feature = "std", error("different identity in update for leaf {0}"))]
+ DifferentIdentityInUpdate(u32),
+ #[cfg_attr(feature = "std", error("update path pub key mismatch"))]
+ PubKeyMismatch,
+ #[cfg_attr(feature = "std", error("tree hash mismatch"))]
+ TreeHashMismatch,
+ #[cfg_attr(feature = "std", error("bad update: no suitable secret key"))]
+ UpdateErrorNoSecretKey,
+ #[cfg_attr(feature = "std", error("invalid lca, not found on direct path"))]
+ LcaNotFoundInDirectPath,
+ #[cfg_attr(feature = "std", error("update path parent hash mismatch"))]
+ ParentHashMismatch,
+ #[cfg_attr(feature = "std", error("unexpected pattern of unmerged leaves"))]
+ UnmergedLeavesMismatch,
+ #[cfg_attr(feature = "std", error("empty tree"))]
+ UnexpectedEmptyTree,
+ #[cfg_attr(feature = "std", error("trailing blanks"))]
+ UnexpectedTrailingBlanks,
+ // Proposal Rules errors
+ #[cfg_attr(
+ feature = "std",
+ error("Commiter must not include any update proposals generated by the commiter")
+ )]
+ InvalidCommitSelfUpdate,
+ #[cfg_attr(feature = "std", error("A PreSharedKey proposal must have a PSK of type External or type Resumption and usage Application"))]
+ InvalidTypeOrUsageInPreSharedKeyProposal,
+ #[cfg_attr(feature = "std", error("psk nonce length does not match cipher suite"))]
+ InvalidPskNonceLength,
+ #[cfg_attr(
+ feature = "std",
+ error("ReInit proposal protocol version is less than the version of the original group")
+ )]
+ InvalidProtocolVersionInReInit,
+ #[cfg_attr(feature = "std", error("More than one proposal applying to leaf: {0}"))]
+ MoreThanOneProposalForLeaf(u32),
+ #[cfg_attr(
+ feature = "std",
+ error("More than one GroupContextExtensions proposal")
+ )]
+ MoreThanOneGroupContextExtensionsProposal,
+ #[cfg_attr(feature = "std", error("Invalid proposal type for sender"))]
+ InvalidProposalTypeForSender,
+ #[cfg_attr(
+ feature = "std",
+ error("External commit must have exactly one ExternalInit proposal")
+ )]
+ ExternalCommitMustHaveExactlyOneExternalInit,
+ #[cfg_attr(feature = "std", error("External commit must have a new leaf"))]
+ ExternalCommitMustHaveNewLeaf,
+ #[cfg_attr(
+ feature = "std",
+ error("External commit contains removal of other identity")
+ )]
+ ExternalCommitRemovesOtherIdentity,
+ #[cfg_attr(
+ feature = "std",
+ error("External commit contains more than one Remove proposal")
+ )]
+ ExternalCommitWithMoreThanOneRemove,
+ #[cfg_attr(feature = "std", error("Duplicate PSK IDs"))]
+ DuplicatePskIds,
+ #[cfg_attr(
+ feature = "std",
+ error("Invalid proposal type {0:?} in external commit")
+ )]
+ InvalidProposalTypeInExternalCommit(ProposalType),
+ #[cfg_attr(feature = "std", error("Committer can not remove themselves"))]
+ CommitterSelfRemoval,
+ #[cfg_attr(
+ feature = "std",
+ error("Only members can commit proposals by reference")
+ )]
+ OnlyMembersCanCommitProposalsByRef,
+ #[cfg_attr(feature = "std", error("Other proposal with ReInit"))]
+ OtherProposalWithReInit,
+ #[cfg_attr(feature = "std", error("Unsupported group extension {0:?}"))]
+ UnsupportedGroupExtension(ExtensionType),
+ #[cfg_attr(feature = "std", error("Unsupported custom proposal type {0:?}"))]
+ UnsupportedCustomProposal(ProposalType),
+ #[cfg_attr(feature = "std", error("by-ref proposal not found"))]
+ ProposalNotFound,
+ #[cfg_attr(
+ feature = "std",
+ error("Removing non-existing member (or removing a member twice)")
+ )]
+ RemovingNonExistingMember,
+ #[cfg_attr(feature = "std", error("Updated identity not a valid successor"))]
+ InvalidSuccessor,
+ #[cfg_attr(
+ feature = "std",
+ error("Updating non-existing member (or updating a member twice)")
+ )]
+ UpdatingNonExistingMember,
+ #[cfg_attr(feature = "std", error("Failed generating next path secret"))]
+ FailedGeneratingPathSecret,
+ #[cfg_attr(feature = "std", error("Invalid group info"))]
+ InvalidGroupInfo,
+ #[cfg_attr(feature = "std", error("Invalid welcome message"))]
+ InvalidWelcomeMessage,
+}
+
+impl IntoAnyError for MlsError {
+ #[cfg(feature = "std")]
+ fn into_dyn_error(self) -> Result<Box<dyn std::error::Error + Send + Sync>, Self> {
+ Ok(self.into())
+ }
+}
+
+impl From<mls_rs_codec::Error> for MlsError {
+ #[inline]
+ fn from(e: mls_rs_codec::Error) -> Self {
+ MlsError::SerializationError(e.into_any_error())
+ }
+}
+
+impl From<ExtensionError> for MlsError {
+ #[inline]
+ fn from(e: ExtensionError) -> Self {
+ MlsError::ExtensionError(e.into_any_error())
+ }
+}
+
+/// MLS client used to create key packages and manage groups.
+///
+/// [`Client::builder`] can be used to instantiate it.
+///
+/// Clients are able to support multiple protocol versions, ciphersuites
+/// and underlying identities used to join groups and generate key packages.
+/// Applications may decide to create one or many clients depending on their
+/// specific needs.
+#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::ffi_type(opaque))]
+#[derive(Clone, Debug)]
+pub struct Client<C> {
+ pub(crate) config: C,
+ pub(crate) signing_identity: Option<(SigningIdentity, CipherSuite)>,
+ pub(crate) signer: Option<SignatureSecretKey>,
+ pub(crate) version: ProtocolVersion,
+}
+
+impl Client<()> {
+ /// Returns a [`ClientBuilder`]
+ /// used to configure client preferences and providers.
+ pub fn builder() -> ClientBuilder<BaseConfig> {
+ ClientBuilder::new()
+ }
+}
+
+#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
+impl<C> Client<C>
+where
+ C: ClientConfig + Clone,
+{
+ pub(crate) fn new(
+ config: C,
+ signer: Option<SignatureSecretKey>,
+ signing_identity: Option<(SigningIdentity, CipherSuite)>,
+ version: ProtocolVersion,
+ ) -> Self {
+ Client {
+ config,
+ signer,
+ signing_identity,
+ version,
+ }
+ }
+
+ #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen_ignore)]
+ pub fn to_builder(&self) -> ClientBuilder<MakeConfig<C>> {
+ ClientBuilder::from_config(recreate_config(
+ self.config.clone(),
+ self.signer.clone(),
+ self.signing_identity.clone(),
+ self.version,
+ ))
+ }
+
+ /// Creates a new key package message that can be used to to add this
+ /// client to a [Group](crate::group::Group). Each call to this function
+ /// will produce a unique value that is signed by `signing_identity`.
+ ///
+ /// The secret keys for the resulting key package message will be stored in
+ /// the [KeyPackageStorage](crate::KeyPackageStorage)
+ /// that was used to configure the client and will
+ /// automatically be erased when this key package is used to
+ /// [join a group](Client::join_group).
+ ///
+ /// # Warning
+ ///
+ /// A key package message may only be used once.
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn generate_key_package_message(&self) -> Result<MlsMessage, MlsError> {
+ Ok(self.generate_key_package().await?.key_package_message())
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn generate_key_package(&self) -> Result<KeyPackageGeneration, MlsError> {
+ let (signing_identity, cipher_suite) = self.signing_identity()?;
+
+ let cipher_suite_provider = self
+ .config
+ .crypto_provider()
+ .cipher_suite_provider(cipher_suite)
+ .ok_or(MlsError::UnsupportedCipherSuite(cipher_suite))?;
+
+ let key_package_generator = KeyPackageGenerator {
+ protocol_version: self.version,
+ cipher_suite_provider: &cipher_suite_provider,
+ signing_key: self.signer()?,
+ signing_identity,
+ identity_provider: &self.config.identity_provider(),
+ };
+
+ let key_pkg_gen = key_package_generator
+ .generate(
+ self.config.lifetime(),
+ self.config.capabilities(),
+ self.config.key_package_extensions(),
+ self.config.leaf_node_extensions(),
+ )
+ .await?;
+
+ let (id, key_package_data) = key_pkg_gen.to_storage()?;
+
+ self.config
+ .key_package_repo()
+ .insert(id, key_package_data)
+ .await
+ .map_err(|e| MlsError::KeyPackageRepoError(e.into_any_error()))?;
+
+ Ok(key_pkg_gen)
+ }
+
+ /// Create a group with a specific group_id.
+ ///
+ /// This function behaves the same way as
+ /// [create_group](Client::create_group) except that it
+ /// specifies a specific unique group identifier to be used.
+ ///
+ /// # Warning
+ ///
+ /// It is recommended to use [create_group](Client::create_group)
+ /// instead of this function because it guarantees that group_id values
+ /// are globally unique.
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn create_group_with_id(
+ &self,
+ group_id: Vec<u8>,
+ group_context_extensions: ExtensionList,
+ ) -> Result<Group<C>, MlsError> {
+ let (signing_identity, cipher_suite) = self.signing_identity()?;
+
+ Group::new(
+ self.config.clone(),
+ Some(group_id),
+ cipher_suite,
+ self.version,
+ signing_identity.clone(),
+ group_context_extensions,
+ self.signer()?.clone(),
+ )
+ .await
+ }
+
+ /// Create a MLS group.
+ ///
+ /// The `cipher_suite` provided must be supported by the
+ /// [CipherSuiteProvider](crate::CipherSuiteProvider)
+ /// that was used to build the client.
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn create_group(
+ &self,
+ group_context_extensions: ExtensionList,
+ ) -> Result<Group<C>, MlsError> {
+ let (signing_identity, cipher_suite) = self.signing_identity()?;
+
+ Group::new(
+ self.config.clone(),
+ None,
+ cipher_suite,
+ self.version,
+ signing_identity.clone(),
+ group_context_extensions,
+ self.signer()?.clone(),
+ )
+ .await
+ }
+
+ /// Join a MLS group via a welcome message created by a
+ /// [Commit](crate::group::CommitOutput).
+ ///
+ /// `tree_data` is required to be provided out of band if the client that
+ /// created `welcome_message` did not use the `ratchet_tree_extension`
+ /// according to [`MlsRules::commit_options`](`crate::MlsRules::commit_options`).
+ /// at the time the welcome message was created. `tree_data` can
+ /// be exported from a group using the
+ /// [export tree function](crate::group::Group::export_tree).
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn join_group(
+ &self,
+ tree_data: Option<ExportedTree<'_>>,
+ welcome_message: &MlsMessage,
+ ) -> Result<(Group<C>, NewMemberInfo), MlsError> {
+ Group::join(
+ welcome_message,
+ tree_data,
+ self.config.clone(),
+ self.signer()?.clone(),
+ )
+ .await
+ }
+
+ /// 0-RTT add to an existing [group](crate::group::Group)
+ ///
+ /// External commits allow for immediate entry into a
+ /// [group](crate::group::Group), even if all of the group members
+ /// are currently offline and unable to process messages. Sending an
+ /// external commit is only allowed for groups that have provided
+ /// a public `group_info_message` containing an
+ /// [ExternalPubExt](crate::extension::ExternalPubExt), which can be
+ /// generated by an existing group member using the
+ /// [group_info_message](crate::group::Group::group_info_message)
+ /// function.
+ ///
+ /// `tree_data` may be provided following the same rules as [Client::join_group]
+ ///
+ /// If PSKs are provided in `external_psks`, the
+ /// [PreSharedKeyStorage](crate::PreSharedKeyStorage)
+ /// used to configure the client will be searched to resolve their values.
+ ///
+ /// `to_remove` may be used to remove an existing member provided that the
+ /// identity of the existing group member at that [index](crate::group::Member::index)
+ /// is a [valid successor](crate::IdentityProvider::valid_successor)
+ /// of `signing_identity` as defined by the
+ /// [IdentityProvider](crate::IdentityProvider) that this client
+ /// was configured with.
+ ///
+ /// # Warning
+ ///
+ /// Only one external commit can be performed against a given group info.
+ /// There may also be security trade-offs to this approach.
+ ///
+ // TODO: Add a comment about forward secrecy and a pointer to the future
+ // book chapter on this topic
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn commit_external(
+ &self,
+ group_info_msg: MlsMessage,
+ ) -> Result<(Group<C>, MlsMessage), MlsError> {
+ ExternalCommitBuilder::new(
+ self.signer()?.clone(),
+ self.signing_identity()?.0.clone(),
+ self.config.clone(),
+ )
+ .build(group_info_msg)
+ .await
+ }
+
+ pub fn external_commit_builder(&self) -> Result<ExternalCommitBuilder<C>, MlsError> {
+ Ok(ExternalCommitBuilder::new(
+ self.signer()?.clone(),
+ self.signing_identity()?.0.clone(),
+ self.config.clone(),
+ ))
+ }
+
+ /// Load an existing group state into this client using the
+ /// [GroupStateStorage](crate::GroupStateStorage) that
+ /// this client was configured to use.
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ #[inline(never)]
+ pub async fn load_group(&self, group_id: &[u8]) -> Result<Group<C>, MlsError> {
+ let snapshot = self
+ .config
+ .group_state_storage()
+ .state(group_id)
+ .await
+ .map_err(|e| MlsError::GroupStorageError(e.into_any_error()))?
+ .ok_or(MlsError::GroupNotFound)?;
+
+ let snapshot = Snapshot::mls_decode(&mut &*snapshot)?;
+
+ Group::from_snapshot(self.config.clone(), snapshot).await
+ }
+
+ /// Request to join an existing [group](crate::group::Group).
+ ///
+ /// An existing group member will need to perform a
+ /// [commit](crate::Group::commit) to complete the add and the resulting
+ /// welcome message can be used by [join_group](Client::join_group).
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn external_add_proposal(
+ &self,
+ group_info: &MlsMessage,
+ tree_data: Option<crate::group::ExportedTree<'_>>,
+ authenticated_data: Vec<u8>,
+ ) -> Result<MlsMessage, MlsError> {
+ let protocol_version = group_info.version;
+
+ if !self.config.version_supported(protocol_version) && protocol_version == self.version {
+ return Err(MlsError::UnsupportedProtocolVersion(protocol_version));
+ }
+
+ let group_info = group_info
+ .as_group_info()
+ .ok_or(MlsError::UnexpectedMessageType)?;
+
+ let cipher_suite = group_info.group_context.cipher_suite;
+
+ let cipher_suite_provider = self
+ .config
+ .crypto_provider()
+ .cipher_suite_provider(cipher_suite)
+ .ok_or(MlsError::UnsupportedCipherSuite(cipher_suite))?;
+
+ crate::group::validate_group_info_joiner(
+ protocol_version,
+ group_info,
+ tree_data,
+ &self.config.identity_provider(),
+ &cipher_suite_provider,
+ )
+ .await?;
+
+ let key_package = self.generate_key_package().await?.key_package;
+
+ (key_package.cipher_suite == cipher_suite)
+ .then_some(())
+ .ok_or(MlsError::UnsupportedCipherSuite(cipher_suite))?;
+
+ let message = AuthenticatedContent::new_signed(
+ &cipher_suite_provider,
+ &group_info.group_context,
+ Sender::NewMemberProposal,
+ Content::Proposal(Box::new(Proposal::Add(Box::new(AddProposal {
+ key_package,
+ })))),
+ self.signer()?,
+ WireFormat::PublicMessage,
+ authenticated_data,
+ )
+ .await?;
+
+ let plaintext = PublicMessage {
+ content: message.content,
+ auth: message.auth,
+ membership_tag: None,
+ };
+
+ Ok(MlsMessage {
+ version: protocol_version,
+ payload: MlsMessagePayload::Plain(plaintext),
+ })
+ }
+
+ fn signer(&self) -> Result<&SignatureSecretKey, MlsError> {
+ self.signer.as_ref().ok_or(MlsError::SignerNotFound)
+ }
+
+ #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen_ignore)]
+ pub fn signing_identity(&self) -> Result<(&SigningIdentity, CipherSuite), MlsError> {
+ self.signing_identity
+ .as_ref()
+ .map(|(id, cs)| (id, *cs))
+ .ok_or(MlsError::SignerNotFound)
+ }
+
+ /// Returns key package extensions used by this client
+ pub fn key_package_extensions(&self) -> ExtensionList {
+ self.config.key_package_extensions()
+ }
+
+ /// The [KeyPackageStorage] that this client was configured to use.
+ #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen_ignore)]
+ pub fn key_package_store(&self) -> <C as ClientConfig>::KeyPackageRepository {
+ self.config.key_package_repo()
+ }
+
+ /// The [PreSharedKeyStorage](crate::PreSharedKeyStorage) that
+ /// this client was configured to use.
+ #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen_ignore)]
+ pub fn secret_store(&self) -> <C as ClientConfig>::PskStore {
+ self.config.secret_store()
+ }
+
+ /// The [GroupStateStorage] that this client was configured to use.
+ #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen_ignore)]
+ pub fn group_state_storage(&self) -> <C as ClientConfig>::GroupStateStorage {
+ self.config.group_state_storage()
+ }
+}
+
+#[cfg(test)]
+pub(crate) mod test_utils {
+ use super::*;
+ use crate::identity::test_utils::get_test_signing_identity;
+
+ pub use crate::client_builder::test_utils::{TestClientBuilder, TestClientConfig};
+
+ pub const TEST_PROTOCOL_VERSION: ProtocolVersion = ProtocolVersion::MLS_10;
+ pub const TEST_CIPHER_SUITE: CipherSuite = CipherSuite::P256_AES128;
+ pub const TEST_CUSTOM_PROPOSAL_TYPE: ProposalType = ProposalType::new(65001);
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn test_client_with_key_pkg(
+ protocol_version: ProtocolVersion,
+ cipher_suite: CipherSuite,
+ identity: &str,
+ ) -> (Client<TestClientConfig>, MlsMessage) {
+ test_client_with_key_pkg_custom(protocol_version, cipher_suite, identity, |_| {}).await
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn test_client_with_key_pkg_custom<F>(
+ protocol_version: ProtocolVersion,
+ cipher_suite: CipherSuite,
+ identity: &str,
+ mut config: F,
+ ) -> (Client<TestClientConfig>, MlsMessage)
+ where
+ F: FnMut(&mut TestClientConfig),
+ {
+ let (identity, secret_key) =
+ get_test_signing_identity(cipher_suite, identity.as_bytes()).await;
+
+ let mut client = TestClientBuilder::new_for_test()
+ .used_protocol_version(protocol_version)
+ .signing_identity(identity.clone(), secret_key, cipher_suite)
+ .build();
+
+ config(&mut client.config);
+
+ let key_package = client.generate_key_package_message().await.unwrap();
+
+ (client, key_package)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::test_utils::*;
+
+ use super::*;
+ use crate::{
+ crypto::test_utils::TestCryptoProvider,
+ identity::test_utils::{get_test_basic_credential, get_test_signing_identity},
+ tree_kem::leaf_node::LeafNodeSource,
+ };
+ use assert_matches::assert_matches;
+
+ use crate::{
+ group::{
+ message_processor::ProposalMessageDescription,
+ proposal::Proposal,
+ test_utils::{test_group, test_group_custom_config},
+ ReceivedMessage,
+ },
+ psk::{ExternalPskId, PreSharedKey},
+ };
+
+ use alloc::vec;
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_keygen() {
+ // This is meant to test the inputs to the internal key package generator
+ // See KeyPackageGenerator tests for key generation specific tests
+ for (protocol_version, cipher_suite) in ProtocolVersion::all().flat_map(|p| {
+ TestCryptoProvider::all_supported_cipher_suites()
+ .into_iter()
+ .map(move |cs| (p, cs))
+ }) {
+ let (identity, secret_key) = get_test_signing_identity(cipher_suite, b"foo").await;
+
+ let client = TestClientBuilder::new_for_test()
+ .signing_identity(identity.clone(), secret_key, cipher_suite)
+ .build();
+
+ // TODO: Tests around extensions
+ let key_package = client.generate_key_package_message().await.unwrap();
+
+ assert_eq!(key_package.version, protocol_version);
+
+ let key_package = key_package.into_key_package().unwrap();
+
+ assert_eq!(key_package.cipher_suite, cipher_suite);
+
+ assert_eq!(
+ &key_package.leaf_node.signing_identity.credential,
+ &get_test_basic_credential(b"foo".to_vec())
+ );
+
+ assert_eq!(key_package.leaf_node.signing_identity, identity);
+
+ let capabilities = key_package.leaf_node.ungreased_capabilities();
+ assert_eq!(capabilities, client.config.capabilities());
+
+ let client_lifetime = client.config.lifetime();
+ assert_matches!(key_package.leaf_node.leaf_node_source, LeafNodeSource::KeyPackage(lifetime) if (lifetime.not_after - lifetime.not_before) == (client_lifetime.not_after - client_lifetime.not_before));
+ }
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn new_member_add_proposal_adds_to_group() {
+ let mut alice_group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+
+ let (bob_identity, secret_key) = get_test_signing_identity(TEST_CIPHER_SUITE, b"bob").await;
+
+ let bob = TestClientBuilder::new_for_test()
+ .signing_identity(bob_identity.clone(), secret_key, TEST_CIPHER_SUITE)
+ .build();
+
+ let proposal = bob
+ .external_add_proposal(
+ &alice_group.group.group_info_message(true).await.unwrap(),
+ None,
+ vec![],
+ )
+ .await
+ .unwrap();
+
+ let message = alice_group
+ .group
+ .process_incoming_message(proposal)
+ .await
+ .unwrap();
+
+ assert_matches!(
+ message,
+ ReceivedMessage::Proposal(ProposalMessageDescription {
+ proposal: Proposal::Add(p), ..}
+ ) if p.key_package.leaf_node.signing_identity == bob_identity
+ );
+
+ alice_group.group.commit(vec![]).await.unwrap();
+ alice_group.group.apply_pending_commit().await.unwrap();
+
+ // Check that the new member is in the group
+ assert!(alice_group
+ .group
+ .roster()
+ .members_iter()
+ .any(|member| member.signing_identity == bob_identity))
+ }
+
+ #[cfg(feature = "psk")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn join_via_external_commit(do_remove: bool, with_psk: bool) -> Result<(), MlsError> {
+ // An external commit cannot be the first commit in a group as it requires
+ // interim_transcript_hash to be computed from the confirmed_transcript_hash and
+ // confirmation_tag, which is not the case for the initial interim_transcript_hash.
+
+ let psk = PreSharedKey::from(b"psk".to_vec());
+ let psk_id = ExternalPskId::new(b"psk id".to_vec());
+
+ let mut alice_group =
+ test_group_custom_config(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, |c| {
+ c.psk(psk_id.clone(), psk.clone())
+ })
+ .await;
+
+ let (mut bob_group, _) = alice_group
+ .join_with_custom_config("bob", false, |c| {
+ c.0.psk_store.insert(psk_id.clone(), psk.clone());
+ })
+ .await
+ .unwrap();
+
+ let group_info_msg = alice_group
+ .group
+ .group_info_message_allowing_ext_commit(true)
+ .await
+ .unwrap();
+
+ let new_client_id = if do_remove { "bob" } else { "charlie" };
+
+ let (new_client_identity, secret_key) =
+ get_test_signing_identity(TEST_CIPHER_SUITE, new_client_id.as_bytes()).await;
+
+ let new_client = TestClientBuilder::new_for_test()
+ .psk(psk_id.clone(), psk)
+ .signing_identity(new_client_identity.clone(), secret_key, TEST_CIPHER_SUITE)
+ .build();
+
+ let mut builder = new_client.external_commit_builder().unwrap();
+
+ if do_remove {
+ builder = builder.with_removal(1);
+ }
+
+ if with_psk {
+ builder = builder.with_external_psk(psk_id);
+ }
+
+ let (new_group, external_commit) = builder.build(group_info_msg).await?;
+
+ let num_members = if do_remove { 2 } else { 3 };
+
+ assert_eq!(new_group.roster().members_iter().count(), num_members);
+
+ let _ = alice_group
+ .group
+ .process_incoming_message(external_commit.clone())
+ .await
+ .unwrap();
+
+ let bob_current_epoch = bob_group.group.current_epoch();
+
+ let message = bob_group
+ .group
+ .process_incoming_message(external_commit)
+ .await
+ .unwrap();
+
+ assert!(alice_group.group.roster().members_iter().count() == num_members);
+
+ if !do_remove {
+ assert!(bob_group.group.roster().members_iter().count() == num_members);
+ } else {
+ // Bob was removed so his epoch must stay the same
+ assert_eq!(bob_group.group.current_epoch(), bob_current_epoch);
+
+ #[cfg(feature = "state_update")]
+ assert_matches!(message, ReceivedMessage::Commit(desc) if !desc.state_update.active);
+
+ #[cfg(not(feature = "state_update"))]
+ assert_matches!(message, ReceivedMessage::Commit(_));
+ }
+
+ // Comparing epoch authenticators is sufficient to check that members are in sync.
+ assert_eq!(
+ alice_group.group.epoch_authenticator().unwrap(),
+ new_group.epoch_authenticator().unwrap()
+ );
+
+ Ok(())
+ }
+
+ #[cfg(feature = "psk")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_external_commit() {
+ // New member can join
+ join_via_external_commit(false, false).await.unwrap();
+ // New member can remove an old copy of themselves
+ join_via_external_commit(true, false).await.unwrap();
+ // New member can inject a PSK
+ join_via_external_commit(false, true).await.unwrap();
+ // All works together
+ join_via_external_commit(true, true).await.unwrap();
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn creating_an_external_commit_requires_a_group_info_message() {
+ let (alice_identity, secret_key) =
+ get_test_signing_identity(TEST_CIPHER_SUITE, b"alice").await;
+
+ let alice = TestClientBuilder::new_for_test()
+ .signing_identity(alice_identity.clone(), secret_key, TEST_CIPHER_SUITE)
+ .build();
+
+ let msg = alice.generate_key_package_message().await.unwrap();
+ let res = alice.commit_external(msg).await.map(|_| ());
+
+ assert_matches!(res, Err(MlsError::UnexpectedMessageType));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn external_commit_with_invalid_group_info_fails() {
+ let mut alice_group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let mut bob_group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+
+ bob_group.group.commit(vec![]).await.unwrap();
+ bob_group.group.apply_pending_commit().await.unwrap();
+
+ let group_info_msg = bob_group
+ .group
+ .group_info_message_allowing_ext_commit(true)
+ .await
+ .unwrap();
+
+ let (carol_identity, secret_key) =
+ get_test_signing_identity(TEST_CIPHER_SUITE, b"carol").await;
+
+ let carol = TestClientBuilder::new_for_test()
+ .signing_identity(carol_identity, secret_key, TEST_CIPHER_SUITE)
+ .build();
+
+ let (_, external_commit) = carol
+ .external_commit_builder()
+ .unwrap()
+ .build(group_info_msg)
+ .await
+ .unwrap();
+
+ // If Carol tries to join Alice's group using the group info from Bob's group, that fails.
+ let res = alice_group
+ .group
+ .process_incoming_message(external_commit)
+ .await;
+ assert_matches!(res, Err(_));
+ }
+
+ #[test]
+ fn builder_can_be_obtained_from_client_to_edit_properties_for_new_client() {
+ let alice = TestClientBuilder::new_for_test()
+ .extension_type(33.into())
+ .build();
+ let bob = alice.to_builder().extension_type(34.into()).build();
+ assert_eq!(bob.config.supported_extensions(), [33, 34].map(Into::into));
+ }
+}
diff --git a/src/client_builder.rs b/src/client_builder.rs
new file mode 100644
index 0000000..186c436
--- /dev/null
+++ b/src/client_builder.rs
@@ -0,0 +1,1029 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+//! Definitions to build a [`Client`].
+//!
+//! See [`ClientBuilder`].
+
+use crate::{
+ cipher_suite::CipherSuite,
+ client::Client,
+ client_config::ClientConfig,
+ extension::{ExtensionType, MlsExtension},
+ group::{
+ mls_rules::{DefaultMlsRules, MlsRules},
+ proposal::ProposalType,
+ },
+ identity::CredentialType,
+ identity::SigningIdentity,
+ protocol_version::ProtocolVersion,
+ psk::{ExternalPskId, PreSharedKey},
+ storage_provider::in_memory::{
+ InMemoryGroupStateStorage, InMemoryKeyPackageStorage, InMemoryPreSharedKeyStorage,
+ },
+ tree_kem::{Capabilities, Lifetime},
+ Sealed,
+};
+
+#[cfg(feature = "std")]
+use crate::time::MlsTime;
+
+use alloc::vec::Vec;
+
+#[cfg(feature = "sqlite")]
+use mls_rs_provider_sqlite::{
+ SqLiteDataStorageEngine, SqLiteDataStorageError,
+ {
+ connection_strategy::ConnectionStrategy,
+ storage::{SqLiteGroupStateStorage, SqLiteKeyPackageStorage, SqLitePreSharedKeyStorage},
+ },
+};
+
+#[cfg(feature = "private_message")]
+pub use crate::group::padding::PaddingMode;
+
+/// Base client configuration type when instantiating `ClientBuilder`
+pub type BaseConfig = Config<
+ InMemoryKeyPackageStorage,
+ InMemoryPreSharedKeyStorage,
+ InMemoryGroupStateStorage,
+ Missing,
+ DefaultMlsRules,
+ Missing,
+>;
+
+/// Base client configuration type when instantiating `ClientBuilder`
+pub type BaseInMemoryConfig = Config<
+ InMemoryKeyPackageStorage,
+ InMemoryPreSharedKeyStorage,
+ InMemoryGroupStateStorage,
+ Missing,
+ Missing,
+ Missing,
+>;
+
+pub type EmptyConfig = Config<Missing, Missing, Missing, Missing, Missing, Missing>;
+
+/// Base client configuration that is backed by SQLite storage.
+#[cfg(feature = "sqlite")]
+pub type BaseSqlConfig = Config<
+ SqLiteKeyPackageStorage,
+ SqLitePreSharedKeyStorage,
+ SqLiteGroupStateStorage,
+ Missing,
+ DefaultMlsRules,
+ Missing,
+>;
+
+/// Builder for [`Client`]
+///
+/// This is returned by [`Client::builder`] and allows to tweak settings the `Client` will use. At a
+/// minimum, the builder must be told the [`CryptoProvider`] and [`IdentityProvider`] to use. Other
+/// settings have default values. This means that the following
+/// methods must be called before [`ClientBuilder::build`]:
+///
+/// - To specify the [`CryptoProvider`]: [`ClientBuilder::crypto_provider`]
+/// - To specify the [`IdentityProvider`]: [`ClientBuilder::identity_provider`]
+///
+/// # Example
+///
+/// ```
+/// use mls_rs::{
+/// Client,
+/// identity::{SigningIdentity, basic::{BasicIdentityProvider, BasicCredential}},
+/// CipherSuite,
+/// };
+///
+/// use mls_rs_crypto_openssl::OpensslCryptoProvider;
+///
+/// // Replace by code to load the certificate and secret key
+/// let secret_key = b"never hard-code secrets".to_vec().into();
+/// let public_key = b"test invalid public key".to_vec().into();
+/// let basic_identity = BasicCredential::new(b"name".to_vec());
+/// let signing_identity = SigningIdentity::new(basic_identity.into_credential(), public_key);
+///
+///
+/// let _client = Client::builder()
+/// .crypto_provider(OpensslCryptoProvider::default())
+/// .identity_provider(BasicIdentityProvider::new())
+/// .signing_identity(signing_identity, secret_key, CipherSuite::CURVE25519_AES128)
+/// .build();
+/// ```
+///
+/// # Spelling out a `Client` type
+///
+/// There are two main ways to spell out a `Client` type if needed (e.g. function return type).
+///
+/// The first option uses `impl MlsConfig`:
+/// ```
+/// use mls_rs::{
+/// Client,
+/// client_builder::MlsConfig,
+/// identity::{SigningIdentity, basic::{BasicIdentityProvider, BasicCredential}},
+/// CipherSuite,
+/// };
+///
+/// use mls_rs_crypto_openssl::OpensslCryptoProvider;
+///
+/// fn make_client() -> Client<impl MlsConfig> {
+/// // Replace by code to load the certificate and secret key
+/// let secret_key = b"never hard-code secrets".to_vec().into();
+/// let public_key = b"test invalid public key".to_vec().into();
+/// let basic_identity = BasicCredential::new(b"name".to_vec());
+/// let signing_identity = SigningIdentity::new(basic_identity.into_credential(), public_key);
+///
+/// Client::builder()
+/// .crypto_provider(OpensslCryptoProvider::default())
+/// .identity_provider(BasicIdentityProvider::new())
+/// .signing_identity(signing_identity, secret_key, CipherSuite::CURVE25519_AES128)
+/// .build()
+/// }
+///```
+///
+/// The second option is more verbose and consists in writing the full `Client` type:
+/// ```
+/// use mls_rs::{
+/// Client,
+/// client_builder::{BaseConfig, WithIdentityProvider, WithCryptoProvider},
+/// identity::{SigningIdentity, basic::{BasicIdentityProvider, BasicCredential}},
+/// CipherSuite,
+/// };
+///
+/// use mls_rs_crypto_openssl::OpensslCryptoProvider;
+///
+/// type MlsClient = Client<
+/// WithIdentityProvider<
+/// BasicIdentityProvider,
+/// WithCryptoProvider<OpensslCryptoProvider, BaseConfig>,
+/// >,
+/// >;
+///
+/// fn make_client_2() -> MlsClient {
+/// // Replace by code to load the certificate and secret key
+/// let secret_key = b"never hard-code secrets".to_vec().into();
+/// let public_key = b"test invalid public key".to_vec().into();
+/// let basic_identity = BasicCredential::new(b"name".to_vec());
+/// let signing_identity = SigningIdentity::new(basic_identity.into_credential(), public_key);
+///
+/// Client::builder()
+/// .crypto_provider(OpensslCryptoProvider::default())
+/// .identity_provider(BasicIdentityProvider::new())
+/// .signing_identity(signing_identity, secret_key, CipherSuite::CURVE25519_AES128)
+/// .build()
+/// }
+///
+/// ```
+#[derive(Debug)]
+pub struct ClientBuilder<C>(C);
+
+impl Default for ClientBuilder<BaseConfig> {
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
+impl<C> ClientBuilder<C> {
+ pub(crate) fn from_config(c: C) -> Self {
+ Self(c)
+ }
+}
+
+impl ClientBuilder<BaseConfig> {
+ /// Create a new client builder with default in-memory providers
+ pub fn new() -> Self {
+ Self(Config(ConfigInner {
+ settings: Default::default(),
+ key_package_repo: Default::default(),
+ psk_store: Default::default(),
+ group_state_storage: Default::default(),
+ identity_provider: Missing,
+ mls_rules: DefaultMlsRules::new(),
+ crypto_provider: Missing,
+ signer: Default::default(),
+ signing_identity: Default::default(),
+ version: ProtocolVersion::MLS_10,
+ }))
+ }
+}
+
+impl ClientBuilder<EmptyConfig> {
+ pub fn new_empty() -> Self {
+ Self(Config(ConfigInner {
+ settings: Default::default(),
+ key_package_repo: Missing,
+ psk_store: Missing,
+ group_state_storage: Missing,
+ identity_provider: Missing,
+ mls_rules: Missing,
+ crypto_provider: Missing,
+ signer: Default::default(),
+ signing_identity: Default::default(),
+ version: ProtocolVersion::MLS_10,
+ }))
+ }
+}
+
+#[cfg(feature = "sqlite")]
+impl ClientBuilder<BaseSqlConfig> {
+ /// Create a new client builder with SQLite storage providers.
+ pub fn new_sqlite<CS: ConnectionStrategy>(
+ storage: SqLiteDataStorageEngine<CS>,
+ ) -> Result<Self, SqLiteDataStorageError> {
+ Ok(Self(Config(ConfigInner {
+ settings: Default::default(),
+ key_package_repo: storage.key_package_storage()?,
+ psk_store: storage.pre_shared_key_storage()?,
+ group_state_storage: storage.group_state_storage()?,
+ identity_provider: Missing,
+ mls_rules: DefaultMlsRules::new(),
+ crypto_provider: Missing,
+ signer: Default::default(),
+ signing_identity: Default::default(),
+ version: ProtocolVersion::MLS_10,
+ })))
+ }
+}
+
+impl<C: IntoConfig> ClientBuilder<C> {
+ /// Add an extension type to the list of extension types supported by the client.
+ pub fn extension_type(self, type_: ExtensionType) -> ClientBuilder<IntoConfigOutput<C>> {
+ self.extension_types(Some(type_))
+ }
+
+ /// Add multiple extension types to the list of extension types supported by the client.
+ pub fn extension_types<I>(self, types: I) -> ClientBuilder<IntoConfigOutput<C>>
+ where
+ I: IntoIterator<Item = ExtensionType>,
+ {
+ let mut c = self.0.into_config();
+ c.0.settings.extension_types.extend(types);
+ ClientBuilder(c)
+ }
+
+ /// Add a custom proposal type to the list of proposals types supported by the client.
+ pub fn custom_proposal_type(self, type_: ProposalType) -> ClientBuilder<IntoConfigOutput<C>> {
+ self.custom_proposal_types(Some(type_))
+ }
+
+ /// Add multiple custom proposal types to the list of proposal types supported by the client.
+ pub fn custom_proposal_types<I>(self, types: I) -> ClientBuilder<IntoConfigOutput<C>>
+ where
+ I: IntoIterator<Item = ProposalType>,
+ {
+ let mut c = self.0.into_config();
+ c.0.settings.custom_proposal_types.extend(types);
+ ClientBuilder(c)
+ }
+
+ /// Add a protocol version to the list of protocol versions supported by the client.
+ ///
+ /// If no protocol version is explicitly added, the client will support all protocol versions
+ /// supported by this crate.
+ pub fn protocol_version(self, version: ProtocolVersion) -> ClientBuilder<IntoConfigOutput<C>> {
+ self.protocol_versions(Some(version))
+ }
+
+ /// Add multiple protocol versions to the list of protocol versions supported by the client.
+ ///
+ /// If no protocol version is explicitly added, the client will support all protocol versions
+ /// supported by this crate.
+ pub fn protocol_versions<I>(self, versions: I) -> ClientBuilder<IntoConfigOutput<C>>
+ where
+ I: IntoIterator<Item = ProtocolVersion>,
+ {
+ let mut c = self.0.into_config();
+ c.0.settings.protocol_versions.extend(versions);
+ ClientBuilder(c)
+ }
+
+ /// Add a key package extension to the list of key package extensions supported by the client.
+ pub fn key_package_extension<T>(
+ self,
+ extension: T,
+ ) -> Result<ClientBuilder<IntoConfigOutput<C>>, ExtensionError>
+ where
+ T: MlsExtension,
+ Self: Sized,
+ {
+ let mut c = self.0.into_config();
+ c.0.settings.key_package_extensions.set_from(extension)?;
+ Ok(ClientBuilder(c))
+ }
+
+ /// Add multiple key package extensions to the list of key package extensions supported by the
+ /// client.
+ pub fn key_package_extensions(
+ self,
+ extensions: ExtensionList,
+ ) -> ClientBuilder<IntoConfigOutput<C>> {
+ let mut c = self.0.into_config();
+ c.0.settings.key_package_extensions.append(extensions);
+ ClientBuilder(c)
+ }
+
+ /// Add a leaf node extension to the list of leaf node extensions supported by the client.
+ pub fn leaf_node_extension<T>(
+ self,
+ extension: T,
+ ) -> Result<ClientBuilder<IntoConfigOutput<C>>, ExtensionError>
+ where
+ T: MlsExtension,
+ Self: Sized,
+ {
+ let mut c = self.0.into_config();
+ c.0.settings.leaf_node_extensions.set_from(extension)?;
+ Ok(ClientBuilder(c))
+ }
+
+ /// Add multiple leaf node extensions to the list of leaf node extensions supported by the
+ /// client.
+ pub fn leaf_node_extensions(
+ self,
+ extensions: ExtensionList,
+ ) -> ClientBuilder<IntoConfigOutput<C>> {
+ let mut c = self.0.into_config();
+ c.0.settings.leaf_node_extensions.append(extensions);
+ ClientBuilder(c)
+ }
+
+ /// Set the lifetime duration in seconds of key packages generated by the client.
+ pub fn key_package_lifetime(self, duration_in_s: u64) -> ClientBuilder<IntoConfigOutput<C>> {
+ let mut c = self.0.into_config();
+ c.0.settings.lifetime_in_s = duration_in_s;
+ ClientBuilder(c)
+ }
+
+ /// Set the key package repository to be used by the client.
+ ///
+ /// By default, an in-memory repository is used.
+ pub fn key_package_repo<K>(self, key_package_repo: K) -> ClientBuilder<WithKeyPackageRepo<K, C>>
+ where
+ K: KeyPackageStorage,
+ {
+ let Config(c) = self.0.into_config();
+
+ ClientBuilder(Config(ConfigInner {
+ settings: c.settings,
+ key_package_repo,
+ psk_store: c.psk_store,
+ group_state_storage: c.group_state_storage,
+ identity_provider: c.identity_provider,
+ mls_rules: c.mls_rules,
+ crypto_provider: c.crypto_provider,
+ signer: c.signer,
+ signing_identity: c.signing_identity,
+ version: c.version,
+ }))
+ }
+
+ /// Set the PSK store to be used by the client.
+ ///
+ /// By default, an in-memory store is used.
+ pub fn psk_store<P>(self, psk_store: P) -> ClientBuilder<WithPskStore<P, C>>
+ where
+ P: PreSharedKeyStorage,
+ {
+ let Config(c) = self.0.into_config();
+
+ ClientBuilder(Config(ConfigInner {
+ settings: c.settings,
+ key_package_repo: c.key_package_repo,
+ psk_store,
+ group_state_storage: c.group_state_storage,
+ identity_provider: c.identity_provider,
+ mls_rules: c.mls_rules,
+ crypto_provider: c.crypto_provider,
+ signer: c.signer,
+ signing_identity: c.signing_identity,
+ version: c.version,
+ }))
+ }
+
+ /// Set the group state storage to be used by the client.
+ ///
+ /// By default, an in-memory storage is used.
+ pub fn group_state_storage<G>(
+ self,
+ group_state_storage: G,
+ ) -> ClientBuilder<WithGroupStateStorage<G, C>>
+ where
+ G: GroupStateStorage,
+ {
+ let Config(c) = self.0.into_config();
+
+ ClientBuilder(Config(ConfigInner {
+ settings: c.settings,
+ key_package_repo: c.key_package_repo,
+ psk_store: c.psk_store,
+ group_state_storage,
+ identity_provider: c.identity_provider,
+ crypto_provider: c.crypto_provider,
+ mls_rules: c.mls_rules,
+ signer: c.signer,
+ signing_identity: c.signing_identity,
+ version: c.version,
+ }))
+ }
+
+ /// Set the identity validator to be used by the client.
+ pub fn identity_provider<I>(
+ self,
+ identity_provider: I,
+ ) -> ClientBuilder<WithIdentityProvider<I, C>>
+ where
+ I: IdentityProvider,
+ {
+ let Config(c) = self.0.into_config();
+
+ ClientBuilder(Config(ConfigInner {
+ settings: c.settings,
+ key_package_repo: c.key_package_repo,
+ psk_store: c.psk_store,
+ group_state_storage: c.group_state_storage,
+ identity_provider,
+ mls_rules: c.mls_rules,
+ crypto_provider: c.crypto_provider,
+ signer: c.signer,
+ signing_identity: c.signing_identity,
+ version: c.version,
+ }))
+ }
+
+ /// Set the crypto provider to be used by the client.
+ pub fn crypto_provider<Cp>(
+ self,
+ crypto_provider: Cp,
+ ) -> ClientBuilder<WithCryptoProvider<Cp, C>>
+ where
+ Cp: CryptoProvider,
+ {
+ let Config(c) = self.0.into_config();
+
+ ClientBuilder(Config(ConfigInner {
+ settings: c.settings,
+ key_package_repo: c.key_package_repo,
+ psk_store: c.psk_store,
+ group_state_storage: c.group_state_storage,
+ identity_provider: c.identity_provider,
+ mls_rules: c.mls_rules,
+ crypto_provider,
+ signer: c.signer,
+ signing_identity: c.signing_identity,
+ version: c.version,
+ }))
+ }
+
+ /// Set the user-defined proposal rules to be used by the client.
+ ///
+ /// User-defined rules are used when sending and receiving commits before
+ /// enforcing general MLS protocol rules. If the rule set returns an error when
+ /// receiving a commit, the entire commit is considered invalid. If the
+ /// rule set would return an error when sending a commit, individual proposals
+ /// may be filtered out to compensate.
+ pub fn mls_rules<Pr>(self, mls_rules: Pr) -> ClientBuilder<WithMlsRules<Pr, C>>
+ where
+ Pr: MlsRules,
+ {
+ let Config(c) = self.0.into_config();
+
+ ClientBuilder(Config(ConfigInner {
+ settings: c.settings,
+ key_package_repo: c.key_package_repo,
+ psk_store: c.psk_store,
+ group_state_storage: c.group_state_storage,
+ identity_provider: c.identity_provider,
+ mls_rules,
+ crypto_provider: c.crypto_provider,
+ signer: c.signer,
+ signing_identity: c.signing_identity,
+ version: c.version,
+ }))
+ }
+
+ /// Set the protocol version used by the client. By default, the client uses version MLS 1.0
+ pub fn used_protocol_version(
+ self,
+ version: ProtocolVersion,
+ ) -> ClientBuilder<IntoConfigOutput<C>> {
+ let mut c = self.0.into_config();
+ c.0.version = version;
+ ClientBuilder(c)
+ }
+
+ /// Set the signing identity used by the client as well as the matching signer and cipher suite.
+ /// This must be called in order to create groups and key packages.
+ pub fn signing_identity(
+ self,
+ signing_identity: SigningIdentity,
+ signer: SignatureSecretKey,
+ cipher_suite: CipherSuite,
+ ) -> ClientBuilder<IntoConfigOutput<C>> {
+ let mut c = self.0.into_config();
+ c.0.signer = Some(signer);
+ c.0.signing_identity = Some((signing_identity, cipher_suite));
+ ClientBuilder(c)
+ }
+
+ /// Set the signer used by the client. This must be called in order to join groups.
+ pub fn signer(self, signer: SignatureSecretKey) -> ClientBuilder<IntoConfigOutput<C>> {
+ let mut c = self.0.into_config();
+ c.0.signer = Some(signer);
+ ClientBuilder(c)
+ }
+
+ #[cfg(any(test, feature = "test_util"))]
+ pub(crate) fn key_package_not_before(
+ self,
+ key_package_not_before: u64,
+ ) -> ClientBuilder<IntoConfigOutput<C>> {
+ let mut c = self.0.into_config();
+ c.0.settings.key_package_not_before = Some(key_package_not_before);
+ ClientBuilder(c)
+ }
+}
+
+impl<C: IntoConfig> ClientBuilder<C>
+where
+ C::KeyPackageRepository: KeyPackageStorage + Clone,
+ C::PskStore: PreSharedKeyStorage + Clone,
+ C::GroupStateStorage: GroupStateStorage + Clone,
+ C::IdentityProvider: IdentityProvider + Clone,
+ C::MlsRules: MlsRules + Clone,
+ C::CryptoProvider: CryptoProvider + Clone,
+{
+ pub(crate) fn build_config(self) -> IntoConfigOutput<C> {
+ let mut c = self.0.into_config();
+
+ if c.0.settings.protocol_versions.is_empty() {
+ c.0.settings.protocol_versions = ProtocolVersion::all().collect();
+ }
+
+ c
+ }
+
+ /// Build a client.
+ ///
+ /// See [`ClientBuilder`] documentation if the return type of this function needs to be spelled
+ /// out.
+ pub fn build(self) -> Client<IntoConfigOutput<C>> {
+ let mut c = self.build_config();
+ let version = c.0.version;
+ let signer = c.0.signer.take();
+ let signing_identity = c.0.signing_identity.take();
+
+ Client::new(c, signer, signing_identity, version)
+ }
+}
+
+impl<C: IntoConfig<PskStore = InMemoryPreSharedKeyStorage>> ClientBuilder<C> {
+ /// Add a PSK to the in-memory PSK store.
+ pub fn psk(
+ self,
+ psk_id: ExternalPskId,
+ psk: PreSharedKey,
+ ) -> ClientBuilder<IntoConfigOutput<C>> {
+ let mut c = self.0.into_config();
+ c.0.psk_store.insert(psk_id, psk);
+ ClientBuilder(c)
+ }
+}
+
+/// Marker type for required `ClientBuilder` services that have not been specified yet.
+#[derive(Debug)]
+pub struct Missing;
+
+/// Change the key package repository used by a client configuration.
+///
+/// See [`ClientBuilder::key_package_repo`].
+pub type WithKeyPackageRepo<K, C> = Config<
+ K,
+ <C as IntoConfig>::PskStore,
+ <C as IntoConfig>::GroupStateStorage,
+ <C as IntoConfig>::IdentityProvider,
+ <C as IntoConfig>::MlsRules,
+ <C as IntoConfig>::CryptoProvider,
+>;
+
+/// Change the PSK store used by a client configuration.
+///
+/// See [`ClientBuilder::psk_store`].
+pub type WithPskStore<P, C> = Config<
+ <C as IntoConfig>::KeyPackageRepository,
+ P,
+ <C as IntoConfig>::GroupStateStorage,
+ <C as IntoConfig>::IdentityProvider,
+ <C as IntoConfig>::MlsRules,
+ <C as IntoConfig>::CryptoProvider,
+>;
+
+/// Change the group state storage used by a client configuration.
+///
+/// See [`ClientBuilder::group_state_storage`].
+pub type WithGroupStateStorage<G, C> = Config<
+ <C as IntoConfig>::KeyPackageRepository,
+ <C as IntoConfig>::PskStore,
+ G,
+ <C as IntoConfig>::IdentityProvider,
+ <C as IntoConfig>::MlsRules,
+ <C as IntoConfig>::CryptoProvider,
+>;
+
+/// Change the identity validator used by a client configuration.
+///
+/// See [`ClientBuilder::identity_provider`].
+pub type WithIdentityProvider<I, C> = Config<
+ <C as IntoConfig>::KeyPackageRepository,
+ <C as IntoConfig>::PskStore,
+ <C as IntoConfig>::GroupStateStorage,
+ I,
+ <C as IntoConfig>::MlsRules,
+ <C as IntoConfig>::CryptoProvider,
+>;
+
+/// Change the proposal rules used by a client configuration.
+///
+/// See [`ClientBuilder::mls_rules`].
+pub type WithMlsRules<Pr, C> = Config<
+ <C as IntoConfig>::KeyPackageRepository,
+ <C as IntoConfig>::PskStore,
+ <C as IntoConfig>::GroupStateStorage,
+ <C as IntoConfig>::IdentityProvider,
+ Pr,
+ <C as IntoConfig>::CryptoProvider,
+>;
+
+/// Change the crypto provider used by a client configuration.
+///
+/// See [`ClientBuilder::crypto_provider`].
+pub type WithCryptoProvider<Cp, C> = Config<
+ <C as IntoConfig>::KeyPackageRepository,
+ <C as IntoConfig>::PskStore,
+ <C as IntoConfig>::GroupStateStorage,
+ <C as IntoConfig>::IdentityProvider,
+ <C as IntoConfig>::MlsRules,
+ Cp,
+>;
+
+/// Helper alias for `Config`.
+pub type IntoConfigOutput<C> = Config<
+ <C as IntoConfig>::KeyPackageRepository,
+ <C as IntoConfig>::PskStore,
+ <C as IntoConfig>::GroupStateStorage,
+ <C as IntoConfig>::IdentityProvider,
+ <C as IntoConfig>::MlsRules,
+ <C as IntoConfig>::CryptoProvider,
+>;
+
+/// Helper alias to make a `Config` from a `ClientConfig`
+pub type MakeConfig<C> = Config<
+ <C as ClientConfig>::KeyPackageRepository,
+ <C as ClientConfig>::PskStore,
+ <C as ClientConfig>::GroupStateStorage,
+ <C as ClientConfig>::IdentityProvider,
+ <C as ClientConfig>::MlsRules,
+ <C as ClientConfig>::CryptoProvider,
+>;
+
+impl<Kpr, Ps, Gss, Ip, Pr, Cp> ClientConfig for ConfigInner<Kpr, Ps, Gss, Ip, Pr, Cp>
+where
+ Kpr: KeyPackageStorage + Clone,
+ Ps: PreSharedKeyStorage + Clone,
+ Gss: GroupStateStorage + Clone,
+ Ip: IdentityProvider + Clone,
+ Pr: MlsRules + Clone,
+ Cp: CryptoProvider + Clone,
+{
+ type KeyPackageRepository = Kpr;
+ type PskStore = Ps;
+ type GroupStateStorage = Gss;
+ type IdentityProvider = Ip;
+ type MlsRules = Pr;
+ type CryptoProvider = Cp;
+
+ fn supported_extensions(&self) -> Vec<ExtensionType> {
+ self.settings.extension_types.clone()
+ }
+
+ fn supported_protocol_versions(&self) -> Vec<ProtocolVersion> {
+ self.settings.protocol_versions.clone()
+ }
+
+ fn key_package_repo(&self) -> Self::KeyPackageRepository {
+ self.key_package_repo.clone()
+ }
+
+ fn mls_rules(&self) -> Self::MlsRules {
+ self.mls_rules.clone()
+ }
+
+ fn secret_store(&self) -> Self::PskStore {
+ self.psk_store.clone()
+ }
+
+ fn group_state_storage(&self) -> Self::GroupStateStorage {
+ self.group_state_storage.clone()
+ }
+
+ fn identity_provider(&self) -> Self::IdentityProvider {
+ self.identity_provider.clone()
+ }
+
+ fn crypto_provider(&self) -> Self::CryptoProvider {
+ self.crypto_provider.clone()
+ }
+
+ fn key_package_extensions(&self) -> ExtensionList {
+ self.settings.key_package_extensions.clone()
+ }
+
+ fn leaf_node_extensions(&self) -> ExtensionList {
+ self.settings.leaf_node_extensions.clone()
+ }
+
+ fn lifetime(&self) -> Lifetime {
+ #[cfg(feature = "std")]
+ let now_timestamp = MlsTime::now().seconds_since_epoch();
+
+ #[cfg(not(feature = "std"))]
+ let now_timestamp = 0;
+
+ #[cfg(test)]
+ let now_timestamp = self
+ .settings
+ .key_package_not_before
+ .unwrap_or(now_timestamp);
+
+ Lifetime {
+ not_before: now_timestamp,
+ not_after: now_timestamp + self.settings.lifetime_in_s,
+ }
+ }
+
+ fn supported_custom_proposals(&self) -> Vec<crate::group::proposal::ProposalType> {
+ self.settings.custom_proposal_types.clone()
+ }
+}
+
+impl<Kpr, Ps, Gss, Ip, Pr, Cp> Sealed for Config<Kpr, Ps, Gss, Ip, Pr, Cp> {}
+
+impl<Kpr, Ps, Gss, Ip, Pr, Cp> MlsConfig for Config<Kpr, Ps, Gss, Ip, Pr, Cp>
+where
+ Kpr: KeyPackageStorage + Clone,
+
+ Ps: PreSharedKeyStorage + Clone,
+ Gss: GroupStateStorage + Clone,
+ Ip: IdentityProvider + Clone,
+ Pr: MlsRules + Clone,
+ Cp: CryptoProvider + Clone,
+{
+ type Output = ConfigInner<Kpr, Ps, Gss, Ip, Pr, Cp>;
+
+ fn get(&self) -> &Self::Output {
+ &self.0
+ }
+}
+
+/// Helper trait to allow consuming crates to easily write a client type as `Client<impl MlsConfig>`
+///
+/// It is not meant to be implemented by consuming crates. `T: MlsConfig` implies `T: ClientConfig`.
+pub trait MlsConfig: Clone + Send + Sync + Sealed {
+ #[doc(hidden)]
+ type Output: ClientConfig;
+
+ #[doc(hidden)]
+ fn get(&self) -> &Self::Output;
+}
+
+/// Blanket implementation so that `T: MlsConfig` implies `T: ClientConfig`
+impl<T: MlsConfig> ClientConfig for T {
+ type KeyPackageRepository = <T::Output as ClientConfig>::KeyPackageRepository;
+ type PskStore = <T::Output as ClientConfig>::PskStore;
+ type GroupStateStorage = <T::Output as ClientConfig>::GroupStateStorage;
+ type IdentityProvider = <T::Output as ClientConfig>::IdentityProvider;
+ type MlsRules = <T::Output as ClientConfig>::MlsRules;
+ type CryptoProvider = <T::Output as ClientConfig>::CryptoProvider;
+
+ fn supported_extensions(&self) -> Vec<ExtensionType> {
+ self.get().supported_extensions()
+ }
+
+ fn supported_custom_proposals(&self) -> Vec<ProposalType> {
+ self.get().supported_custom_proposals()
+ }
+
+ fn supported_protocol_versions(&self) -> Vec<ProtocolVersion> {
+ self.get().supported_protocol_versions()
+ }
+
+ fn key_package_repo(&self) -> Self::KeyPackageRepository {
+ self.get().key_package_repo()
+ }
+
+ fn mls_rules(&self) -> Self::MlsRules {
+ self.get().mls_rules()
+ }
+
+ fn secret_store(&self) -> Self::PskStore {
+ self.get().secret_store()
+ }
+
+ fn group_state_storage(&self) -> Self::GroupStateStorage {
+ self.get().group_state_storage()
+ }
+
+ fn identity_provider(&self) -> Self::IdentityProvider {
+ self.get().identity_provider()
+ }
+
+ fn crypto_provider(&self) -> Self::CryptoProvider {
+ self.get().crypto_provider()
+ }
+
+ fn key_package_extensions(&self) -> ExtensionList {
+ self.get().key_package_extensions()
+ }
+
+ fn leaf_node_extensions(&self) -> ExtensionList {
+ self.get().leaf_node_extensions()
+ }
+
+ fn lifetime(&self) -> Lifetime {
+ self.get().lifetime()
+ }
+
+ fn capabilities(&self) -> Capabilities {
+ self.get().capabilities()
+ }
+
+ fn version_supported(&self, version: ProtocolVersion) -> bool {
+ self.get().version_supported(version)
+ }
+
+ fn supported_credential_types(&self) -> Vec<CredentialType> {
+ self.get().supported_credential_types()
+ }
+}
+
+#[derive(Clone, Debug)]
+pub(crate) struct Settings {
+ pub(crate) extension_types: Vec<ExtensionType>,
+ pub(crate) protocol_versions: Vec<ProtocolVersion>,
+ pub(crate) custom_proposal_types: Vec<ProposalType>,
+ pub(crate) key_package_extensions: ExtensionList,
+ pub(crate) leaf_node_extensions: ExtensionList,
+ pub(crate) lifetime_in_s: u64,
+ #[cfg(any(test, feature = "test_util"))]
+ pub(crate) key_package_not_before: Option<u64>,
+}
+
+impl Default for Settings {
+ fn default() -> Self {
+ Self {
+ extension_types: Default::default(),
+ protocol_versions: Default::default(),
+ key_package_extensions: Default::default(),
+ leaf_node_extensions: Default::default(),
+ lifetime_in_s: 365 * 24 * 3600,
+ custom_proposal_types: Default::default(),
+ #[cfg(any(test, feature = "test_util"))]
+ key_package_not_before: None,
+ }
+ }
+}
+
+pub(crate) fn recreate_config<T: ClientConfig>(
+ c: T,
+ signer: Option<SignatureSecretKey>,
+ signing_identity: Option<(SigningIdentity, CipherSuite)>,
+ version: ProtocolVersion,
+) -> MakeConfig<T> {
+ Config(ConfigInner {
+ settings: Settings {
+ extension_types: c.supported_extensions(),
+ protocol_versions: c.supported_protocol_versions(),
+ custom_proposal_types: c.supported_custom_proposals(),
+ key_package_extensions: c.key_package_extensions(),
+ leaf_node_extensions: c.leaf_node_extensions(),
+ lifetime_in_s: {
+ let l = c.lifetime();
+ l.not_after - l.not_before
+ },
+ #[cfg(any(test, feature = "test_util"))]
+ key_package_not_before: None,
+ },
+ key_package_repo: c.key_package_repo(),
+ psk_store: c.secret_store(),
+ group_state_storage: c.group_state_storage(),
+ identity_provider: c.identity_provider(),
+ mls_rules: c.mls_rules(),
+ crypto_provider: c.crypto_provider(),
+ signer,
+ signing_identity,
+ version,
+ })
+}
+
+/// Definitions meant to be private that are inaccessible outside this crate. They need to be marked
+/// `pub` because they appear in public definitions.
+mod private {
+ use mls_rs_core::{
+ crypto::{CipherSuite, SignatureSecretKey},
+ identity::SigningIdentity,
+ protocol_version::ProtocolVersion,
+ };
+
+ use crate::client_builder::{IntoConfigOutput, Settings};
+
+ #[derive(Clone, Debug)]
+ pub struct Config<Kpr, Ps, Gss, Ip, Pr, Cp>(pub(crate) ConfigInner<Kpr, Ps, Gss, Ip, Pr, Cp>);
+
+ #[derive(Clone, Debug)]
+ pub struct ConfigInner<Kpr, Ps, Gss, Ip, Pr, Cp> {
+ pub(crate) settings: Settings,
+ pub(crate) key_package_repo: Kpr,
+ pub(crate) psk_store: Ps,
+ pub(crate) group_state_storage: Gss,
+ pub(crate) identity_provider: Ip,
+ pub(crate) mls_rules: Pr,
+ pub(crate) crypto_provider: Cp,
+ pub(crate) signer: Option<SignatureSecretKey>,
+ pub(crate) signing_identity: Option<(SigningIdentity, CipherSuite)>,
+ pub(crate) version: ProtocolVersion,
+ }
+
+ pub trait IntoConfig {
+ type KeyPackageRepository;
+ type PskStore;
+ type GroupStateStorage;
+ type IdentityProvider;
+ type MlsRules;
+ type CryptoProvider;
+
+ fn into_config(self) -> IntoConfigOutput<Self>;
+ }
+
+ impl<Kpr, Ps, Gss, Ip, Pr, Cp> IntoConfig for Config<Kpr, Ps, Gss, Ip, Pr, Cp> {
+ type KeyPackageRepository = Kpr;
+ type PskStore = Ps;
+ type GroupStateStorage = Gss;
+ type IdentityProvider = Ip;
+ type MlsRules = Pr;
+ type CryptoProvider = Cp;
+
+ fn into_config(self) -> Self {
+ self
+ }
+ }
+}
+
+use mls_rs_core::{
+ crypto::{CryptoProvider, SignatureSecretKey},
+ extension::{ExtensionError, ExtensionList},
+ group::GroupStateStorage,
+ identity::IdentityProvider,
+ key_package::KeyPackageStorage,
+ psk::PreSharedKeyStorage,
+};
+use private::{Config, ConfigInner, IntoConfig};
+
+#[cfg(test)]
+pub(crate) mod test_utils {
+ use crate::{
+ client_builder::{BaseConfig, ClientBuilder, WithIdentityProvider},
+ crypto::test_utils::TestCryptoProvider,
+ identity::{
+ basic::BasicIdentityProvider,
+ test_utils::{get_test_signing_identity, BasicWithCustomProvider},
+ },
+ CipherSuite,
+ };
+
+ use super::WithCryptoProvider;
+
+ pub type TestClientConfig = WithIdentityProvider<
+ BasicWithCustomProvider,
+ WithCryptoProvider<TestCryptoProvider, BaseConfig>,
+ >;
+
+ pub type TestClientBuilder = ClientBuilder<TestClientConfig>;
+
+ impl TestClientBuilder {
+ pub fn new_for_test() -> Self {
+ ClientBuilder::new()
+ .crypto_provider(TestCryptoProvider::new())
+ .identity_provider(BasicWithCustomProvider::new(BasicIdentityProvider::new()))
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn with_random_signing_identity(
+ self,
+ identity: &str,
+ cipher_suite: CipherSuite,
+ ) -> Self {
+ let (signing_identity, signer) =
+ get_test_signing_identity(cipher_suite, identity.as_bytes()).await;
+ self.signing_identity(signing_identity, signer, cipher_suite)
+ }
+ }
+}
diff --git a/src/client_config.rs b/src/client_config.rs
new file mode 100644
index 0000000..339f335
--- /dev/null
+++ b/src/client_config.rs
@@ -0,0 +1,68 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use crate::{
+ extension::ExtensionType,
+ group::{mls_rules::MlsRules, proposal::ProposalType},
+ identity::CredentialType,
+ protocol_version::ProtocolVersion,
+ tree_kem::{leaf_node::ConfigProperties, Capabilities, Lifetime},
+ ExtensionList,
+};
+use alloc::vec::Vec;
+use mls_rs_core::{
+ crypto::CryptoProvider, group::GroupStateStorage, identity::IdentityProvider,
+ key_package::KeyPackageStorage, psk::PreSharedKeyStorage,
+};
+
+pub trait ClientConfig: Send + Sync + Clone {
+ type KeyPackageRepository: KeyPackageStorage + Clone;
+ type PskStore: PreSharedKeyStorage + Clone;
+ type GroupStateStorage: GroupStateStorage + Clone;
+ type IdentityProvider: IdentityProvider + Clone;
+ type MlsRules: MlsRules + Clone;
+ type CryptoProvider: CryptoProvider + Clone;
+
+ fn supported_extensions(&self) -> Vec<ExtensionType>;
+ fn supported_custom_proposals(&self) -> Vec<ProposalType>;
+ fn supported_protocol_versions(&self) -> Vec<ProtocolVersion>;
+
+ fn key_package_repo(&self) -> Self::KeyPackageRepository;
+
+ fn mls_rules(&self) -> Self::MlsRules;
+
+ fn secret_store(&self) -> Self::PskStore;
+ fn group_state_storage(&self) -> Self::GroupStateStorage;
+ fn identity_provider(&self) -> Self::IdentityProvider;
+ fn crypto_provider(&self) -> Self::CryptoProvider;
+
+ fn key_package_extensions(&self) -> ExtensionList;
+ fn leaf_node_extensions(&self) -> ExtensionList;
+ fn lifetime(&self) -> Lifetime;
+
+ fn capabilities(&self) -> Capabilities {
+ Capabilities {
+ protocol_versions: self.supported_protocol_versions(),
+ cipher_suites: self.crypto_provider().supported_cipher_suites(),
+ extensions: self.supported_extensions(),
+ proposals: self.supported_custom_proposals(),
+ credentials: self.supported_credential_types(),
+ }
+ }
+
+ fn version_supported(&self, version: ProtocolVersion) -> bool {
+ self.supported_protocol_versions().contains(&version)
+ }
+
+ fn supported_credential_types(&self) -> Vec<CredentialType> {
+ self.identity_provider().supported_types()
+ }
+
+ fn leaf_properties(&self) -> ConfigProperties {
+ ConfigProperties {
+ capabilities: self.capabilities(),
+ extensions: self.leaf_node_extensions(),
+ }
+ }
+}
diff --git a/src/crypto.rs b/src/crypto.rs
new file mode 100644
index 0000000..795476a
--- /dev/null
+++ b/src/crypto.rs
@@ -0,0 +1,43 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+pub(crate) use mls_rs_core::crypto::CipherSuiteProvider;
+
+pub use mls_rs_core::crypto::{
+ HpkeCiphertext, HpkeContextR, HpkeContextS, HpkePublicKey, HpkeSecretKey, SignaturePublicKey,
+ SignatureSecretKey,
+};
+
+pub use mls_rs_core::secret::Secret;
+
+#[cfg(test)]
+pub(crate) mod test_utils {
+ use cfg_if::cfg_if;
+ use mls_rs_core::crypto::CryptoProvider;
+
+ cfg_if! {
+ if #[cfg(target_arch = "wasm32")] {
+ pub use mls_rs_crypto_webcrypto::WebCryptoProvider as TestCryptoProvider;
+ } else {
+ pub use mls_rs_crypto_openssl::OpensslCryptoProvider as TestCryptoProvider;
+ }
+ }
+
+ use crate::cipher_suite::CipherSuite;
+
+ pub fn test_cipher_suite_provider(
+ cipher_suite: CipherSuite,
+ ) -> <TestCryptoProvider as CryptoProvider>::CipherSuiteProvider {
+ TestCryptoProvider::new()
+ .cipher_suite_provider(cipher_suite)
+ .unwrap()
+ }
+
+ #[allow(unused)]
+ pub fn try_test_cipher_suite_provider(
+ cipher_suite: u16,
+ ) -> Option<<TestCryptoProvider as CryptoProvider>::CipherSuiteProvider> {
+ TestCryptoProvider::new().cipher_suite_provider(CipherSuite::from(cipher_suite))
+ }
+}
diff --git a/src/extension.rs b/src/extension.rs
new file mode 100644
index 0000000..4cba416
--- /dev/null
+++ b/src/extension.rs
@@ -0,0 +1,52 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+pub use mls_rs_core::extension::{ExtensionType, MlsCodecExtension, MlsExtension};
+
+pub(crate) use built_in::*;
+
+/// Default extension types required by the MLS RFC.
+pub mod built_in;
+
+#[cfg(test)]
+pub(crate) mod test_utils {
+ use alloc::vec::Vec;
+ use core::convert::Infallible;
+ use core::fmt::Debug;
+ use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+ use mls_rs_core::extension::MlsExtension;
+
+ use super::*;
+
+ pub const TEST_EXTENSION_TYPE: u16 = 42;
+
+ #[derive(MlsSize, MlsEncode, MlsDecode, Clone, Debug, PartialEq)]
+ pub(crate) struct TestExtension {
+ pub(crate) foo: u8,
+ }
+
+ impl From<u8> for TestExtension {
+ fn from(value: u8) -> Self {
+ Self { foo: value }
+ }
+ }
+
+ impl MlsExtension for TestExtension {
+ type SerializationError = Infallible;
+
+ type DeserializationError = Infallible;
+
+ fn extension_type() -> ExtensionType {
+ ExtensionType::from(TEST_EXTENSION_TYPE)
+ }
+
+ fn to_bytes(&self) -> Result<Vec<u8>, Self::SerializationError> {
+ Ok([self.foo].to_vec())
+ }
+
+ fn from_bytes(data: &[u8]) -> Result<Self, Self::DeserializationError> {
+ Ok(TestExtension { foo: data[0] })
+ }
+ }
+}
diff --git a/src/extension/built_in.rs b/src/extension/built_in.rs
new file mode 100644
index 0000000..361a112
--- /dev/null
+++ b/src/extension/built_in.rs
@@ -0,0 +1,330 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::vec::Vec;
+use core::fmt::{self, Debug};
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+use mls_rs_core::extension::{ExtensionType, MlsCodecExtension};
+
+use mls_rs_core::{group::ProposalType, identity::CredentialType};
+
+#[cfg(feature = "by_ref_proposal")]
+use mls_rs_core::{
+ extension::ExtensionList,
+ identity::{IdentityProvider, SigningIdentity},
+ time::MlsTime,
+};
+
+use crate::group::ExportedTree;
+
+use mls_rs_core::crypto::HpkePublicKey;
+
+/// Application specific identifier.
+///
+/// A custom application level identifier that can be optionally stored
+/// within the `leaf_node_extensions` of a group [Member](crate::group::Member).
+#[cfg_attr(
+ all(feature = "ffi", not(test)),
+ safer_ffi_gen::ffi_type(clone, opaque)
+)]
+#[derive(Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
+pub struct ApplicationIdExt {
+ /// Application level identifier presented by this extension.
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ pub identifier: Vec<u8>,
+}
+
+impl Debug for ApplicationIdExt {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("ApplicationIdExt")
+ .field(
+ "identifier",
+ &mls_rs_core::debug::pretty_bytes(&self.identifier),
+ )
+ .finish()
+ }
+}
+
+#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
+impl ApplicationIdExt {
+ /// Create a new application level identifier extension.
+ pub fn new(identifier: Vec<u8>) -> Self {
+ ApplicationIdExt { identifier }
+ }
+
+ /// Get the application level identifier presented by this extension.
+ #[cfg(feature = "ffi")]
+ pub fn identifier(&self) -> &[u8] {
+ &self.identifier
+ }
+}
+
+impl MlsCodecExtension for ApplicationIdExt {
+ fn extension_type() -> ExtensionType {
+ ExtensionType::APPLICATION_ID
+ }
+}
+
+/// Representation of an MLS ratchet tree.
+///
+/// Used to provide new members
+/// a copy of the current group state in-band.
+#[cfg_attr(
+ all(feature = "ffi", not(test)),
+ safer_ffi_gen::ffi_type(clone, opaque)
+)]
+#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+pub struct RatchetTreeExt {
+ pub tree_data: ExportedTree<'static>,
+}
+
+#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
+impl RatchetTreeExt {
+ /// Required custom extension types.
+ #[cfg(feature = "ffi")]
+ pub fn tree_data(&self) -> &ExportedTree<'static> {
+ &self.tree_data
+ }
+}
+
+impl MlsCodecExtension for RatchetTreeExt {
+ fn extension_type() -> ExtensionType {
+ ExtensionType::RATCHET_TREE
+ }
+}
+
+/// Require members to have certain capabilities.
+///
+/// Used within a
+/// [Group Context Extensions Proposal](crate::group::proposal::Proposal)
+/// in order to require that all current and future members of a group MUST
+/// support specific extensions, proposals, or credentials.
+///
+/// # Warning
+///
+/// Extension, proposal, and credential types defined by the MLS RFC and
+/// provided are considered required by default and should NOT be used
+/// within this extension.
+#[cfg_attr(
+ all(feature = "ffi", not(test)),
+ safer_ffi_gen::ffi_type(clone, opaque)
+)]
+#[derive(Clone, Debug, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode, Default)]
+pub struct RequiredCapabilitiesExt {
+ pub extensions: Vec<ExtensionType>,
+ pub proposals: Vec<ProposalType>,
+ pub credentials: Vec<CredentialType>,
+}
+
+#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
+impl RequiredCapabilitiesExt {
+ /// Create a required capabilities extension.
+ pub fn new(
+ extensions: Vec<ExtensionType>,
+ proposals: Vec<ProposalType>,
+ credentials: Vec<CredentialType>,
+ ) -> Self {
+ Self {
+ extensions,
+ proposals,
+ credentials,
+ }
+ }
+
+ /// Required custom extension types.
+ #[cfg(feature = "ffi")]
+ pub fn extensions(&self) -> &[ExtensionType] {
+ &self.extensions
+ }
+
+ /// Required custom proposal types.
+ #[cfg(feature = "ffi")]
+ pub fn proposals(&self) -> &[ProposalType] {
+ &self.proposals
+ }
+
+ /// Required custom credential types.
+ #[cfg(feature = "ffi")]
+ pub fn credentials(&self) -> &[CredentialType] {
+ &self.credentials
+ }
+}
+
+impl MlsCodecExtension for RequiredCapabilitiesExt {
+ fn extension_type() -> ExtensionType {
+ ExtensionType::REQUIRED_CAPABILITIES
+ }
+}
+
+/// External public key used for [External Commits](crate::Client::commit_external).
+///
+/// This proposal type is optionally provided as part of a
+/// [Group Info](crate::group::Group::group_info_message).
+#[cfg_attr(
+ all(feature = "ffi", not(test)),
+ safer_ffi_gen::ffi_type(clone, opaque)
+)]
+#[derive(Clone, Debug, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
+pub struct ExternalPubExt {
+ /// Public key to be used for an external commit.
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ pub external_pub: HpkePublicKey,
+}
+
+#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
+impl ExternalPubExt {
+ /// Get the public key to be used for an external commit.
+ #[cfg(feature = "ffi")]
+ pub fn external_pub(&self) -> &HpkePublicKey {
+ &self.external_pub
+ }
+}
+
+impl MlsCodecExtension for ExternalPubExt {
+ fn extension_type() -> ExtensionType {
+ ExtensionType::EXTERNAL_PUB
+ }
+}
+
+/// Enable proposals by an [ExternalClient](crate::external_client::ExternalClient).
+#[cfg(feature = "by_ref_proposal")]
+#[cfg_attr(
+ all(feature = "ffi", not(test)),
+ safer_ffi_gen::ffi_type(clone, opaque)
+)]
+#[derive(Clone, Debug, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
+#[non_exhaustive]
+pub struct ExternalSendersExt {
+ pub allowed_senders: Vec<SigningIdentity>,
+}
+
+#[cfg(feature = "by_ref_proposal")]
+#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
+impl ExternalSendersExt {
+ pub fn new(allowed_senders: Vec<SigningIdentity>) -> Self {
+ Self { allowed_senders }
+ }
+
+ #[cfg(feature = "ffi")]
+ pub fn allowed_senders(&self) -> &[SigningIdentity] {
+ &self.allowed_senders
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn verify_all<I: IdentityProvider>(
+ &self,
+ provider: &I,
+ timestamp: Option<MlsTime>,
+ group_context_extensions: &ExtensionList,
+ ) -> Result<(), I::Error> {
+ for id in self.allowed_senders.iter() {
+ provider
+ .validate_external_sender(id, timestamp, Some(group_context_extensions))
+ .await?;
+ }
+
+ Ok(())
+ }
+}
+
+#[cfg(feature = "by_ref_proposal")]
+impl MlsCodecExtension for ExternalSendersExt {
+ fn extension_type() -> ExtensionType {
+ ExtensionType::EXTERNAL_SENDERS
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ use crate::tree_kem::node::NodeVec;
+ #[cfg(feature = "by_ref_proposal")]
+ use crate::{
+ client::test_utils::TEST_CIPHER_SUITE, identity::test_utils::get_test_signing_identity,
+ };
+
+ use mls_rs_core::extension::MlsExtension;
+
+ use mls_rs_core::identity::BasicCredential;
+
+ use alloc::vec;
+
+ #[cfg(target_arch = "wasm32")]
+ use wasm_bindgen_test::wasm_bindgen_test as test;
+
+ #[test]
+ fn test_application_id_extension() {
+ let test_id = vec![0u8; 32];
+ let test_extension = ApplicationIdExt {
+ identifier: test_id.clone(),
+ };
+
+ let as_extension = test_extension.into_extension().unwrap();
+
+ assert_eq!(as_extension.extension_type, ExtensionType::APPLICATION_ID);
+
+ let restored = ApplicationIdExt::from_extension(&as_extension).unwrap();
+ assert_eq!(restored.identifier, test_id);
+ }
+
+ #[test]
+ fn test_ratchet_tree() {
+ let ext = RatchetTreeExt {
+ tree_data: ExportedTree::new(NodeVec::from(vec![None, None])),
+ };
+
+ let as_extension = ext.clone().into_extension().unwrap();
+ assert_eq!(as_extension.extension_type, ExtensionType::RATCHET_TREE);
+
+ let restored = RatchetTreeExt::from_extension(&as_extension).unwrap();
+ assert_eq!(ext, restored)
+ }
+
+ #[test]
+ fn test_required_capabilities() {
+ let ext = RequiredCapabilitiesExt {
+ extensions: vec![0.into(), 1.into()],
+ proposals: vec![42.into(), 43.into()],
+ credentials: vec![BasicCredential::credential_type()],
+ };
+
+ let as_extension = ext.clone().into_extension().unwrap();
+
+ assert_eq!(
+ as_extension.extension_type,
+ ExtensionType::REQUIRED_CAPABILITIES
+ );
+
+ let restored = RequiredCapabilitiesExt::from_extension(&as_extension).unwrap();
+ assert_eq!(ext, restored)
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_external_senders() {
+ let identity = get_test_signing_identity(TEST_CIPHER_SUITE, &[1]).await.0;
+ let ext = ExternalSendersExt::new(vec![identity]);
+
+ let as_extension = ext.clone().into_extension().unwrap();
+
+ assert_eq!(as_extension.extension_type, ExtensionType::EXTERNAL_SENDERS);
+
+ let restored = ExternalSendersExt::from_extension(&as_extension).unwrap();
+ assert_eq!(ext, restored)
+ }
+
+ #[test]
+ fn test_external_pub() {
+ let ext = ExternalPubExt {
+ external_pub: vec![0, 1, 2, 3].into(),
+ };
+
+ let as_extension = ext.clone().into_extension().unwrap();
+ assert_eq!(as_extension.extension_type, ExtensionType::EXTERNAL_PUB);
+
+ let restored = ExternalPubExt::from_extension(&as_extension).unwrap();
+ assert_eq!(ext, restored)
+ }
+}
diff --git a/src/external_client.rs b/src/external_client.rs
new file mode 100644
index 0000000..0c882ac
--- /dev/null
+++ b/src/external_client.rs
@@ -0,0 +1,142 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use crate::{
+ client::MlsError,
+ group::{framing::MlsMessage, message_processor::validate_key_package, ExportedTree},
+ KeyPackage,
+};
+
+pub mod builder;
+mod config;
+mod group;
+
+pub(crate) use config::ExternalClientConfig;
+use mls_rs_core::{
+ crypto::{CryptoProvider, SignatureSecretKey},
+ identity::SigningIdentity,
+};
+
+use builder::{ExternalBaseConfig, ExternalClientBuilder};
+
+pub use group::{ExternalGroup, ExternalReceivedMessage, ExternalSnapshot};
+
+/// A client capable of observing a group's state without having
+/// private keys required to read content.
+///
+/// This structure is useful when an application is sending
+/// plaintext control messages in order to allow a central server
+/// to facilitate communication between users.
+///
+/// # Warning
+///
+/// This structure will only be able to observe groups that were
+/// created by clients that have the `encrypt_control_messages`
+/// option returned by [`MlsRules::encryption_options`](`crate::MlsRules::encryption_options`)
+/// set to `false`. Any control messages that are sent encrypted
+/// over the wire will break the ability of this client to track
+/// the resulting group state.
+pub struct ExternalClient<C> {
+ config: C,
+ signing_data: Option<(SignatureSecretKey, SigningIdentity)>,
+}
+
+impl ExternalClient<()> {
+ pub fn builder() -> ExternalClientBuilder<ExternalBaseConfig> {
+ ExternalClientBuilder::new()
+ }
+}
+
+impl<C> ExternalClient<C>
+where
+ C: ExternalClientConfig + Clone,
+{
+ pub(crate) fn new(
+ config: C,
+ signing_data: Option<(SignatureSecretKey, SigningIdentity)>,
+ ) -> Self {
+ Self {
+ config,
+ signing_data,
+ }
+ }
+
+ /// Begin observing a group based on a GroupInfo message created by
+ /// [Group::group_info_message](crate::group::Group::group_info_message)
+ ///
+ ///`tree_data` is required to be provided out of band if the client that
+ /// created GroupInfo message did not did not use the `ratchet_tree_extension`
+ /// according to [`MlsRules::commit_options`](crate::MlsRules::commit_options)
+ /// at the time the welcome message
+ /// was created. `tree_data` can be exported from a group using the
+ /// [export tree function](crate::group::Group::export_tree).
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn observe_group(
+ &self,
+ group_info: MlsMessage,
+ tree_data: Option<ExportedTree<'_>>,
+ ) -> Result<ExternalGroup<C>, MlsError> {
+ ExternalGroup::join(
+ self.config.clone(),
+ self.signing_data.clone(),
+ group_info,
+ tree_data,
+ )
+ .await
+ }
+
+ /// Load an existing observed group by loading a snapshot that was
+ /// generated by
+ /// [ExternalGroup::snapshot](self::ExternalGroup::snapshot).
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn load_group(
+ &self,
+ snapshot: ExternalSnapshot,
+ ) -> Result<ExternalGroup<C>, MlsError> {
+ ExternalGroup::from_snapshot(self.config.clone(), snapshot).await
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn validate_key_package(
+ &self,
+ key_package: MlsMessage,
+ ) -> Result<KeyPackage, MlsError> {
+ let version = key_package.version;
+
+ let key_package = key_package
+ .into_key_package()
+ .ok_or(MlsError::UnexpectedMessageType)?;
+
+ let cs = self
+ .config
+ .crypto_provider()
+ .cipher_suite_provider(key_package.cipher_suite)
+ .ok_or(MlsError::UnsupportedCipherSuite(key_package.cipher_suite))?;
+
+ let id = self.config.identity_provider();
+
+ validate_key_package(&key_package, version, &cs, &id).await?;
+
+ Ok(key_package)
+ }
+}
+
+#[cfg(test)]
+pub(crate) mod tests_utils {
+ use crate::{
+ client::test_utils::{TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
+ key_package::test_utils::test_key_package_message,
+ };
+
+ pub use super::builder::test_utils::*;
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn external_client_can_validate_key_package() {
+ let kp = test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "john").await;
+ let server = TestExternalClientBuilder::new_for_test().build();
+ let validated_kp = server.validate_key_package(kp.clone()).await.unwrap();
+
+ assert_eq!(kp.into_key_package().unwrap(), validated_kp);
+ }
+}
diff --git a/src/external_client/builder.rs b/src/external_client/builder.rs
new file mode 100644
index 0000000..04c9768
--- /dev/null
+++ b/src/external_client/builder.rs
@@ -0,0 +1,602 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+//! Definitions to build an [`ExternalClient`].
+//!
+//! See [`ExternalClientBuilder`].
+
+use crate::{
+ crypto::SignaturePublicKey,
+ extension::ExtensionType,
+ external_client::{ExternalClient, ExternalClientConfig},
+ group::{
+ mls_rules::{DefaultMlsRules, MlsRules},
+ proposal::ProposalType,
+ },
+ identity::CredentialType,
+ protocol_version::ProtocolVersion,
+ tree_kem::Capabilities,
+ CryptoProvider, Sealed,
+};
+use std::{
+ collections::HashMap,
+ fmt::{self, Debug},
+};
+
+/// Base client configuration type when instantiating `ExternalClientBuilder`
+pub type ExternalBaseConfig = Config<Missing, DefaultMlsRules, Missing>;
+
+/// Builder for [`ExternalClient`]
+///
+/// This is returned by [`ExternalClient::builder`] and allows to tweak settings the
+/// `ExternalClient` will use. At a minimum, the builder must be told the [`CryptoProvider`]
+/// and [`IdentityProvider`] to use. Other settings have default values. This
+/// means that the following methods must be called before [`ExternalClientBuilder::build`]:
+///
+/// - To specify the [`CryptoProvider`]: [`ExternalClientBuilder::crypto_provider`]
+/// - To specify the [`IdentityProvider`]: [`ExternalClientBuilder::identity_provider`]
+///
+/// # Example
+///
+/// ```
+/// use mls_rs::{
+/// external_client::ExternalClient,
+/// identity::basic::BasicIdentityProvider,
+/// };
+///
+/// use mls_rs_crypto_openssl::OpensslCryptoProvider;
+///
+/// let _client = ExternalClient::builder()
+/// .crypto_provider(OpensslCryptoProvider::default())
+/// .identity_provider(BasicIdentityProvider::new())
+/// .build();
+/// ```
+///
+/// # Spelling out an `ExternalClient` type
+///
+/// There are two main ways to spell out an `ExternalClient` type if needed (e.g. function return type).
+///
+/// The first option uses `impl MlsConfig`:
+/// ```
+/// use mls_rs::{
+/// external_client::{ExternalClient, builder::MlsConfig},
+/// identity::basic::BasicIdentityProvider,
+/// };
+///
+/// use mls_rs_crypto_openssl::OpensslCryptoProvider;
+///
+/// fn make_client() -> ExternalClient<impl MlsConfig> {
+/// ExternalClient::builder()
+/// .crypto_provider(OpensslCryptoProvider::default())
+/// .identity_provider(BasicIdentityProvider::new())
+/// .build()
+/// }
+///```
+///
+/// The second option is more verbose and consists in writing the full `ExternalClient` type:
+/// ```
+/// use mls_rs::{
+/// external_client::{ExternalClient, builder::{ExternalBaseConfig, WithIdentityProvider, WithCryptoProvider}},
+/// identity::basic::BasicIdentityProvider,
+/// };
+///
+/// use mls_rs_crypto_openssl::OpensslCryptoProvider;
+///
+/// type MlsClient = ExternalClient<WithIdentityProvider<
+/// BasicIdentityProvider,
+/// WithCryptoProvider<OpensslCryptoProvider, ExternalBaseConfig>,
+/// >>;
+///
+/// fn make_client_2() -> MlsClient {
+/// ExternalClient::builder()
+/// .crypto_provider(OpensslCryptoProvider::new())
+/// .identity_provider(BasicIdentityProvider::new())
+/// .build()
+/// }
+///
+/// ```
+#[derive(Debug)]
+pub struct ExternalClientBuilder<C>(C);
+
+impl Default for ExternalClientBuilder<ExternalBaseConfig> {
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
+impl ExternalClientBuilder<ExternalBaseConfig> {
+ pub fn new() -> Self {
+ Self(Config(ConfigInner {
+ settings: Default::default(),
+ identity_provider: Missing,
+ mls_rules: DefaultMlsRules::new(),
+ crypto_provider: Missing,
+ signing_data: None,
+ }))
+ }
+}
+
+impl<C: IntoConfig> ExternalClientBuilder<C> {
+ /// Add an extension type to the list of extension types supported by the client.
+ pub fn extension_type(
+ self,
+ type_: ExtensionType,
+ ) -> ExternalClientBuilder<IntoConfigOutput<C>> {
+ self.extension_types(Some(type_))
+ }
+
+ /// Add multiple extension types to the list of extension types supported by the client.
+ pub fn extension_types<I>(self, types: I) -> ExternalClientBuilder<IntoConfigOutput<C>>
+ where
+ I: IntoIterator<Item = ExtensionType>,
+ {
+ let mut c = self.0.into_config();
+ c.0.settings.extension_types.extend(types);
+ ExternalClientBuilder(c)
+ }
+
+ /// Add a custom proposal type to the list of proposals types supported by the client.
+ pub fn custom_proposal_type(
+ self,
+ type_: ProposalType,
+ ) -> ExternalClientBuilder<IntoConfigOutput<C>> {
+ self.custom_proposal_types(Some(type_))
+ }
+
+ /// Add multiple custom proposal types to the list of proposal types supported by the client.
+ pub fn custom_proposal_types<I>(self, types: I) -> ExternalClientBuilder<IntoConfigOutput<C>>
+ where
+ I: IntoIterator<Item = ProposalType>,
+ {
+ let mut c = self.0.into_config();
+ c.0.settings.custom_proposal_types.extend(types);
+ ExternalClientBuilder(c)
+ }
+
+ /// Add a protocol version to the list of protocol versions supported by the client.
+ ///
+ /// If no protocol version is explicitly added, the client will support all protocol versions
+ /// supported by this crate.
+ pub fn protocol_version(
+ self,
+ version: ProtocolVersion,
+ ) -> ExternalClientBuilder<IntoConfigOutput<C>> {
+ self.protocol_versions(Some(version))
+ }
+
+ /// Add multiple protocol versions to the list of protocol versions supported by the client.
+ ///
+ /// If no protocol version is explicitly added, the client will support all protocol versions
+ /// supported by this crate.
+ pub fn protocol_versions<I>(self, versions: I) -> ExternalClientBuilder<IntoConfigOutput<C>>
+ where
+ I: IntoIterator<Item = ProtocolVersion>,
+ {
+ let mut c = self.0.into_config();
+ c.0.settings.protocol_versions.extend(versions);
+ ExternalClientBuilder(c)
+ }
+
+ /// Add an external signing key to be used by the client.
+ pub fn external_signing_key(
+ self,
+ id: Vec<u8>,
+ key: SignaturePublicKey,
+ ) -> ExternalClientBuilder<IntoConfigOutput<C>> {
+ let mut c = self.0.into_config();
+ c.0.settings.external_signing_keys.insert(id, key);
+ ExternalClientBuilder(c)
+ }
+
+ /// Specify the number of epochs before the current one to keep.
+ ///
+ /// By default, all epochs are kept.
+ pub fn max_epoch_jitter(self, max_jitter: u64) -> ExternalClientBuilder<IntoConfigOutput<C>> {
+ let mut c = self.0.into_config();
+ c.0.settings.max_epoch_jitter = Some(max_jitter);
+ ExternalClientBuilder(c)
+ }
+
+ /// Specify whether processed proposals should be cached by the external group. In case they
+ /// are not cached by the group, they should be cached externally and inserted using
+ /// `ExternalGroup::insert_proposal` before processing the next commit.
+ pub fn cache_proposals(
+ self,
+ cache_proposals: bool,
+ ) -> ExternalClientBuilder<IntoConfigOutput<C>> {
+ let mut c = self.0.into_config();
+ c.0.settings.cache_proposals = cache_proposals;
+ ExternalClientBuilder(c)
+ }
+
+ /// Set the identity validator to be used by the client.
+ pub fn identity_provider<I>(
+ self,
+ identity_provider: I,
+ ) -> ExternalClientBuilder<WithIdentityProvider<I, C>>
+ where
+ I: IdentityProvider,
+ {
+ let Config(c) = self.0.into_config();
+ ExternalClientBuilder(Config(ConfigInner {
+ settings: c.settings,
+ identity_provider,
+ mls_rules: c.mls_rules,
+ crypto_provider: c.crypto_provider,
+ signing_data: c.signing_data,
+ }))
+ }
+
+ /// Set the crypto provider to be used by the client.
+ ///
+ // TODO add a comment once we have a default provider
+ pub fn crypto_provider<Cp>(
+ self,
+ crypto_provider: Cp,
+ ) -> ExternalClientBuilder<WithCryptoProvider<Cp, C>>
+ where
+ Cp: CryptoProvider,
+ {
+ let Config(c) = self.0.into_config();
+ ExternalClientBuilder(Config(ConfigInner {
+ settings: c.settings,
+ identity_provider: c.identity_provider,
+ mls_rules: c.mls_rules,
+ crypto_provider,
+ signing_data: c.signing_data,
+ }))
+ }
+
+ /// Set the user-defined proposal rules to be used by the client.
+ ///
+ /// User-defined rules are used when sending and receiving commits before
+ /// enforcing general MLS protocol rules. If the rule set returns an error when
+ /// receiving a commit, the entire commit is considered invalid. If the
+ /// rule set would return an error when sending a commit, individual proposals
+ /// may be filtered out to compensate.
+ pub fn mls_rules<Pr>(self, mls_rules: Pr) -> ExternalClientBuilder<WithMlsRules<Pr, C>>
+ where
+ Pr: MlsRules,
+ {
+ let Config(c) = self.0.into_config();
+ ExternalClientBuilder(Config(ConfigInner {
+ settings: c.settings,
+ identity_provider: c.identity_provider,
+ mls_rules,
+ crypto_provider: c.crypto_provider,
+ signing_data: c.signing_data,
+ }))
+ }
+
+ /// Set the signature secret key used by the client to send external proposals.
+ pub fn signer(
+ self,
+ signer: SignatureSecretKey,
+ signing_identity: SigningIdentity,
+ ) -> ExternalClientBuilder<IntoConfigOutput<C>> {
+ let mut c = self.0.into_config();
+ c.0.signing_data = Some((signer, signing_identity));
+ ExternalClientBuilder(c)
+ }
+}
+
+impl<C: IntoConfig> ExternalClientBuilder<C>
+where
+ C::IdentityProvider: IdentityProvider + Clone,
+ C::MlsRules: MlsRules + Clone,
+ C::CryptoProvider: CryptoProvider + Clone,
+{
+ pub(crate) fn build_config(self) -> IntoConfigOutput<C> {
+ let mut c = self.0.into_config();
+
+ if c.0.settings.protocol_versions.is_empty() {
+ c.0.settings.protocol_versions = ProtocolVersion::all().collect();
+ }
+
+ c
+ }
+
+ /// Build an external client.
+ ///
+ /// See [`ExternalClientBuilder`] documentation if the return type of this function needs to be
+ /// spelled out.
+ pub fn build(self) -> ExternalClient<IntoConfigOutput<C>> {
+ let mut c = self.build_config();
+ let signing_data = c.0.signing_data.take();
+ ExternalClient::new(c, signing_data)
+ }
+}
+
+/// Marker type for required `ExternalClientBuilder` services that have not been specified yet.
+#[derive(Debug)]
+pub struct Missing;
+
+/// Change the identity validator used by a client configuration.
+///
+/// See [`ExternalClientBuilder::identity_provider`].
+pub type WithIdentityProvider<I, C> =
+ Config<I, <C as IntoConfig>::MlsRules, <C as IntoConfig>::CryptoProvider>;
+
+/// Change the proposal filter used by a client configuration.
+///
+/// See [`ExternalClientBuilder::mls_rules`].
+pub type WithMlsRules<Pr, C> =
+ Config<<C as IntoConfig>::IdentityProvider, Pr, <C as IntoConfig>::CryptoProvider>;
+
+/// Change the crypto provider used by a client configuration.
+///
+/// See [`ExternalClientBuilder::crypto_provider`].
+pub type WithCryptoProvider<Cp, C> =
+ Config<<C as IntoConfig>::IdentityProvider, <C as IntoConfig>::MlsRules, Cp>;
+
+/// Helper alias for `Config`.
+pub type IntoConfigOutput<C> = Config<
+ <C as IntoConfig>::IdentityProvider,
+ <C as IntoConfig>::MlsRules,
+ <C as IntoConfig>::CryptoProvider,
+>;
+
+impl<Ip, Pr, Cp> ExternalClientConfig for ConfigInner<Ip, Pr, Cp>
+where
+ Ip: IdentityProvider + Clone,
+ Pr: MlsRules + Clone,
+ Cp: CryptoProvider + Clone,
+{
+ type IdentityProvider = Ip;
+ type MlsRules = Pr;
+ type CryptoProvider = Cp;
+
+ fn supported_extensions(&self) -> Vec<ExtensionType> {
+ self.settings.extension_types.clone()
+ }
+
+ fn supported_protocol_versions(&self) -> Vec<ProtocolVersion> {
+ self.settings.protocol_versions.clone()
+ }
+
+ fn identity_provider(&self) -> Self::IdentityProvider {
+ self.identity_provider.clone()
+ }
+
+ fn crypto_provider(&self) -> Self::CryptoProvider {
+ self.crypto_provider.clone()
+ }
+
+ fn external_signing_key(&self, external_key_id: &[u8]) -> Option<SignaturePublicKey> {
+ self.settings
+ .external_signing_keys
+ .get(external_key_id)
+ .cloned()
+ }
+
+ fn mls_rules(&self) -> Self::MlsRules {
+ self.mls_rules.clone()
+ }
+
+ fn max_epoch_jitter(&self) -> Option<u64> {
+ self.settings.max_epoch_jitter
+ }
+
+ fn cache_proposals(&self) -> bool {
+ self.settings.cache_proposals
+ }
+
+ fn supported_custom_proposals(&self) -> Vec<ProposalType> {
+ self.settings.custom_proposal_types.clone()
+ }
+}
+
+impl<Ip, Mpf, Cp> Sealed for Config<Ip, Mpf, Cp> {}
+
+impl<Ip, Pr, Cp> MlsConfig for Config<Ip, Pr, Cp>
+where
+ Ip: IdentityProvider + Clone,
+ Pr: MlsRules + Clone,
+ Cp: CryptoProvider + Clone,
+{
+ type Output = ConfigInner<Ip, Pr, Cp>;
+
+ fn get(&self) -> &Self::Output {
+ &self.0
+ }
+}
+
+/// Helper trait to allow consuming crates to easily write an external client type as
+/// `ExternalClient<impl MlsConfig>`
+///
+/// It is not meant to be implemented by consuming crates. `T: MlsConfig` implies
+/// `T: ExternalClientConfig`.
+pub trait MlsConfig: Send + Sync + Clone + Sealed {
+ #[doc(hidden)]
+ type Output: ExternalClientConfig;
+
+ #[doc(hidden)]
+ fn get(&self) -> &Self::Output;
+}
+
+/// Blanket implementation so that `T: MlsConfig` implies `T: ExternalClientConfig`
+impl<T: MlsConfig> ExternalClientConfig for T {
+ type IdentityProvider = <T::Output as ExternalClientConfig>::IdentityProvider;
+ type MlsRules = <T::Output as ExternalClientConfig>::MlsRules;
+ type CryptoProvider = <T::Output as ExternalClientConfig>::CryptoProvider;
+
+ fn supported_extensions(&self) -> Vec<ExtensionType> {
+ self.get().supported_extensions()
+ }
+
+ fn supported_protocol_versions(&self) -> Vec<ProtocolVersion> {
+ self.get().supported_protocol_versions()
+ }
+
+ fn supported_custom_proposals(&self) -> Vec<ProposalType> {
+ self.get().supported_custom_proposals()
+ }
+
+ fn identity_provider(&self) -> Self::IdentityProvider {
+ self.get().identity_provider()
+ }
+
+ fn crypto_provider(&self) -> Self::CryptoProvider {
+ self.get().crypto_provider()
+ }
+
+ fn external_signing_key(&self, external_key_id: &[u8]) -> Option<SignaturePublicKey> {
+ self.get().external_signing_key(external_key_id)
+ }
+
+ fn mls_rules(&self) -> Self::MlsRules {
+ self.get().mls_rules()
+ }
+
+ fn cache_proposals(&self) -> bool {
+ self.get().cache_proposals()
+ }
+
+ fn max_epoch_jitter(&self) -> Option<u64> {
+ self.get().max_epoch_jitter()
+ }
+
+ fn capabilities(&self) -> Capabilities {
+ self.get().capabilities()
+ }
+
+ fn version_supported(&self, version: ProtocolVersion) -> bool {
+ self.get().version_supported(version)
+ }
+
+ fn supported_credentials(&self) -> Vec<CredentialType> {
+ self.get().supported_credentials()
+ }
+}
+
+#[derive(Clone)]
+pub(crate) struct Settings {
+ pub(crate) extension_types: Vec<ExtensionType>,
+ pub(crate) custom_proposal_types: Vec<ProposalType>,
+ pub(crate) protocol_versions: Vec<ProtocolVersion>,
+ pub(crate) external_signing_keys: HashMap<Vec<u8>, SignaturePublicKey>,
+ pub(crate) max_epoch_jitter: Option<u64>,
+ pub(crate) cache_proposals: bool,
+}
+
+impl Debug for Settings {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("Settings")
+ .field("extension_types", &self.extension_types)
+ .field("custom_proposal_types", &self.custom_proposal_types)
+ .field("protocol_versions", &self.protocol_versions)
+ .field(
+ "external_signing_keys",
+ &mls_rs_core::debug::pretty_with(|f| {
+ f.debug_map()
+ .entries(
+ self.external_signing_keys
+ .iter()
+ .map(|(k, v)| (mls_rs_core::debug::pretty_bytes(k), v)),
+ )
+ .finish()
+ }),
+ )
+ .field("max_epoch_jitter", &self.max_epoch_jitter)
+ .field("cache_proposals", &self.cache_proposals)
+ .finish()
+ }
+}
+
+impl Default for Settings {
+ fn default() -> Self {
+ Self {
+ cache_proposals: true,
+ extension_types: vec![],
+ protocol_versions: vec![],
+ external_signing_keys: Default::default(),
+ max_epoch_jitter: None,
+ custom_proposal_types: vec![],
+ }
+ }
+}
+
+/// Definitions meant to be private that are inaccessible outside this crate. They need to be marked
+/// `pub` because they appear in public definitions.
+mod private {
+ use mls_rs_core::{crypto::SignatureSecretKey, identity::SigningIdentity};
+
+ use super::{IntoConfigOutput, Settings};
+
+ #[derive(Clone, Debug)]
+ pub struct Config<Ip, Pr, Cp>(pub(crate) ConfigInner<Ip, Pr, Cp>);
+
+ #[derive(Clone, Debug)]
+ pub struct ConfigInner<Ip, Mpf, Cp> {
+ pub(crate) settings: Settings,
+ pub(crate) identity_provider: Ip,
+ pub(crate) mls_rules: Mpf,
+ pub(crate) crypto_provider: Cp,
+ pub(crate) signing_data: Option<(SignatureSecretKey, SigningIdentity)>,
+ }
+
+ pub trait IntoConfig {
+ type IdentityProvider;
+ type MlsRules;
+ type CryptoProvider;
+
+ fn into_config(self) -> IntoConfigOutput<Self>;
+ }
+
+ impl<Ip, Pr, Cp> IntoConfig for Config<Ip, Pr, Cp> {
+ type IdentityProvider = Ip;
+ type MlsRules = Pr;
+ type CryptoProvider = Cp;
+
+ fn into_config(self) -> Self {
+ self
+ }
+ }
+}
+
+use mls_rs_core::{
+ crypto::SignatureSecretKey,
+ identity::{IdentityProvider, SigningIdentity},
+};
+use private::{Config, ConfigInner, IntoConfig};
+
+#[cfg(test)]
+pub(crate) mod test_utils {
+ use crate::{
+ cipher_suite::CipherSuite, crypto::test_utils::TestCryptoProvider,
+ identity::basic::BasicIdentityProvider,
+ };
+
+ use super::{
+ ExternalBaseConfig, ExternalClientBuilder, WithCryptoProvider, WithIdentityProvider,
+ };
+
+ pub type TestExternalClientConfig = WithIdentityProvider<
+ BasicIdentityProvider,
+ WithCryptoProvider<TestCryptoProvider, ExternalBaseConfig>,
+ >;
+
+ pub type TestExternalClientBuilder = ExternalClientBuilder<TestExternalClientConfig>;
+
+ impl TestExternalClientBuilder {
+ pub fn new_for_test() -> Self {
+ ExternalClientBuilder::new()
+ .crypto_provider(TestCryptoProvider::default())
+ .identity_provider(BasicIdentityProvider::new())
+ }
+
+ pub fn new_for_test_disabling_cipher_suite(cipher_suite: CipherSuite) -> Self {
+ let crypto_provider = TestCryptoProvider::with_enabled_cipher_suites(
+ TestCryptoProvider::all_supported_cipher_suites()
+ .into_iter()
+ .filter(|cs| cs != &cipher_suite)
+ .collect(),
+ );
+
+ ExternalClientBuilder::new()
+ .crypto_provider(crypto_provider)
+ .identity_provider(BasicIdentityProvider::new())
+ }
+ }
+}
diff --git a/src/external_client/config.rs b/src/external_client/config.rs
new file mode 100644
index 0000000..649be99
--- /dev/null
+++ b/src/external_client/config.rs
@@ -0,0 +1,54 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use mls_rs_core::identity::IdentityProvider;
+
+use crate::{
+ crypto::SignaturePublicKey,
+ extension::ExtensionType,
+ group::{mls_rules::MlsRules, proposal::ProposalType},
+ identity::CredentialType,
+ protocol_version::ProtocolVersion,
+ tree_kem::Capabilities,
+ CryptoProvider,
+};
+
+pub trait ExternalClientConfig: Send + Sync + Clone {
+ type IdentityProvider: IdentityProvider + Clone;
+ type MlsRules: MlsRules + Clone;
+ type CryptoProvider: CryptoProvider;
+
+ fn supported_extensions(&self) -> Vec<ExtensionType>;
+ fn supported_custom_proposals(&self) -> Vec<ProposalType>;
+ fn supported_protocol_versions(&self) -> Vec<ProtocolVersion>;
+ fn identity_provider(&self) -> Self::IdentityProvider;
+ fn crypto_provider(&self) -> Self::CryptoProvider;
+ fn external_signing_key(&self, external_key_id: &[u8]) -> Option<SignaturePublicKey>;
+
+ fn mls_rules(&self) -> Self::MlsRules;
+
+ fn cache_proposals(&self) -> bool;
+
+ fn max_epoch_jitter(&self) -> Option<u64> {
+ None
+ }
+
+ fn capabilities(&self) -> Capabilities {
+ Capabilities {
+ protocol_versions: self.supported_protocol_versions(),
+ cipher_suites: self.crypto_provider().supported_cipher_suites(),
+ extensions: self.supported_extensions(),
+ proposals: self.supported_custom_proposals(),
+ credentials: self.supported_credentials(),
+ }
+ }
+
+ fn version_supported(&self, version: ProtocolVersion) -> bool {
+ self.supported_protocol_versions().contains(&version)
+ }
+
+ fn supported_credentials(&self) -> Vec<CredentialType> {
+ self.identity_provider().supported_types()
+ }
+}
diff --git a/src/external_client/group.rs b/src/external_client/group.rs
new file mode 100644
index 0000000..8939948
--- /dev/null
+++ b/src/external_client/group.rs
@@ -0,0 +1,1354 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+use mls_rs_core::{
+ crypto::SignatureSecretKey, error::IntoAnyError, extension::ExtensionList, group::Member,
+ identity::IdentityProvider,
+};
+
+use crate::{
+ cipher_suite::CipherSuite,
+ client::MlsError,
+ external_client::ExternalClientConfig,
+ group::{
+ cipher_suite_provider,
+ confirmation_tag::ConfirmationTag,
+ framing::PublicMessage,
+ member_from_leaf_node,
+ message_processor::{
+ ApplicationMessageDescription, CommitMessageDescription, EventOrContent,
+ MessageProcessor, ProposalMessageDescription, ProvisionalState,
+ },
+ snapshot::RawGroupState,
+ state::GroupState,
+ transcript_hash::InterimTranscriptHash,
+ validate_group_info_joiner, ContentType, ExportedTree, GroupContext, GroupInfo, Roster,
+ Welcome,
+ },
+ identity::SigningIdentity,
+ protocol_version::ProtocolVersion,
+ psk::AlwaysFoundPskStorage,
+ tree_kem::{node::LeafIndex, path_secret::PathSecret, TreeKemPrivate},
+ CryptoProvider, KeyPackage, MlsMessage,
+};
+
+#[cfg(feature = "by_ref_proposal")]
+use crate::{
+ group::{
+ framing::{Content, MlsMessagePayload},
+ message_processor::CachedProposal,
+ message_signature::AuthenticatedContent,
+ proposal::Proposal,
+ proposal_ref::ProposalRef,
+ Sender,
+ },
+ WireFormat,
+};
+
+#[cfg(all(feature = "by_ref_proposal", feature = "custom_proposal"))]
+use crate::group::proposal::CustomProposal;
+
+#[cfg(feature = "by_ref_proposal")]
+use mls_rs_core::{crypto::CipherSuiteProvider, psk::ExternalPskId};
+
+#[cfg(feature = "by_ref_proposal")]
+use crate::{
+ extension::ExternalSendersExt,
+ group::proposal::{AddProposal, ReInitProposal, RemoveProposal},
+};
+
+#[cfg(all(feature = "by_ref_proposal", feature = "psk"))]
+use crate::{
+ group::proposal::PreSharedKeyProposal,
+ psk::{
+ JustPreSharedKeyID, PreSharedKeyID, PskGroupId, PskNonce, ResumptionPSKUsage, ResumptionPsk,
+ },
+};
+
+#[cfg(feature = "private_message")]
+use crate::group::framing::PrivateMessage;
+
+use alloc::boxed::Box;
+
+/// The result of processing an [ExternalGroup](ExternalGroup) message using
+/// [process_incoming_message](ExternalGroup::process_incoming_message)
+#[derive(Clone, Debug)]
+#[allow(clippy::large_enum_variant)]
+pub enum ExternalReceivedMessage {
+ /// State update as the result of a successful commit.
+ Commit(CommitMessageDescription),
+ /// Received proposal and its unique identifier.
+ Proposal(ProposalMessageDescription),
+ /// Encrypted message that can not be processed.
+ Ciphertext(ContentType),
+ /// Validated GroupInfo object
+ GroupInfo(GroupInfo),
+ /// Validated welcome message
+ Welcome,
+ /// Validated key package
+ KeyPackage(KeyPackage),
+}
+
+/// A handle to an observed group that can track plaintext control messages
+/// and the resulting group state.
+#[derive(Clone)]
+pub struct ExternalGroup<C>
+where
+ C: ExternalClientConfig,
+{
+ pub(crate) config: C,
+ pub(crate) cipher_suite_provider: <C::CryptoProvider as CryptoProvider>::CipherSuiteProvider,
+ pub(crate) state: GroupState,
+ pub(crate) signing_data: Option<(SignatureSecretKey, SigningIdentity)>,
+}
+
+impl<C: ExternalClientConfig + Clone> ExternalGroup<C> {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn join(
+ config: C,
+ signing_data: Option<(SignatureSecretKey, SigningIdentity)>,
+ group_info: MlsMessage,
+ tree_data: Option<ExportedTree<'_>>,
+ ) -> Result<Self, MlsError> {
+ let protocol_version = group_info.version;
+
+ if !config.version_supported(protocol_version) {
+ return Err(MlsError::UnsupportedProtocolVersion(protocol_version));
+ }
+
+ let group_info = group_info
+ .into_group_info()
+ .ok_or(MlsError::UnexpectedMessageType)?;
+
+ let cipher_suite_provider = cipher_suite_provider(
+ config.crypto_provider(),
+ group_info.group_context.cipher_suite,
+ )?;
+
+ let public_tree = validate_group_info_joiner(
+ protocol_version,
+ &group_info,
+ tree_data,
+ &config.identity_provider(),
+ &cipher_suite_provider,
+ )
+ .await?;
+
+ let interim_transcript_hash = InterimTranscriptHash::create(
+ &cipher_suite_provider,
+ &group_info.group_context.confirmed_transcript_hash,
+ &group_info.confirmation_tag,
+ )
+ .await?;
+
+ Ok(Self {
+ config,
+ signing_data,
+ state: GroupState::new(
+ group_info.group_context,
+ public_tree,
+ interim_transcript_hash,
+ group_info.confirmation_tag,
+ ),
+ cipher_suite_provider,
+ })
+ }
+
+ /// Process a message that was sent to the group.
+ ///
+ /// * Proposals will be stored in the group state and processed by the
+ /// same rules as a standard group.
+ ///
+ /// * Commits will result in the same outcome as a standard group.
+ /// However, the integrity of the resulting group state can only be partially
+ /// verified, since the external group does have access to the group
+ /// secrets required to do a complete check.
+ ///
+ /// * Application messages are always encrypted so they result in a no-op
+ /// that returns [ExternalReceivedMessage::Ciphertext]
+ ///
+ /// # Warning
+ ///
+ /// Processing an encrypted commit or proposal message has the same result
+ /// as processing an encrypted application message. Proper tracking of
+ /// the group state requires that all proposal and commit messages are
+ /// readable.
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn process_incoming_message(
+ &mut self,
+ message: MlsMessage,
+ ) -> Result<ExternalReceivedMessage, MlsError> {
+ MessageProcessor::process_incoming_message(
+ self,
+ message,
+ #[cfg(feature = "by_ref_proposal")]
+ self.config.cache_proposals(),
+ )
+ .await
+ }
+
+ /// Replay a proposal message into the group skipping all validation steps.
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn insert_proposal_from_message(
+ &mut self,
+ message: MlsMessage,
+ ) -> Result<(), MlsError> {
+ let ptxt = match message.payload {
+ MlsMessagePayload::Plain(p) => Ok(p),
+ _ => Err(MlsError::UnexpectedMessageType),
+ }?;
+
+ let auth_content: AuthenticatedContent = ptxt.into();
+
+ let proposal_ref =
+ ProposalRef::from_content(&self.cipher_suite_provider, &auth_content).await?;
+
+ let sender = auth_content.content.sender;
+
+ let proposal = match auth_content.content.content {
+ Content::Proposal(p) => Ok(*p),
+ _ => Err(MlsError::UnexpectedMessageType),
+ }?;
+
+ self.group_state_mut()
+ .proposals
+ .insert(proposal_ref, proposal, sender);
+
+ Ok(())
+ }
+
+ /// Force insert a proposal directly into the internal state of the group
+ /// with no validation.
+ #[cfg(feature = "by_ref_proposal")]
+ pub fn insert_proposal(&mut self, proposal: CachedProposal) {
+ self.group_state_mut().proposals.insert(
+ proposal.proposal_ref,
+ proposal.proposal,
+ proposal.sender,
+ )
+ }
+
+ /// Create an external proposal to request that a group add a new member
+ ///
+ /// # Warning
+ ///
+ /// In order for the proposal generated by this function to be successfully
+ /// committed, the group needs to have `signing_identity` as an entry
+ /// within an [ExternalSendersExt](crate::extension::built_in::ExternalSendersExt)
+ /// as part of its group context extensions.
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn propose_add(
+ &mut self,
+ key_package: MlsMessage,
+ authenticated_data: Vec<u8>,
+ ) -> Result<MlsMessage, MlsError> {
+ let key_package = key_package
+ .into_key_package()
+ .ok_or(MlsError::UnexpectedMessageType)?;
+
+ self.propose(
+ Proposal::Add(alloc::boxed::Box::new(AddProposal { key_package })),
+ authenticated_data,
+ )
+ .await
+ }
+
+ /// Create an external proposal to request that a group remove an existing member
+ ///
+ /// # Warning
+ ///
+ /// In order for the proposal generated by this function to be successfully
+ /// committed, the group needs to have `signing_identity` as an entry
+ /// within an [ExternalSendersExt](crate::extension::built_in::ExternalSendersExt)
+ /// as part of its group context extensions.
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn propose_remove(
+ &mut self,
+ index: u32,
+ authenticated_data: Vec<u8>,
+ ) -> Result<MlsMessage, MlsError> {
+ let to_remove = LeafIndex(index);
+
+ // Verify that this leaf is actually in the tree
+ self.group_state().public_tree.get_leaf_node(to_remove)?;
+
+ self.propose(
+ Proposal::Remove(RemoveProposal { to_remove }),
+ authenticated_data,
+ )
+ .await
+ }
+
+ /// Create an external proposal to request that a group inserts an external
+ /// pre shared key into its state.
+ ///
+ /// # Warning
+ ///
+ /// In order for the proposal generated by this function to be successfully
+ /// committed, the group needs to have `signing_identity` as an entry
+ /// within an [ExternalSendersExt](crate::extension::built_in::ExternalSendersExt)
+ /// as part of its group context extensions.
+ #[cfg(all(feature = "by_ref_proposal", feature = "psk"))]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn propose_external_psk(
+ &mut self,
+ psk: ExternalPskId,
+ authenticated_data: Vec<u8>,
+ ) -> Result<MlsMessage, MlsError> {
+ let proposal = self.psk_proposal(JustPreSharedKeyID::External(psk))?;
+ self.propose(proposal, authenticated_data).await
+ }
+
+ /// Create an external proposal to request that a group adds a pre shared key
+ /// from a previous epoch to the current group state.
+ ///
+ /// # Warning
+ ///
+ /// In order for the proposal generated by this function to be successfully
+ /// committed, the group needs to have `signing_identity` as an entry
+ /// within an [ExternalSendersExt](crate::extension::built_in::ExternalSendersExt)
+ /// as part of its group context extensions.
+ #[cfg(all(feature = "by_ref_proposal", feature = "psk"))]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn propose_resumption_psk(
+ &mut self,
+ psk_epoch: u64,
+ authenticated_data: Vec<u8>,
+ ) -> Result<MlsMessage, MlsError> {
+ let key_id = ResumptionPsk {
+ psk_epoch,
+ usage: ResumptionPSKUsage::Application,
+ psk_group_id: PskGroupId(self.group_context().group_id().to_vec()),
+ };
+
+ let proposal = self.psk_proposal(JustPreSharedKeyID::Resumption(key_id))?;
+ self.propose(proposal, authenticated_data).await
+ }
+
+ #[cfg(all(feature = "by_ref_proposal", feature = "psk"))]
+ fn psk_proposal(&self, key_id: JustPreSharedKeyID) -> Result<Proposal, MlsError> {
+ Ok(Proposal::Psk(PreSharedKeyProposal {
+ psk: PreSharedKeyID {
+ key_id,
+ psk_nonce: PskNonce::random(&self.cipher_suite_provider)
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?,
+ },
+ }))
+ }
+
+ /// Create an external proposal to request that a group sets extensions stored in the group
+ /// state.
+ ///
+ /// # Warning
+ ///
+ /// In order for the proposal generated by this function to be successfully
+ /// committed, the group needs to have `signing_identity` as an entry
+ /// within an [ExternalSendersExt](crate::extension::built_in::ExternalSendersExt)
+ /// as part of its group context extensions.
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn propose_group_context_extensions(
+ &mut self,
+ extensions: ExtensionList,
+ authenticated_data: Vec<u8>,
+ ) -> Result<MlsMessage, MlsError> {
+ let proposal = Proposal::GroupContextExtensions(extensions);
+ self.propose(proposal, authenticated_data).await
+ }
+
+ /// Create an external proposal to request that a group is reinitialized.
+ ///
+ /// # Warning
+ ///
+ /// In order for the proposal generated by this function to be successfully
+ /// committed, the group needs to have `signing_identity` as an entry
+ /// within an [ExternalSendersExt](crate::extension::built_in::ExternalSendersExt)
+ /// as part of its group context extensions.
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn propose_reinit(
+ &mut self,
+ group_id: Option<Vec<u8>>,
+ version: ProtocolVersion,
+ cipher_suite: CipherSuite,
+ extensions: ExtensionList,
+ authenticated_data: Vec<u8>,
+ ) -> Result<MlsMessage, MlsError> {
+ let group_id = group_id.map(Ok).unwrap_or_else(|| {
+ self.cipher_suite_provider
+ .random_bytes_vec(self.cipher_suite_provider.kdf_extract_size())
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
+ })?;
+
+ let proposal = Proposal::ReInit(ReInitProposal {
+ group_id,
+ version,
+ cipher_suite,
+ extensions,
+ });
+
+ self.propose(proposal, authenticated_data).await
+ }
+
+ /// Create a custom proposal message.
+ ///
+ /// # Warning
+ ///
+ /// In order for the proposal generated by this function to be successfully
+ /// committed, the group needs to have `signing_identity` as an entry
+ /// within an [ExternalSendersExt](crate::extension::built_in::ExternalSendersExt)
+ /// as part of its group context extensions.
+ #[cfg(all(feature = "by_ref_proposal", feature = "custom_proposal"))]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn propose_custom(
+ &mut self,
+ proposal: CustomProposal,
+ authenticated_data: Vec<u8>,
+ ) -> Result<MlsMessage, MlsError> {
+ self.propose(Proposal::Custom(proposal), authenticated_data)
+ .await
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn propose(
+ &mut self,
+ proposal: Proposal,
+ authenticated_data: Vec<u8>,
+ ) -> Result<MlsMessage, MlsError> {
+ let (signer, signing_identity) =
+ self.signing_data.as_ref().ok_or(MlsError::SignerNotFound)?;
+
+ let external_senders_ext = self
+ .state
+ .context
+ .extensions
+ .get_as::<ExternalSendersExt>()?
+ .ok_or(MlsError::ExternalProposalsDisabled)?;
+
+ let sender_index = external_senders_ext
+ .allowed_senders
+ .iter()
+ .position(|allowed_signer| signing_identity == allowed_signer)
+ .ok_or(MlsError::InvalidExternalSigningIdentity)?;
+
+ let sender = Sender::External(sender_index as u32);
+
+ let auth_content = AuthenticatedContent::new_signed(
+ &self.cipher_suite_provider,
+ &self.state.context,
+ sender,
+ Content::Proposal(Box::new(proposal.clone())),
+ signer,
+ WireFormat::PublicMessage,
+ authenticated_data,
+ )
+ .await?;
+
+ self.state.proposals.insert(
+ ProposalRef::from_content(&self.cipher_suite_provider, &auth_content).await?,
+ proposal,
+ sender,
+ );
+
+ let plaintext = PublicMessage {
+ content: auth_content.content,
+ auth: auth_content.auth,
+ membership_tag: None,
+ };
+
+ Ok(MlsMessage::new(
+ self.group_context().version(),
+ MlsMessagePayload::Plain(plaintext),
+ ))
+ }
+
+ /// Delete all sent and received proposals cached for commit.
+ #[cfg(feature = "by_ref_proposal")]
+ pub fn clear_proposal_cache(&mut self) {
+ self.state.proposals.clear()
+ }
+
+ #[inline(always)]
+ pub(crate) fn group_state(&self) -> &GroupState {
+ &self.state
+ }
+
+ /// Get the current group context summarizing various information about the group.
+ #[inline(always)]
+ pub fn group_context(&self) -> &GroupContext {
+ &self.group_state().context
+ }
+
+ /// Export the current ratchet tree used within the group.
+ pub fn export_tree(&self) -> Result<Vec<u8>, MlsError> {
+ self.group_state()
+ .public_tree
+ .nodes
+ .mls_encode_to_vec()
+ .map_err(Into::into)
+ }
+
+ /// Get the current roster of the group.
+ #[inline(always)]
+ pub fn roster(&self) -> Roster {
+ self.group_state().public_tree.roster()
+ }
+
+ /// Get the
+ /// [transcript hash](https://messaginglayersecurity.rocks/mls-protocol/draft-ietf-mls-protocol.html#name-transcript-hashes)
+ /// for the current epoch that the group is in.
+ #[inline(always)]
+ pub fn transcript_hash(&self) -> &Vec<u8> {
+ &self.group_state().context.confirmed_transcript_hash
+ }
+
+ /// Get the
+ /// [tree hash](https://www.rfc-editor.org/rfc/rfc9420.html#name-tree-hashes)
+ /// for the current epoch that the group is in.
+ #[inline(always)]
+ pub fn tree_hash(&self) -> &[u8] {
+ &self.group_state().context.tree_hash
+ }
+
+ /// Find a member based on their identity.
+ ///
+ /// Identities are matched based on the
+ /// [IdentityProvider](crate::IdentityProvider)
+ /// that this group was configured with.
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn get_member_with_identity(
+ &self,
+ identity_id: &SigningIdentity,
+ ) -> Result<Member, MlsError> {
+ let identity = self
+ .identity_provider()
+ .identity(identity_id, self.group_context().extensions())
+ .await
+ .map_err(|error| MlsError::IdentityProviderError(error.into_any_error()))?;
+
+ let tree = &self.group_state().public_tree;
+
+ #[cfg(feature = "tree_index")]
+ let index = tree.get_leaf_node_with_identity(&identity);
+
+ #[cfg(not(feature = "tree_index"))]
+ let index = tree
+ .get_leaf_node_with_identity(
+ &identity,
+ &self.identity_provider(),
+ self.group_context().extensions(),
+ )
+ .await?;
+
+ let index = index.ok_or(MlsError::MemberNotFound)?;
+ let node = self.group_state().public_tree.get_leaf_node(index)?;
+
+ Ok(member_from_leaf_node(node, index))
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(all(target_arch = "wasm32", mls_build_async), maybe_async::must_be_async(?Send))]
+#[cfg_attr(
+ all(not(target_arch = "wasm32"), mls_build_async),
+ maybe_async::must_be_async
+)]
+impl<C> MessageProcessor for ExternalGroup<C>
+where
+ C: ExternalClientConfig + Clone,
+{
+ type MlsRules = C::MlsRules;
+ type IdentityProvider = C::IdentityProvider;
+ type PreSharedKeyStorage = AlwaysFoundPskStorage;
+ type OutputType = ExternalReceivedMessage;
+ type CipherSuiteProvider = <C::CryptoProvider as CryptoProvider>::CipherSuiteProvider;
+
+ #[cfg(feature = "private_message")]
+ fn self_index(&self) -> Option<LeafIndex> {
+ None
+ }
+
+ fn mls_rules(&self) -> Self::MlsRules {
+ self.config.mls_rules()
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn verify_plaintext_authentication(
+ &self,
+ message: PublicMessage,
+ ) -> Result<EventOrContent<Self::OutputType>, MlsError> {
+ let auth_content = crate::group::message_verifier::verify_plaintext_authentication(
+ &self.cipher_suite_provider,
+ message,
+ None,
+ None,
+ &self.state,
+ )
+ .await?;
+
+ Ok(EventOrContent::Content(auth_content))
+ }
+
+ #[cfg(feature = "private_message")]
+ async fn process_ciphertext(
+ &mut self,
+ cipher_text: &PrivateMessage,
+ ) -> Result<EventOrContent<Self::OutputType>, MlsError> {
+ Ok(EventOrContent::Event(ExternalReceivedMessage::Ciphertext(
+ cipher_text.content_type,
+ )))
+ }
+
+ async fn update_key_schedule(
+ &mut self,
+ _secrets: Option<(TreeKemPrivate, PathSecret)>,
+ interim_transcript_hash: InterimTranscriptHash,
+ confirmation_tag: &ConfirmationTag,
+ provisional_public_state: ProvisionalState,
+ ) -> Result<(), MlsError> {
+ self.state.context = provisional_public_state.group_context;
+ #[cfg(feature = "by_ref_proposal")]
+ self.state.proposals.clear();
+ self.state.interim_transcript_hash = interim_transcript_hash;
+ self.state.public_tree = provisional_public_state.public_tree;
+ self.state.confirmation_tag = confirmation_tag.clone();
+
+ Ok(())
+ }
+
+ fn identity_provider(&self) -> Self::IdentityProvider {
+ self.config.identity_provider()
+ }
+
+ fn psk_storage(&self) -> Self::PreSharedKeyStorage {
+ AlwaysFoundPskStorage
+ }
+
+ fn group_state(&self) -> &GroupState {
+ &self.state
+ }
+
+ fn group_state_mut(&mut self) -> &mut GroupState {
+ &mut self.state
+ }
+
+ fn can_continue_processing(&self, _provisional_state: &ProvisionalState) -> bool {
+ true
+ }
+
+ #[cfg(feature = "private_message")]
+ fn min_epoch_available(&self) -> Option<u64> {
+ self.config
+ .max_epoch_jitter()
+ .map(|j| self.state.context.epoch - j)
+ }
+
+ fn cipher_suite_provider(&self) -> &Self::CipherSuiteProvider {
+ &self.cipher_suite_provider
+ }
+}
+
+/// Serializable snapshot of an [ExternalGroup](ExternalGroup) state.
+#[derive(Debug, MlsEncode, MlsSize, MlsDecode, PartialEq, Clone)]
+pub struct ExternalSnapshot {
+ version: u16,
+ state: RawGroupState,
+ signing_data: Option<(SignatureSecretKey, SigningIdentity)>,
+}
+
+impl ExternalSnapshot {
+ /// Serialize the snapshot
+ pub fn to_bytes(&self) -> Result<Vec<u8>, MlsError> {
+ Ok(self.mls_encode_to_vec()?)
+ }
+
+ /// Deserialize the snapshot
+ pub fn from_bytes(bytes: &[u8]) -> Result<Self, MlsError> {
+ Ok(Self::mls_decode(&mut &*bytes)?)
+ }
+}
+
+impl<C> ExternalGroup<C>
+where
+ C: ExternalClientConfig + Clone,
+{
+ /// Create a snapshot of this group's current internal state.
+ pub fn snapshot(&self) -> ExternalSnapshot {
+ ExternalSnapshot {
+ state: RawGroupState::export(self.group_state()),
+ version: 1,
+ signing_data: self.signing_data.clone(),
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn from_snapshot(
+ config: C,
+ snapshot: ExternalSnapshot,
+ ) -> Result<Self, MlsError> {
+ #[cfg(feature = "tree_index")]
+ let identity_provider = config.identity_provider();
+
+ let cipher_suite_provider = cipher_suite_provider(
+ config.crypto_provider(),
+ snapshot.state.context.cipher_suite,
+ )?;
+
+ Ok(ExternalGroup {
+ config,
+ signing_data: snapshot.signing_data,
+ state: snapshot
+ .state
+ .import(
+ #[cfg(feature = "tree_index")]
+ &identity_provider,
+ )
+ .await?,
+ cipher_suite_provider,
+ })
+ }
+}
+
+impl From<CommitMessageDescription> for ExternalReceivedMessage {
+ fn from(value: CommitMessageDescription) -> Self {
+ ExternalReceivedMessage::Commit(value)
+ }
+}
+
+impl TryFrom<ApplicationMessageDescription> for ExternalReceivedMessage {
+ type Error = MlsError;
+
+ fn try_from(_: ApplicationMessageDescription) -> Result<Self, Self::Error> {
+ Err(MlsError::UnencryptedApplicationMessage)
+ }
+}
+
+impl From<ProposalMessageDescription> for ExternalReceivedMessage {
+ fn from(value: ProposalMessageDescription) -> Self {
+ ExternalReceivedMessage::Proposal(value)
+ }
+}
+
+impl From<GroupInfo> for ExternalReceivedMessage {
+ fn from(value: GroupInfo) -> Self {
+ ExternalReceivedMessage::GroupInfo(value)
+ }
+}
+
+impl From<Welcome> for ExternalReceivedMessage {
+ fn from(_: Welcome) -> Self {
+ ExternalReceivedMessage::Welcome
+ }
+}
+
+impl From<KeyPackage> for ExternalReceivedMessage {
+ fn from(value: KeyPackage) -> Self {
+ ExternalReceivedMessage::KeyPackage(value)
+ }
+}
+
+#[cfg(test)]
+pub(crate) mod test_utils {
+ use crate::{
+ external_client::tests_utils::{TestExternalClientBuilder, TestExternalClientConfig},
+ group::test_utils::TestGroup,
+ };
+
+ use super::ExternalGroup;
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn make_external_group(
+ group: &TestGroup,
+ ) -> ExternalGroup<TestExternalClientConfig> {
+ make_external_group_with_config(
+ group,
+ TestExternalClientBuilder::new_for_test().build_config(),
+ )
+ .await
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn make_external_group_with_config(
+ group: &TestGroup,
+ config: TestExternalClientConfig,
+ ) -> ExternalGroup<TestExternalClientConfig> {
+ ExternalGroup::join(
+ config,
+ None,
+ group
+ .group
+ .group_info_message_allowing_ext_commit(true)
+ .await
+ .unwrap(),
+ None,
+ )
+ .await
+ .unwrap()
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::test_utils::make_external_group;
+ use crate::{
+ cipher_suite::CipherSuite,
+ client::{
+ test_utils::{TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
+ MlsError,
+ },
+ crypto::{test_utils::TestCryptoProvider, SignatureSecretKey},
+ extension::ExternalSendersExt,
+ external_client::{
+ group::test_utils::make_external_group_with_config,
+ tests_utils::{TestExternalClientBuilder, TestExternalClientConfig},
+ ExternalGroup, ExternalReceivedMessage, ExternalSnapshot,
+ },
+ group::{
+ framing::{Content, MlsMessagePayload},
+ proposal::{AddProposal, Proposal, ProposalOrRef},
+ proposal_ref::ProposalRef,
+ test_utils::{test_group, TestGroup},
+ ProposalMessageDescription,
+ },
+ identity::{test_utils::get_test_signing_identity, SigningIdentity},
+ key_package::test_utils::{test_key_package, test_key_package_message},
+ protocol_version::ProtocolVersion,
+ ExtensionList, MlsMessage,
+ };
+ use assert_matches::assert_matches;
+ use mls_rs_codec::{MlsDecode, MlsEncode};
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn test_group_with_one_commit(v: ProtocolVersion, cs: CipherSuite) -> TestGroup {
+ let mut group = test_group(v, cs).await;
+ group.group.commit(Vec::new()).await.unwrap();
+ group.process_pending_commit().await.unwrap();
+ group
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn test_group_two_members(
+ v: ProtocolVersion,
+ cs: CipherSuite,
+ #[cfg(feature = "by_ref_proposal")] ext_identity: Option<SigningIdentity>,
+ ) -> TestGroup {
+ let mut group = test_group_with_one_commit(v, cs).await;
+
+ let bob_key_package = test_key_package_message(v, cs, "bob").await;
+
+ let mut commit_builder = group
+ .group
+ .commit_builder()
+ .add_member(bob_key_package)
+ .unwrap();
+
+ #[cfg(feature = "by_ref_proposal")]
+ if let Some(ext_signer) = ext_identity {
+ let mut ext_list = ExtensionList::new();
+
+ ext_list
+ .set_from(ExternalSendersExt {
+ allowed_senders: vec![ext_signer],
+ })
+ .unwrap();
+
+ commit_builder = commit_builder.set_group_context_ext(ext_list).unwrap();
+ }
+
+ commit_builder.build().await.unwrap();
+
+ group.process_pending_commit().await.unwrap();
+ group
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn external_group_can_be_created() {
+ for (v, cs) in ProtocolVersion::all().flat_map(|v| {
+ TestCryptoProvider::all_supported_cipher_suites()
+ .into_iter()
+ .map(move |cs| (v, cs))
+ }) {
+ make_external_group(&test_group_with_one_commit(v, cs).await).await;
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn external_group_can_process_commit() {
+ let mut alice = test_group_with_one_commit(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let mut server = make_external_group(&alice).await;
+ let commit_output = alice.group.commit(Vec::new()).await.unwrap();
+ alice.group.apply_pending_commit().await.unwrap();
+
+ server
+ .process_incoming_message(commit_output.commit_message)
+ .await
+ .unwrap();
+
+ assert_eq!(alice.group.state, server.state);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn external_group_can_process_proposals_by_reference() {
+ let mut alice = test_group_with_one_commit(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let mut server = make_external_group(&alice).await;
+
+ let bob_key_package =
+ test_key_package(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "bob").await;
+
+ let add_proposal = Proposal::Add(Box::new(AddProposal {
+ key_package: bob_key_package,
+ }));
+
+ let packet = alice.propose(add_proposal.clone()).await;
+
+ let proposal_process = server.process_incoming_message(packet).await.unwrap();
+
+ assert_matches!(
+ proposal_process,
+ ExternalReceivedMessage::Proposal(ProposalMessageDescription { ref proposal, ..}) if proposal == &add_proposal
+ );
+
+ let commit_output = alice.group.commit(vec![]).await.unwrap();
+ alice.group.apply_pending_commit().await.unwrap();
+
+ let commit_result = server
+ .process_incoming_message(commit_output.commit_message)
+ .await
+ .unwrap();
+
+ #[cfg(feature = "state_update")]
+ assert_matches!(
+ commit_result,
+ ExternalReceivedMessage::Commit(commit_description)
+ if commit_description.state_update.roster_update.added().iter().any(|added| added.index == 1)
+ );
+
+ #[cfg(not(feature = "state_update"))]
+ assert_matches!(commit_result, ExternalReceivedMessage::Commit(_));
+
+ assert_eq!(alice.group.state, server.state);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn external_group_can_process_commit_adding_member() {
+ let mut alice = test_group_with_one_commit(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let mut server = make_external_group(&alice).await;
+ let (_, commit) = alice.join("bob").await;
+
+ let update = match server.process_incoming_message(commit).await.unwrap() {
+ ExternalReceivedMessage::Commit(update) => update.state_update,
+ _ => panic!("Expected processed commit"),
+ };
+
+ #[cfg(feature = "state_update")]
+ assert_eq!(update.roster_update.added().len(), 1);
+
+ assert_eq!(server.state.public_tree.get_leaf_nodes().len(), 2);
+
+ assert_eq!(alice.group.state, server.state);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn external_group_rejects_commit_not_for_current_epoch() {
+ let mut alice = test_group_with_one_commit(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let mut server = make_external_group(&alice).await;
+
+ let mut commit_output = alice.group.commit(vec![]).await.unwrap();
+
+ match commit_output.commit_message.payload {
+ MlsMessagePayload::Plain(ref mut plain) => plain.content.epoch = 0,
+ _ => panic!("Unexpected non-plaintext data"),
+ };
+
+ let res = server
+ .process_incoming_message(commit_output.commit_message)
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidEpoch));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn external_group_can_reject_message_with_invalid_signature() {
+ let mut alice = test_group_with_one_commit(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+
+ let mut server = make_external_group_with_config(
+ &alice,
+ TestExternalClientBuilder::new_for_test().build_config(),
+ )
+ .await;
+
+ let mut commit_output = alice.group.commit(Vec::new()).await.unwrap();
+
+ match commit_output.commit_message.payload {
+ MlsMessagePayload::Plain(ref mut plain) => plain.auth.signature = Vec::new().into(),
+ _ => panic!("Unexpected non-plaintext data"),
+ };
+
+ let res = server
+ .process_incoming_message(commit_output.commit_message)
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidSignature));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn external_group_rejects_unencrypted_application_message() {
+ let mut alice = test_group_with_one_commit(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let mut server = make_external_group(&alice).await;
+
+ let plaintext = alice
+ .make_plaintext(Content::Application(b"hello".to_vec().into()))
+ .await;
+
+ let res = server.process_incoming_message(plaintext).await;
+
+ assert_matches!(res, Err(MlsError::UnencryptedApplicationMessage));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn external_group_will_reject_unsupported_cipher_suites() {
+ let alice = test_group_with_one_commit(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+
+ let config =
+ TestExternalClientBuilder::new_for_test_disabling_cipher_suite(TEST_CIPHER_SUITE)
+ .build_config();
+
+ let res = ExternalGroup::join(
+ config,
+ None,
+ alice
+ .group
+ .group_info_message_allowing_ext_commit(true)
+ .await
+ .unwrap(),
+ None,
+ )
+ .await
+ .map(|_| ());
+
+ assert_matches!(
+ res,
+ Err(MlsError::UnsupportedCipherSuite(TEST_CIPHER_SUITE))
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn external_group_will_reject_unsupported_protocol_versions() {
+ let alice = test_group_with_one_commit(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+
+ let config = TestExternalClientBuilder::new_for_test().build_config();
+
+ let mut group_info = alice
+ .group
+ .group_info_message_allowing_ext_commit(true)
+ .await
+ .unwrap();
+
+ group_info.version = ProtocolVersion::from(64);
+
+ let res = ExternalGroup::join(config, None, group_info, None)
+ .await
+ .map(|_| ());
+
+ assert_matches!(
+ res,
+ Err(MlsError::UnsupportedProtocolVersion(v)) if v ==
+ ProtocolVersion::from(64)
+ );
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn setup_extern_proposal_test(
+ extern_proposals_allowed: bool,
+ ) -> (SigningIdentity, SignatureSecretKey, TestGroup) {
+ let (server_identity, server_key) =
+ get_test_signing_identity(TEST_CIPHER_SUITE, b"server").await;
+
+ let alice = test_group_two_members(
+ TEST_PROTOCOL_VERSION,
+ TEST_CIPHER_SUITE,
+ extern_proposals_allowed.then(|| server_identity.clone()),
+ )
+ .await;
+
+ (server_identity, server_key, alice)
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn test_external_proposal(
+ server: &mut ExternalGroup<TestExternalClientConfig>,
+ alice: &mut TestGroup,
+ external_proposal: MlsMessage,
+ ) {
+ let auth_content = external_proposal.clone().into_plaintext().unwrap().into();
+
+ let proposal_ref = ProposalRef::from_content(&server.cipher_suite_provider, &auth_content)
+ .await
+ .unwrap();
+
+ // Alice receives the proposal
+ alice.process_message(external_proposal).await.unwrap();
+
+ // Alice commits the proposal
+ let commit_output = alice.group.commit(vec![]).await.unwrap();
+
+ let commit = match commit_output
+ .commit_message
+ .clone()
+ .into_plaintext()
+ .unwrap()
+ .content
+ .content
+ {
+ Content::Commit(commit) => commit,
+ _ => panic!("not a commit"),
+ };
+
+ // The proposal should be in the resulting commit
+ assert!(commit
+ .proposals
+ .contains(&ProposalOrRef::Reference(proposal_ref)));
+
+ alice.process_pending_commit().await.unwrap();
+
+ server
+ .process_incoming_message(commit_output.commit_message)
+ .await
+ .unwrap();
+
+ assert_eq!(alice.group.state, server.state);
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn external_group_can_propose_add() {
+ let (server_identity, server_key, mut alice) = setup_extern_proposal_test(true).await;
+
+ let mut server = make_external_group(&alice).await;
+
+ server.signing_data = Some((server_key, server_identity));
+
+ let charlie_key_package =
+ test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "charlie").await;
+
+ let external_proposal = server
+ .propose_add(charlie_key_package, vec![])
+ .await
+ .unwrap();
+
+ test_external_proposal(&mut server, &mut alice, external_proposal).await
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn external_group_can_propose_remove() {
+ let (server_identity, server_key, mut alice) = setup_extern_proposal_test(true).await;
+
+ let mut server = make_external_group(&alice).await;
+
+ server.signing_data = Some((server_key, server_identity));
+
+ let external_proposal = server.propose_remove(1, vec![]).await.unwrap();
+
+ test_external_proposal(&mut server, &mut alice, external_proposal).await
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn external_group_external_proposal_not_allowed() {
+ let (signing_id, secret_key, alice) = setup_extern_proposal_test(false).await;
+ let mut server = make_external_group(&alice).await;
+
+ server.signing_data = Some((secret_key, signing_id));
+
+ let charlie_key_package =
+ test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "charlie").await;
+
+ let res = server.propose_add(charlie_key_package, vec![]).await;
+
+ assert_matches!(res, Err(MlsError::ExternalProposalsDisabled));
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn external_group_external_signing_identity_invalid() {
+ let (server_identity, server_key) =
+ get_test_signing_identity(TEST_CIPHER_SUITE, b"server").await;
+
+ let alice = test_group_two_members(
+ TEST_PROTOCOL_VERSION,
+ TEST_CIPHER_SUITE,
+ Some(
+ get_test_signing_identity(TEST_CIPHER_SUITE, b"not server")
+ .await
+ .0,
+ ),
+ )
+ .await;
+
+ let mut server = make_external_group(&alice).await;
+
+ server.signing_data = Some((server_key, server_identity));
+
+ let res = server.propose_remove(1, vec![]).await;
+
+ assert_matches!(res, Err(MlsError::InvalidExternalSigningIdentity));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn external_group_errors_on_old_epoch() {
+ let mut alice = test_group_with_one_commit(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+
+ let mut server = make_external_group_with_config(
+ &alice,
+ TestExternalClientBuilder::new_for_test()
+ .max_epoch_jitter(0)
+ .build_config(),
+ )
+ .await;
+
+ let old_application_msg = alice
+ .group
+ .encrypt_application_message(&[], vec![])
+ .await
+ .unwrap();
+
+ let commit_output = alice.group.commit(vec![]).await.unwrap();
+
+ server
+ .process_incoming_message(commit_output.commit_message)
+ .await
+ .unwrap();
+
+ let res = server.process_incoming_message(old_application_msg).await;
+
+ assert_matches!(res, Err(MlsError::InvalidEpoch));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn proposals_can_be_cached_externally() {
+ let mut alice = test_group_with_one_commit(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+
+ let mut server = make_external_group_with_config(
+ &alice,
+ TestExternalClientBuilder::new_for_test()
+ .cache_proposals(false)
+ .build_config(),
+ )
+ .await;
+
+ let proposal = alice.group.propose_update(vec![]).await.unwrap();
+
+ let commit_output = alice.group.commit(vec![]).await.unwrap();
+
+ server
+ .process_incoming_message(proposal.clone())
+ .await
+ .unwrap();
+
+ server.insert_proposal_from_message(proposal).await.unwrap();
+
+ server
+ .process_incoming_message(commit_output.commit_message)
+ .await
+ .unwrap();
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn external_group_can_observe_since_creation() {
+ let mut alice = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+
+ let info = alice
+ .group
+ .group_info_message_allowing_ext_commit(true)
+ .await
+ .unwrap();
+
+ let config = TestExternalClientBuilder::new_for_test().build_config();
+ let mut server = ExternalGroup::join(config, None, info, None).await.unwrap();
+
+ for _ in 0..2 {
+ let commit = alice.group.commit(vec![]).await.unwrap().commit_message;
+ alice.process_pending_commit().await.unwrap();
+ server.process_incoming_message(commit).await.unwrap();
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn external_group_can_be_serialized_to_tls_encoding() {
+ let server =
+ make_external_group(&test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await).await;
+
+ let snapshot = server.snapshot().mls_encode_to_vec().unwrap();
+ let snapshot_restored = ExternalSnapshot::mls_decode(&mut snapshot.as_slice()).unwrap();
+
+ let server_restored =
+ ExternalGroup::from_snapshot(server.config.clone(), snapshot_restored)
+ .await
+ .unwrap();
+
+ assert_eq!(server.group_state(), server_restored.group_state());
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn external_group_can_validate_info() {
+ let alice = test_group_with_one_commit(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let mut server = make_external_group(&alice).await;
+
+ let info = alice
+ .group
+ .group_info_message_allowing_ext_commit(false)
+ .await
+ .unwrap();
+
+ let update = server.process_incoming_message(info.clone()).await.unwrap();
+ let info = info.into_group_info().unwrap();
+
+ assert_matches!(update, ExternalReceivedMessage::GroupInfo(update_info) if update_info == info);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn external_group_can_validate_key_package() {
+ let alice = test_group_with_one_commit(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let mut server = make_external_group(&alice).await;
+
+ let kp = test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "john").await;
+
+ let update = server.process_incoming_message(kp.clone()).await.unwrap();
+ let kp = kp.into_key_package().unwrap();
+
+ assert_matches!(update, ExternalReceivedMessage::KeyPackage(update_kp) if update_kp == kp);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn external_group_can_validate_welcome() {
+ let mut alice = test_group_with_one_commit(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let mut server = make_external_group(&alice).await;
+
+ let [welcome] = alice
+ .group
+ .commit_builder()
+ .add_member(
+ test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "john").await,
+ )
+ .unwrap()
+ .build()
+ .await
+ .unwrap()
+ .welcome_messages
+ .try_into()
+ .unwrap();
+
+ let update = server.process_incoming_message(welcome).await.unwrap();
+
+ assert_matches!(update, ExternalReceivedMessage::Welcome);
+ }
+}
diff --git a/src/grease.rs b/src/grease.rs
new file mode 100644
index 0000000..cd4f208
--- /dev/null
+++ b/src/grease.rs
@@ -0,0 +1,227 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use mls_rs_core::{crypto::CipherSuiteProvider, extension::ExtensionList, group::Capabilities};
+
+use crate::{
+ client::MlsError,
+ group::{GroupInfo, NewMemberInfo},
+ key_package::KeyPackage,
+ tree_kem::leaf_node::LeafNode,
+};
+
+impl LeafNode {
+ pub fn ungreased_capabilities(&self) -> Capabilities {
+ let mut capabilitites = self.capabilities.clone();
+ grease_functions::ungrease(&mut capabilitites.cipher_suites);
+ grease_functions::ungrease(&mut capabilitites.extensions);
+ grease_functions::ungrease(&mut capabilitites.proposals);
+ grease_functions::ungrease(&mut capabilitites.credentials);
+ capabilitites
+ }
+
+ pub fn ungreased_extensions(&self) -> ExtensionList {
+ let mut extensions = self.extensions.clone();
+ grease_functions::ungrease_extensions(&mut extensions);
+ extensions
+ }
+
+ pub fn grease<P: CipherSuiteProvider>(&mut self, cs: &P) -> Result<(), MlsError> {
+ grease_functions::grease(&mut self.capabilities.cipher_suites, cs)?;
+ grease_functions::grease(&mut self.capabilities.proposals, cs)?;
+ grease_functions::grease(&mut self.capabilities.credentials, cs)?;
+
+ let mut new_extensions = grease_functions::grease_extensions(&mut self.extensions, cs)?;
+ self.capabilities.extensions.append(&mut new_extensions);
+
+ Ok(())
+ }
+}
+
+impl KeyPackage {
+ pub fn grease<P: CipherSuiteProvider>(&mut self, cs: &P) -> Result<(), MlsError> {
+ grease_functions::grease_extensions(&mut self.extensions, cs).map(|_| ())
+ }
+
+ pub fn ungreased_extensions(&self) -> ExtensionList {
+ let mut extensions = self.extensions.clone();
+ grease_functions::ungrease_extensions(&mut extensions);
+ extensions
+ }
+}
+
+impl GroupInfo {
+ pub fn grease<P: CipherSuiteProvider>(&mut self, cs: &P) -> Result<(), MlsError> {
+ grease_functions::grease_extensions(&mut self.extensions, cs).map(|_| ())
+ }
+}
+
+impl NewMemberInfo {
+ pub fn ungrease(&mut self) {
+ grease_functions::ungrease_extensions(&mut self.group_info_extensions)
+ }
+}
+
+#[cfg(feature = "grease")]
+mod grease_functions {
+ use core::ops::Deref;
+
+ use mls_rs_core::{
+ crypto::CipherSuiteProvider,
+ error::IntoAnyError,
+ extension::{Extension, ExtensionList, ExtensionType},
+ };
+
+ use super::MlsError;
+
+ pub const GREASE_VALUES: &[u16] = &[
+ 0x0A0A, 0x1A1A, 0x2A2A, 0x3A3A, 0x4A4A, 0x5A5A, 0x6A6A, 0x7A7A, 0x8A8A, 0x9A9A, 0xAAAA,
+ 0xBABA, 0xCACA, 0xDADA, 0xEAEA,
+ ];
+
+ pub fn grease<T: From<u16>, P: CipherSuiteProvider>(
+ array: &mut Vec<T>,
+ cs: &P,
+ ) -> Result<(), MlsError> {
+ array.push(random_grease_value(cs)?.into());
+ Ok(())
+ }
+
+ pub fn grease_extensions<P: CipherSuiteProvider>(
+ extensions: &mut ExtensionList,
+ cs: &P,
+ ) -> Result<Vec<ExtensionType>, MlsError> {
+ let grease_value = random_grease_value(cs)?;
+ extensions.set(Extension::new(grease_value.into(), vec![]));
+ Ok(vec![grease_value.into()])
+ }
+
+ fn random_grease_value<P: CipherSuiteProvider>(cs: &P) -> Result<u16, MlsError> {
+ let index = cs
+ .random_bytes_vec(1)
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?[0];
+
+ Ok(GREASE_VALUES[index as usize % GREASE_VALUES.len()])
+ }
+
+ pub fn ungrease<T: Deref<Target = u16>>(array: &mut Vec<T>) {
+ array.retain(|x| !GREASE_VALUES.contains(&**x));
+ }
+
+ pub fn ungrease_extensions(extensions: &mut ExtensionList) {
+ for e in GREASE_VALUES {
+ extensions.remove((*e).into())
+ }
+ }
+}
+
+#[cfg(not(feature = "grease"))]
+mod grease_functions {
+ use core::ops::Deref;
+
+ use alloc::vec::Vec;
+
+ use mls_rs_core::{
+ crypto::CipherSuiteProvider,
+ extension::{ExtensionList, ExtensionType},
+ };
+
+ use super::MlsError;
+
+ pub fn grease<T: From<u16>, P: CipherSuiteProvider>(
+ _array: &mut [T],
+ _cs: &P,
+ ) -> Result<(), MlsError> {
+ Ok(())
+ }
+
+ pub fn grease_extensions<P: CipherSuiteProvider>(
+ _extensions: &mut ExtensionList,
+ _cs: &P,
+ ) -> Result<Vec<ExtensionType>, MlsError> {
+ Ok(Vec::new())
+ }
+
+ pub fn ungrease<T: Deref<Target = u16>>(_array: &mut [T]) {}
+
+ pub fn ungrease_extensions(_extensions: &mut ExtensionList) {}
+}
+
+#[cfg(all(test, feature = "grease"))]
+mod tests {
+ #[cfg(target_arch = "wasm32")]
+ use wasm_bindgen_test::wasm_bindgen_test as test;
+
+ use std::ops::Deref;
+
+ use mls_rs_core::extension::ExtensionList;
+
+ use crate::{
+ client::test_utils::{test_client_with_key_pkg, TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
+ group::test_utils::test_group,
+ };
+
+ use super::grease_functions::GREASE_VALUES;
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn key_package_is_greased() {
+ let key_pkg = test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "alice")
+ .await
+ .1
+ .into_key_package()
+ .unwrap();
+
+ assert!(is_ext_greased(&key_pkg.extensions));
+ assert!(is_ext_greased(&key_pkg.leaf_node.extensions));
+ assert!(is_greased(&key_pkg.leaf_node.capabilities.cipher_suites));
+ assert!(is_greased(&key_pkg.leaf_node.capabilities.extensions));
+ assert!(is_greased(&key_pkg.leaf_node.capabilities.proposals));
+ assert!(is_greased(&key_pkg.leaf_node.capabilities.credentials));
+
+ assert!(!is_greased(
+ &key_pkg.leaf_node.capabilities.protocol_versions
+ ));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn group_info_is_greased() {
+ let group_info = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE)
+ .await
+ .group
+ .group_info_message_allowing_ext_commit(false)
+ .await
+ .unwrap()
+ .into_group_info()
+ .unwrap();
+
+ assert!(is_ext_greased(&group_info.extensions));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn public_api_is_not_greased() {
+ let member = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE)
+ .await
+ .group
+ .roster()
+ .member_with_index(0)
+ .unwrap();
+
+ assert!(!is_ext_greased(member.extensions()));
+ assert!(!is_greased(member.capabilities().protocol_versions()));
+ assert!(!is_greased(member.capabilities().cipher_suites()));
+ assert!(!is_greased(member.capabilities().extensions()));
+ assert!(!is_greased(member.capabilities().proposals()));
+ assert!(!is_greased(member.capabilities().credentials()));
+ }
+
+ fn is_greased<T: Deref<Target = u16>>(list: &[T]) -> bool {
+ list.iter().any(|v| GREASE_VALUES.contains(v))
+ }
+
+ fn is_ext_greased(extensions: &ExtensionList) -> bool {
+ extensions
+ .iter()
+ .any(|ext| GREASE_VALUES.contains(&*ext.extension_type()))
+ }
+}
diff --git a/src/group/ciphertext_processor.rs b/src/group/ciphertext_processor.rs
new file mode 100644
index 0000000..bf70f5d
--- /dev/null
+++ b/src/group/ciphertext_processor.rs
@@ -0,0 +1,410 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use self::{
+ message_key::MessageKey,
+ reuse_guard::ReuseGuard,
+ sender_data_key::{SenderData, SenderDataAAD, SenderDataKey},
+};
+
+use super::{
+ epoch::EpochSecrets,
+ framing::{ContentType, FramedContent, Sender, WireFormat},
+ message_signature::AuthenticatedContent,
+ padding::PaddingMode,
+ secret_tree::{KeyType, MessageKeyData},
+ GroupContext,
+};
+use crate::{
+ client::MlsError,
+ tree_kem::node::{LeafIndex, NodeIndex},
+};
+use mls_rs_codec::MlsEncode;
+use mls_rs_core::{crypto::CipherSuiteProvider, error::IntoAnyError};
+use zeroize::Zeroizing;
+
+mod message_key;
+mod reuse_guard;
+mod sender_data_key;
+
+#[cfg(feature = "private_message")]
+use super::framing::{PrivateContentAAD, PrivateMessage, PrivateMessageContent};
+
+#[cfg(test)]
+pub use sender_data_key::test_utils::*;
+
+pub(crate) trait GroupStateProvider {
+ fn group_context(&self) -> &GroupContext;
+ fn self_index(&self) -> LeafIndex;
+ fn epoch_secrets_mut(&mut self) -> &mut EpochSecrets;
+ fn epoch_secrets(&self) -> &EpochSecrets;
+}
+
+pub(crate) struct CiphertextProcessor<'a, GS, CP>
+where
+ GS: GroupStateProvider,
+ CP: CipherSuiteProvider,
+{
+ group_state: &'a mut GS,
+ cipher_suite_provider: CP,
+}
+
+impl<'a, GS, CP> CiphertextProcessor<'a, GS, CP>
+where
+ GS: GroupStateProvider,
+ CP: CipherSuiteProvider,
+{
+ pub fn new(
+ group_state: &'a mut GS,
+ cipher_suite_provider: CP,
+ ) -> CiphertextProcessor<'a, GS, CP> {
+ Self {
+ group_state,
+ cipher_suite_provider,
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn next_encryption_key(
+ &mut self,
+ key_type: KeyType,
+ ) -> Result<MessageKeyData, MlsError> {
+ let self_index = NodeIndex::from(self.group_state.self_index());
+
+ self.group_state
+ .epoch_secrets_mut()
+ .secret_tree
+ .next_message_key(&self.cipher_suite_provider, self_index, key_type)
+ .await
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn decryption_key(
+ &mut self,
+ sender: LeafIndex,
+ key_type: KeyType,
+ generation: u32,
+ ) -> Result<MessageKeyData, MlsError> {
+ let sender = NodeIndex::from(sender);
+
+ self.group_state
+ .epoch_secrets_mut()
+ .secret_tree
+ .message_key_generation(&self.cipher_suite_provider, sender, key_type, generation)
+ .await
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn seal(
+ &mut self,
+ auth_content: AuthenticatedContent,
+ padding: PaddingMode,
+ ) -> Result<PrivateMessage, MlsError> {
+ if Sender::Member(*self.group_state.self_index()) != auth_content.content.sender {
+ return Err(MlsError::InvalidSender);
+ }
+
+ let content_type = ContentType::from(&auth_content.content.content);
+ let authenticated_data = auth_content.content.authenticated_data;
+
+ // Build a ciphertext content using the plaintext content and signature
+ let private_content = PrivateMessageContent {
+ content: auth_content.content.content,
+ auth: auth_content.auth,
+ };
+
+ // Build ciphertext aad using the plaintext message
+ let aad = PrivateContentAAD {
+ group_id: auth_content.content.group_id,
+ epoch: auth_content.content.epoch,
+ content_type,
+ authenticated_data: authenticated_data.clone(),
+ };
+
+ // Generate a 4 byte reuse guard
+ let reuse_guard = ReuseGuard::random(&self.cipher_suite_provider)
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
+
+ // Grab an encryption key from the current epoch's key schedule
+ let key_type = match &content_type {
+ ContentType::Application => KeyType::Application,
+ _ => KeyType::Handshake,
+ };
+
+ let mut serialized_private_content = private_content.mls_encode_to_vec()?;
+
+ // Apply padding to private content based on the current padding mode.
+ serialized_private_content.resize(padding.padded_size(serialized_private_content.len()), 0);
+
+ let serialized_private_content = Zeroizing::new(serialized_private_content);
+
+ // Encrypt the ciphertext content using the encryption key and a nonce that is
+ // reuse safe by xor the reuse guard with the first 4 bytes
+ let self_index = self.group_state.self_index();
+
+ let key_data = self.next_encryption_key(key_type).await?;
+ let generation = key_data.generation;
+
+ let ciphertext = MessageKey::new(key_data)
+ .encrypt(
+ &self.cipher_suite_provider,
+ &serialized_private_content,
+ &aad.mls_encode_to_vec()?,
+ &reuse_guard,
+ )
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
+
+ // Construct an mls sender data struct using the plaintext sender info, the generation
+ // of the key schedule encryption key, and the reuse guard used to encrypt ciphertext
+ let sender_data = SenderData {
+ sender: self_index,
+ generation,
+ reuse_guard,
+ };
+
+ let sender_data_aad = SenderDataAAD {
+ group_id: self.group_state.group_context().group_id.clone(),
+ epoch: self.group_state.group_context().epoch,
+ content_type,
+ };
+
+ // Encrypt the sender data with the derived sender_key and sender_nonce from the current
+ // epoch's key schedule
+ let sender_data_key = SenderDataKey::new(
+ &self.group_state.epoch_secrets().sender_data_secret,
+ &ciphertext,
+ &self.cipher_suite_provider,
+ )
+ .await?;
+
+ let encrypted_sender_data = sender_data_key.seal(&sender_data, &sender_data_aad).await?;
+
+ Ok(PrivateMessage {
+ group_id: self.group_state.group_context().group_id.clone(),
+ epoch: self.group_state.group_context().epoch,
+ content_type,
+ authenticated_data,
+ encrypted_sender_data,
+ ciphertext,
+ })
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn open(
+ &mut self,
+ ciphertext: &PrivateMessage,
+ ) -> Result<AuthenticatedContent, MlsError> {
+ // Decrypt the sender data with the derived sender_key and sender_nonce from the message
+ // epoch's key schedule
+ let sender_data_aad = SenderDataAAD {
+ group_id: self.group_state.group_context().group_id.clone(),
+ epoch: self.group_state.group_context().epoch,
+ content_type: ciphertext.content_type,
+ };
+
+ let sender_data_key = SenderDataKey::new(
+ &self.group_state.epoch_secrets().sender_data_secret,
+ &ciphertext.ciphertext,
+ &self.cipher_suite_provider,
+ )
+ .await?;
+
+ let sender_data = sender_data_key
+ .open(&ciphertext.encrypted_sender_data, &sender_data_aad)
+ .await?;
+
+ if self.group_state.self_index() == sender_data.sender {
+ return Err(MlsError::CantProcessMessageFromSelf);
+ }
+
+ // Grab a decryption key from the message epoch's key schedule
+ let key_type = match &ciphertext.content_type {
+ ContentType::Application => KeyType::Application,
+ _ => KeyType::Handshake,
+ };
+
+ // Decrypt the content of the message using the grabbed key
+ let key = self
+ .decryption_key(sender_data.sender, key_type, sender_data.generation)
+ .await?;
+
+ let sender = Sender::Member(*sender_data.sender);
+
+ let decrypted_content = MessageKey::new(key)
+ .decrypt(
+ &self.cipher_suite_provider,
+ &ciphertext.ciphertext,
+ &PrivateContentAAD::from(ciphertext).mls_encode_to_vec()?,
+ &sender_data.reuse_guard,
+ )
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
+
+ let ciphertext_content =
+ PrivateMessageContent::mls_decode(&mut &**decrypted_content, ciphertext.content_type)?;
+
+ // Build the MLS plaintext object and process it
+ let auth_content = AuthenticatedContent {
+ wire_format: WireFormat::PrivateMessage,
+ content: FramedContent {
+ group_id: ciphertext.group_id.clone(),
+ epoch: ciphertext.epoch,
+ sender,
+ authenticated_data: ciphertext.authenticated_data.clone(),
+ content: ciphertext_content.content,
+ },
+ auth: ciphertext_content.auth,
+ };
+
+ Ok(auth_content)
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use crate::{
+ cipher_suite::CipherSuite,
+ client::test_utils::{TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
+ crypto::{
+ test_utils::{test_cipher_suite_provider, TestCryptoProvider},
+ CipherSuiteProvider,
+ },
+ group::{
+ framing::{ApplicationData, Content, Sender, WireFormat},
+ message_signature::AuthenticatedContent,
+ padding::PaddingMode,
+ test_utils::{random_bytes, test_group, TestGroup},
+ },
+ tree_kem::node::LeafIndex,
+ };
+
+ use super::{CiphertextProcessor, GroupStateProvider, MlsError};
+
+ use alloc::vec;
+ use assert_matches::assert_matches;
+
+ struct TestData {
+ group: TestGroup,
+ content: AuthenticatedContent,
+ }
+
+ fn test_processor(
+ group: &mut TestGroup,
+ cipher_suite: CipherSuite,
+ ) -> CiphertextProcessor<'_, impl GroupStateProvider, impl CipherSuiteProvider> {
+ CiphertextProcessor::new(&mut group.group, test_cipher_suite_provider(cipher_suite))
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn test_data(cipher_suite: CipherSuite) -> TestData {
+ let provider = test_cipher_suite_provider(cipher_suite);
+
+ let group = test_group(TEST_PROTOCOL_VERSION, cipher_suite).await;
+
+ let content = AuthenticatedContent::new_signed(
+ &provider,
+ group.group.context(),
+ Sender::Member(0),
+ Content::Application(ApplicationData::from(b"test".to_vec())),
+ &group.group.signer,
+ WireFormat::PrivateMessage,
+ vec![],
+ )
+ .await
+ .unwrap();
+
+ TestData { group, content }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_encrypt_decrypt() {
+ for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
+ let mut test_data = test_data(cipher_suite).await;
+ let mut receiver_group = test_data.group.clone();
+
+ let mut ciphertext_processor = test_processor(&mut test_data.group, cipher_suite);
+
+ let ciphertext = ciphertext_processor
+ .seal(test_data.content.clone(), PaddingMode::StepFunction)
+ .await
+ .unwrap();
+
+ receiver_group.group.private_tree.self_index = LeafIndex::new(1);
+
+ let mut receiver_processor = test_processor(&mut receiver_group, cipher_suite);
+
+ let decrypted = receiver_processor.open(&ciphertext).await.unwrap();
+
+ assert_eq!(decrypted, test_data.content);
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_padding_use() {
+ let mut test_data = test_data(TEST_CIPHER_SUITE).await;
+ let mut ciphertext_processor = test_processor(&mut test_data.group, TEST_CIPHER_SUITE);
+
+ let ciphertext_step = ciphertext_processor
+ .seal(test_data.content.clone(), PaddingMode::StepFunction)
+ .await
+ .unwrap();
+
+ let ciphertext_no_pad = ciphertext_processor
+ .seal(test_data.content.clone(), PaddingMode::None)
+ .await
+ .unwrap();
+
+ assert!(ciphertext_step.ciphertext.len() > ciphertext_no_pad.ciphertext.len());
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_invalid_sender() {
+ let mut test_data = test_data(TEST_CIPHER_SUITE).await;
+ test_data.content.content.sender = Sender::Member(3);
+
+ let mut ciphertext_processor = test_processor(&mut test_data.group, TEST_CIPHER_SUITE);
+
+ let res = ciphertext_processor
+ .seal(test_data.content, PaddingMode::None)
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidSender))
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_cant_process_from_self() {
+ let mut test_data = test_data(TEST_CIPHER_SUITE).await;
+
+ let mut ciphertext_processor = test_processor(&mut test_data.group, TEST_CIPHER_SUITE);
+
+ let ciphertext = ciphertext_processor
+ .seal(test_data.content, PaddingMode::None)
+ .await
+ .unwrap();
+
+ let res = ciphertext_processor.open(&ciphertext).await;
+
+ assert_matches!(res, Err(MlsError::CantProcessMessageFromSelf))
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_decryption_error() {
+ let mut test_data = test_data(TEST_CIPHER_SUITE).await;
+ let mut receiver_group = test_data.group.clone();
+ let mut ciphertext_processor = test_processor(&mut test_data.group, TEST_CIPHER_SUITE);
+
+ let mut ciphertext = ciphertext_processor
+ .seal(test_data.content.clone(), PaddingMode::StepFunction)
+ .await
+ .unwrap();
+
+ ciphertext.ciphertext = random_bytes(ciphertext.ciphertext.len());
+ receiver_group.group.private_tree.self_index = LeafIndex::new(1);
+
+ let res = ciphertext_processor.open(&ciphertext).await;
+
+ assert!(res.is_err());
+ }
+}
diff --git a/src/group/ciphertext_processor/message_key.rs b/src/group/ciphertext_processor/message_key.rs
new file mode 100644
index 0000000..256db7d
--- /dev/null
+++ b/src/group/ciphertext_processor/message_key.rs
@@ -0,0 +1,57 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::vec::Vec;
+use zeroize::Zeroizing;
+
+use crate::{crypto::CipherSuiteProvider, group::secret_tree::MessageKeyData};
+
+use super::reuse_guard::ReuseGuard;
+
+#[derive(Debug, PartialEq, Eq)]
+pub struct MessageKey(MessageKeyData);
+
+impl MessageKey {
+ pub(crate) fn new(key: MessageKeyData) -> MessageKey {
+ MessageKey(key)
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn encrypt<P: CipherSuiteProvider>(
+ &self,
+ provider: &P,
+ data: &[u8],
+ aad: &[u8],
+ reuse_guard: &ReuseGuard,
+ ) -> Result<Vec<u8>, P::Error> {
+ provider
+ .aead_seal(
+ &self.0.key,
+ data,
+ Some(aad),
+ &reuse_guard.apply(&self.0.nonce),
+ )
+ .await
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn decrypt<P: CipherSuiteProvider>(
+ &self,
+ provider: &P,
+ data: &[u8],
+ aad: &[u8],
+ reuse_guard: &ReuseGuard,
+ ) -> Result<Zeroizing<Vec<u8>>, P::Error> {
+ provider
+ .aead_open(
+ &self.0.key,
+ data,
+ Some(aad),
+ &reuse_guard.apply(&self.0.nonce),
+ )
+ .await
+ }
+}
+
+// TODO: Write test vectors
diff --git a/src/group/ciphertext_processor/reuse_guard.rs b/src/group/ciphertext_processor/reuse_guard.rs
new file mode 100644
index 0000000..10e1db1
--- /dev/null
+++ b/src/group/ciphertext_processor/reuse_guard.rs
@@ -0,0 +1,133 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::vec::Vec;
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+
+use crate::CipherSuiteProvider;
+
+const REUSE_GUARD_SIZE: usize = 4;
+
+#[derive(Clone, Debug, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
+pub(crate) struct ReuseGuard([u8; REUSE_GUARD_SIZE]);
+
+impl From<[u8; REUSE_GUARD_SIZE]> for ReuseGuard {
+ fn from(value: [u8; REUSE_GUARD_SIZE]) -> Self {
+ ReuseGuard(value)
+ }
+}
+
+impl From<ReuseGuard> for [u8; REUSE_GUARD_SIZE] {
+ fn from(value: ReuseGuard) -> Self {
+ value.0
+ }
+}
+
+impl AsRef<[u8]> for ReuseGuard {
+ fn as_ref(&self) -> &[u8] {
+ &self.0
+ }
+}
+
+impl ReuseGuard {
+ pub(crate) fn random<P: CipherSuiteProvider>(provider: &P) -> Result<Self, P::Error> {
+ let mut data = [0u8; REUSE_GUARD_SIZE];
+ provider.random_bytes(&mut data).map(|_| ReuseGuard(data))
+ }
+
+ pub(crate) fn apply(&self, nonce: &[u8]) -> Vec<u8> {
+ let mut new_nonce = nonce.to_vec();
+
+ new_nonce
+ .iter_mut()
+ .zip(self.as_ref().iter())
+ .for_each(|(nonce_byte, guard_byte)| *nonce_byte ^= guard_byte);
+
+ new_nonce
+ }
+}
+
+#[cfg(test)]
+mod test_utils {
+ use alloc::vec::Vec;
+
+ use super::{ReuseGuard, REUSE_GUARD_SIZE};
+
+ impl ReuseGuard {
+ pub fn new(guard: Vec<u8>) -> Self {
+ let mut data = [0u8; REUSE_GUARD_SIZE];
+ data.copy_from_slice(&guard);
+ Self(data)
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use alloc::vec::Vec;
+ use mls_rs_core::crypto::CipherSuiteProvider;
+
+ use crate::{
+ client::test_utils::TEST_CIPHER_SUITE, crypto::test_utils::test_cipher_suite_provider,
+ };
+
+ use super::{ReuseGuard, REUSE_GUARD_SIZE};
+
+ #[test]
+ fn test_random_generation() {
+ let test_guard =
+ ReuseGuard::random(&test_cipher_suite_provider(TEST_CIPHER_SUITE)).unwrap();
+
+ (0..1000).for_each(|_| {
+ let next = ReuseGuard::random(&test_cipher_suite_provider(TEST_CIPHER_SUITE)).unwrap();
+ assert_ne!(next, test_guard);
+ })
+ }
+
+ #[derive(Debug, serde::Serialize, serde::Deserialize)]
+ struct TestCase {
+ nonce: Vec<u8>,
+ guard: [u8; REUSE_GUARD_SIZE],
+ result: Vec<u8>,
+ }
+
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn generate_reuse_guard_test_cases() -> Vec<TestCase> {
+ let provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ [16, 32]
+ .into_iter()
+ .map(
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ |len| {
+ let nonce = provider.random_bytes_vec(len).unwrap();
+ let guard = ReuseGuard::random(&provider).unwrap();
+
+ let result = guard.apply(&nonce);
+
+ TestCase {
+ nonce,
+ guard: guard.into(),
+ result,
+ }
+ },
+ )
+ .collect()
+ }
+
+ fn load_test_cases() -> Vec<TestCase> {
+ load_test_case_json!(reuse_guard, generate_reuse_guard_test_cases())
+ }
+
+ #[test]
+ fn test_reuse_guard() {
+ let test_cases = load_test_cases();
+
+ for case in test_cases {
+ let guard = ReuseGuard::from(case.guard);
+ let result = guard.apply(&case.nonce);
+ assert_eq!(result, case.result);
+ }
+ }
+}
diff --git a/src/group/ciphertext_processor/sender_data_key.rs b/src/group/ciphertext_processor/sender_data_key.rs
new file mode 100644
index 0000000..983920a
--- /dev/null
+++ b/src/group/ciphertext_processor/sender_data_key.rs
@@ -0,0 +1,360 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::vec::Vec;
+use core::fmt::{self, Debug};
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+use mls_rs_core::error::IntoAnyError;
+use zeroize::Zeroizing;
+
+use crate::{
+ client::MlsError,
+ crypto::CipherSuiteProvider,
+ group::{epoch::SenderDataSecret, framing::ContentType, key_schedule::kdf_expand_with_label},
+ tree_kem::node::LeafIndex,
+};
+
+use super::ReuseGuard;
+
+#[derive(Clone, Debug, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
+pub(crate) struct SenderData {
+ pub sender: LeafIndex,
+ pub generation: u32,
+ pub reuse_guard: ReuseGuard,
+}
+
+#[derive(Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
+pub(crate) struct SenderDataAAD {
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ pub group_id: Vec<u8>,
+ pub epoch: u64,
+ pub content_type: ContentType,
+}
+
+impl Debug for SenderDataAAD {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("SenderDataAAD")
+ .field(
+ "group_id",
+ &mls_rs_core::debug::pretty_group_id(&self.group_id),
+ )
+ .field("epoch", &self.epoch)
+ .field("content_type", &self.content_type)
+ .finish()
+ }
+}
+
+pub(crate) struct SenderDataKey<'a, CP: CipherSuiteProvider> {
+ pub(crate) key: Zeroizing<Vec<u8>>,
+ pub(crate) nonce: Zeroizing<Vec<u8>>,
+ cipher_suite_provider: &'a CP,
+}
+
+impl<CP: CipherSuiteProvider + Debug> Debug for SenderDataKey<'_, CP> {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("SenderDataKey")
+ .field("key", &mls_rs_core::debug::pretty_bytes(&self.key))
+ .field("nonce", &mls_rs_core::debug::pretty_bytes(&self.nonce))
+ .field("cipher_suite_provider", self.cipher_suite_provider)
+ .finish()
+ }
+}
+
+impl<'a, CP: CipherSuiteProvider> SenderDataKey<'a, CP> {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(super) async fn new(
+ sender_data_secret: &SenderDataSecret,
+ ciphertext: &[u8],
+ cipher_suite_provider: &'a CP,
+ ) -> Result<SenderDataKey<'a, CP>, MlsError> {
+ // Sample the first extract_size bytes of the ciphertext, and if it is shorter, just use
+ // the ciphertext itself
+ let extract_size = cipher_suite_provider.kdf_extract_size();
+ let ciphertext_sample = ciphertext.get(0..extract_size).unwrap_or(ciphertext);
+
+ // Generate a sender data key and nonce using the sender_data_secret from the current
+ // epoch's key schedule
+ let key = kdf_expand_with_label(
+ cipher_suite_provider,
+ sender_data_secret,
+ b"key",
+ ciphertext_sample,
+ Some(cipher_suite_provider.aead_key_size()),
+ )
+ .await?;
+
+ let nonce = kdf_expand_with_label(
+ cipher_suite_provider,
+ sender_data_secret,
+ b"nonce",
+ ciphertext_sample,
+ Some(cipher_suite_provider.aead_nonce_size()),
+ )
+ .await?;
+
+ Ok(Self {
+ key,
+ nonce,
+ cipher_suite_provider,
+ })
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn seal(
+ &self,
+ sender_data: &SenderData,
+ aad: &SenderDataAAD,
+ ) -> Result<Vec<u8>, MlsError> {
+ self.cipher_suite_provider
+ .aead_seal(
+ &self.key,
+ &sender_data.mls_encode_to_vec()?,
+ Some(&aad.mls_encode_to_vec()?),
+ &self.nonce,
+ )
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn open(
+ &self,
+ sender_data: &[u8],
+ aad: &SenderDataAAD,
+ ) -> Result<SenderData, MlsError> {
+ self.cipher_suite_provider
+ .aead_open(
+ &self.key,
+ sender_data,
+ Some(&aad.mls_encode_to_vec()?),
+ &self.nonce,
+ )
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
+ .and_then(|data| SenderData::mls_decode(&mut &**data).map_err(From::from))
+ }
+}
+
+#[cfg(test)]
+pub(crate) mod test_utils {
+ use alloc::vec::Vec;
+ use mls_rs_core::crypto::CipherSuiteProvider;
+
+ use super::SenderDataKey;
+
+ #[derive(Debug, serde::Serialize, serde::Deserialize)]
+ pub struct InteropSenderData {
+ #[serde(with = "hex::serde")]
+ pub sender_data_secret: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub ciphertext: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub key: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub nonce: Vec<u8>,
+ }
+
+ impl InteropSenderData {
+ #[cfg(not(mls_build_async))]
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ pub(crate) fn new<P: CipherSuiteProvider>(cs: &P) -> Self {
+ let secret = cs.random_bytes_vec(cs.kdf_extract_size()).unwrap().into();
+ let ciphertext = cs.random_bytes_vec(77).unwrap();
+ let key = SenderDataKey::new(&secret, &ciphertext, cs).unwrap();
+ let secret = (*secret).clone();
+
+ Self {
+ ciphertext,
+ key: key.key.to_vec(),
+ nonce: key.nonce.to_vec(),
+ sender_data_secret: secret,
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn verify<P: CipherSuiteProvider>(&self, cs: &P) {
+ let secret = self.sender_data_secret.clone().into();
+
+ let key = SenderDataKey::new(&secret, &self.ciphertext, cs)
+ .await
+ .unwrap();
+
+ assert_eq!(key.key.to_vec(), self.key, "sender data key mismatch");
+ assert_eq!(key.nonce.to_vec(), self.nonce, "sender data nonce mismatch");
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+
+ use alloc::vec::Vec;
+ #[cfg(target_arch = "wasm32")]
+ use wasm_bindgen_test::wasm_bindgen_test as test;
+
+ use crate::{
+ crypto::test_utils::try_test_cipher_suite_provider,
+ group::{ciphertext_processor::reuse_guard::ReuseGuard, framing::ContentType},
+ tree_kem::node::LeafIndex,
+ };
+
+ use super::{SenderData, SenderDataAAD, SenderDataKey};
+
+ #[cfg(not(mls_build_async))]
+ use crate::{
+ cipher_suite::CipherSuite, crypto::test_utils::test_cipher_suite_provider,
+ group::test_utils::random_bytes, CipherSuiteProvider,
+ };
+
+ #[derive(serde::Deserialize, serde::Serialize)]
+ struct TestCase {
+ cipher_suite: u16,
+ #[serde(with = "hex::serde")]
+ secret: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ ciphertext_bytes: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ expected_key: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ expected_nonce: Vec<u8>,
+ sender_data: TestSenderData,
+ sender_data_aad: TestSenderDataAAD,
+ #[serde(with = "hex::serde")]
+ expected_ciphertext: Vec<u8>,
+ }
+
+ #[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
+ struct TestSenderData {
+ sender: u32,
+ generation: u32,
+ #[serde(with = "hex::serde")]
+ reuse_guard: Vec<u8>,
+ }
+
+ impl From<TestSenderData> for SenderData {
+ fn from(value: TestSenderData) -> Self {
+ let reuse_guard = ReuseGuard::new(value.reuse_guard);
+
+ Self {
+ sender: LeafIndex(value.sender),
+ generation: value.generation,
+ reuse_guard,
+ }
+ }
+ }
+
+ #[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
+ struct TestSenderDataAAD {
+ epoch: u64,
+ #[serde(with = "hex::serde")]
+ group_id: Vec<u8>,
+ }
+
+ impl From<TestSenderDataAAD> for SenderDataAAD {
+ fn from(value: TestSenderDataAAD) -> Self {
+ Self {
+ epoch: value.epoch,
+ group_id: value.group_id,
+ content_type: ContentType::Application,
+ }
+ }
+ }
+
+ #[cfg(not(mls_build_async))]
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn generate_test_vector() -> Vec<TestCase> {
+ let test_cases = CipherSuite::all().map(test_cipher_suite_provider).map(
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ |provider| {
+ let ext_size = provider.kdf_extract_size();
+ let secret = random_bytes(ext_size).into();
+ let ciphertext_sizes = [ext_size - 5, ext_size, ext_size + 5];
+
+ let sender_data = TestSenderData {
+ sender: 0,
+ generation: 13,
+ reuse_guard: random_bytes(4),
+ };
+
+ let sender_data_aad = TestSenderDataAAD {
+ group_id: b"group".to_vec(),
+ epoch: 42,
+ };
+
+ ciphertext_sizes.into_iter().map(
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ move |ciphertext_size| {
+ let ciphertext_bytes = random_bytes(ciphertext_size);
+
+ let sender_data_key =
+ SenderDataKey::new(&secret, &ciphertext_bytes, &provider).unwrap();
+
+ let expected_ciphertext = sender_data_key
+ .seal(&sender_data.clone().into(), &sender_data_aad.clone().into())
+ .unwrap();
+
+ TestCase {
+ cipher_suite: provider.cipher_suite().into(),
+ secret: secret.to_vec(),
+ ciphertext_bytes,
+ expected_key: sender_data_key.key.to_vec(),
+ expected_nonce: sender_data_key.nonce.to_vec(),
+ sender_data: sender_data.clone(),
+ sender_data_aad: sender_data_aad.clone(),
+ expected_ciphertext,
+ }
+ },
+ )
+ },
+ );
+
+ test_cases.flatten().collect()
+ }
+
+ #[cfg(mls_build_async)]
+ fn generate_test_vector() -> Vec<TestCase> {
+ panic!("Tests cannot be generated in async mode");
+ }
+
+ fn load_test_cases() -> Vec<TestCase> {
+ load_test_case_json!(sender_data_key_test_vector, generate_test_vector())
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sender_data_key_test_vector() {
+ for test_case in load_test_cases() {
+ let Some(provider) = try_test_cipher_suite_provider(test_case.cipher_suite) else {
+ continue;
+ };
+
+ let sender_data_key = SenderDataKey::new(
+ &test_case.secret.into(),
+ &test_case.ciphertext_bytes,
+ &provider,
+ )
+ .await
+ .unwrap();
+
+ assert_eq!(sender_data_key.key.to_vec(), test_case.expected_key);
+ assert_eq!(sender_data_key.nonce.to_vec(), test_case.expected_nonce);
+
+ let sender_data = test_case.sender_data.into();
+ let sender_data_aad = test_case.sender_data_aad.into();
+
+ let ciphertext = sender_data_key
+ .seal(&sender_data, &sender_data_aad)
+ .await
+ .unwrap();
+
+ assert_eq!(ciphertext, test_case.expected_ciphertext);
+
+ let plaintext = sender_data_key
+ .open(&ciphertext, &sender_data_aad)
+ .await
+ .unwrap();
+
+ assert_eq!(plaintext, sender_data);
+ }
+ }
+}
diff --git a/src/group/commit.rs b/src/group/commit.rs
new file mode 100644
index 0000000..c201057
--- /dev/null
+++ b/src/group/commit.rs
@@ -0,0 +1,1601 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::vec;
+use alloc::vec::Vec;
+use core::fmt::{self, Debug};
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+use mls_rs_core::{
+ crypto::{CipherSuiteProvider, SignatureSecretKey},
+ error::IntoAnyError,
+};
+
+use crate::{
+ cipher_suite::CipherSuite,
+ client::MlsError,
+ client_config::ClientConfig,
+ extension::RatchetTreeExt,
+ identity::SigningIdentity,
+ protocol_version::ProtocolVersion,
+ signer::Signable,
+ tree_kem::{
+ kem::TreeKem, node::LeafIndex, path_secret::PathSecret, TreeKemPrivate, UpdatePath,
+ },
+ ExtensionList, MlsRules,
+};
+
+#[cfg(all(not(mls_build_async), feature = "rayon"))]
+use {crate::iter::ParallelIteratorExt, rayon::prelude::*};
+
+use crate::tree_kem::leaf_node::LeafNode;
+
+#[cfg(not(feature = "private_message"))]
+use crate::WireFormat;
+
+#[cfg(feature = "psk")]
+use crate::{
+ group::{JustPreSharedKeyID, PskGroupId, ResumptionPSKUsage, ResumptionPsk},
+ psk::ExternalPskId,
+};
+
+use super::{
+ confirmation_tag::ConfirmationTag,
+ framing::{Content, MlsMessage, MlsMessagePayload, Sender},
+ key_schedule::{KeySchedule, WelcomeSecret},
+ message_processor::{path_update_required, MessageProcessor},
+ message_signature::AuthenticatedContent,
+ mls_rules::CommitDirection,
+ proposal::{Proposal, ProposalOrRef},
+ ConfirmedTranscriptHash, EncryptedGroupSecrets, ExportedTree, Group, GroupContext, GroupInfo,
+ Welcome,
+};
+
+#[cfg(not(feature = "by_ref_proposal"))]
+use super::proposal_cache::prepare_commit;
+
+#[cfg(feature = "custom_proposal")]
+use super::proposal::CustomProposal;
+
+#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(mls_rs_core::arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub(crate) struct Commit {
+ pub proposals: Vec<ProposalOrRef>,
+ pub path: Option<UpdatePath>,
+}
+
+#[derive(Clone, PartialEq, Debug, MlsEncode, MlsDecode, MlsSize)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub(super) struct CommitGeneration {
+ pub content: AuthenticatedContent,
+ pub pending_private_tree: TreeKemPrivate,
+ pub pending_commit_secret: PathSecret,
+ pub commit_message_hash: CommitHash,
+}
+
+#[derive(Clone, PartialEq, MlsEncode, MlsDecode, MlsSize)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub(crate) struct CommitHash(
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
+ Vec<u8>,
+);
+
+impl Debug for CommitHash {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ mls_rs_core::debug::pretty_bytes(&self.0)
+ .named("CommitHash")
+ .fmt(f)
+ }
+}
+
+impl CommitHash {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn compute<CS: CipherSuiteProvider>(
+ cs: &CS,
+ commit: &MlsMessage,
+ ) -> Result<Self, MlsError> {
+ cs.hash(&commit.mls_encode_to_vec()?)
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
+ .map(Self)
+ }
+}
+
+#[cfg_attr(
+ all(feature = "ffi", not(test)),
+ safer_ffi_gen::ffi_type(clone, opaque)
+)]
+#[derive(Clone, Debug)]
+#[non_exhaustive]
+/// Result of MLS commit operation using
+/// [`Group::commit`](crate::group::Group::commit) or
+/// [`CommitBuilder::build`](CommitBuilder::build).
+pub struct CommitOutput {
+ /// Commit message to send to other group members.
+ pub commit_message: MlsMessage,
+ /// Welcome messages to send to new group members. If the commit does not add members,
+ /// this list is empty. Otherwise, if [`MlsRules::commit_options`] returns `single_welcome_message`
+ /// set to true, then this list contains a single message sent to all members. Else, the list
+ /// contains one message for each added member. Recipients of each message can be identified using
+ /// [`MlsMessage::key_package_reference`] of their key packages and
+ /// [`MlsMessage::welcome_key_package_references`].
+ pub welcome_messages: Vec<MlsMessage>,
+ /// Ratchet tree that can be sent out of band if
+ /// `ratchet_tree_extension` is not used according to
+ /// [`MlsRules::commit_options`].
+ pub ratchet_tree: Option<ExportedTree<'static>>,
+ /// A group info that can be provided to new members in order to enable external commit
+ /// functionality. This value is set if [`MlsRules::commit_options`] returns
+ /// `allow_external_commit` set to true.
+ pub external_commit_group_info: Option<MlsMessage>,
+ /// Proposals that were received in the prior epoch but not included in the following commit.
+ #[cfg(feature = "by_ref_proposal")]
+ pub unused_proposals: Vec<crate::mls_rules::ProposalInfo<Proposal>>,
+}
+
+#[cfg_attr(all(feature = "ffi", not(test)), ::safer_ffi_gen::safer_ffi_gen)]
+impl CommitOutput {
+ /// Commit message to send to other group members.
+ #[cfg(feature = "ffi")]
+ pub fn commit_message(&self) -> &MlsMessage {
+ &self.commit_message
+ }
+
+ /// Welcome message to send to new group members.
+ #[cfg(feature = "ffi")]
+ pub fn welcome_messages(&self) -> &[MlsMessage] {
+ &self.welcome_messages
+ }
+
+ /// Ratchet tree that can be sent out of band if
+ /// `ratchet_tree_extension` is not used according to
+ /// [`MlsRules::commit_options`].
+ #[cfg(feature = "ffi")]
+ pub fn ratchet_tree(&self) -> Option<&ExportedTree<'static>> {
+ self.ratchet_tree.as_ref()
+ }
+
+ /// A group info that can be provided to new members in order to enable external commit
+ /// functionality. This value is set if [`MlsRules::commit_options`] returns
+ /// `allow_external_commit` set to true.
+ #[cfg(feature = "ffi")]
+ pub fn external_commit_group_info(&self) -> Option<&MlsMessage> {
+ self.external_commit_group_info.as_ref()
+ }
+
+ /// Proposals that were received in the prior epoch but not included in the following commit.
+ #[cfg(all(feature = "ffi", feature = "by_ref_proposal"))]
+ pub fn unused_proposals(&self) -> &[crate::mls_rules::ProposalInfo<Proposal>] {
+ &self.unused_proposals
+ }
+}
+
+/// Build a commit with multiple proposals by-value.
+///
+/// Proposals within a commit can be by-value or by-reference.
+/// Proposals received during the current epoch will be added to the resulting
+/// commit by-reference automatically so long as they pass the rules defined
+/// in the current
+/// [proposal rules](crate::client_builder::ClientBuilder::mls_rules).
+pub struct CommitBuilder<'a, C>
+where
+ C: ClientConfig + Clone,
+{
+ group: &'a mut Group<C>,
+ pub(super) proposals: Vec<Proposal>,
+ authenticated_data: Vec<u8>,
+ group_info_extensions: ExtensionList,
+ new_signer: Option<SignatureSecretKey>,
+ new_signing_identity: Option<SigningIdentity>,
+}
+
+impl<'a, C> CommitBuilder<'a, C>
+where
+ C: ClientConfig + Clone,
+{
+ /// Insert an [`AddProposal`](crate::group::proposal::AddProposal) into
+ /// the current commit that is being built.
+ pub fn add_member(mut self, key_package: MlsMessage) -> Result<CommitBuilder<'a, C>, MlsError> {
+ let proposal = self.group.add_proposal(key_package)?;
+ self.proposals.push(proposal);
+ Ok(self)
+ }
+
+ /// Set group info extensions that will be inserted into the resulting
+ /// [welcome messages](CommitOutput::welcome_messages) for new members.
+ ///
+ /// Group info extensions that are transmitted as part of a welcome message
+ /// are encrypted along with other private values.
+ ///
+ /// These extensions can be retrieved as part of
+ /// [`NewMemberInfo`](crate::group::NewMemberInfo) that is returned
+ /// by joining the group via
+ /// [`Client::join_group`](crate::Client::join_group).
+ pub fn set_group_info_ext(self, extensions: ExtensionList) -> Self {
+ Self {
+ group_info_extensions: extensions,
+ ..self
+ }
+ }
+
+ /// Insert a [`RemoveProposal`](crate::group::proposal::RemoveProposal) into
+ /// the current commit that is being built.
+ pub fn remove_member(mut self, index: u32) -> Result<Self, MlsError> {
+ let proposal = self.group.remove_proposal(index)?;
+ self.proposals.push(proposal);
+ Ok(self)
+ }
+
+ /// Insert a
+ /// [`GroupContextExtensions`](crate::group::proposal::Proposal::GroupContextExtensions)
+ /// into the current commit that is being built.
+ pub fn set_group_context_ext(mut self, extensions: ExtensionList) -> Result<Self, MlsError> {
+ let proposal = self.group.group_context_extensions_proposal(extensions);
+ self.proposals.push(proposal);
+ Ok(self)
+ }
+
+ /// Insert a
+ /// [`PreSharedKeyProposal`](crate::group::proposal::PreSharedKeyProposal) with
+ /// an external PSK into the current commit that is being built.
+ #[cfg(feature = "psk")]
+ pub fn add_external_psk(mut self, psk_id: ExternalPskId) -> Result<Self, MlsError> {
+ let key_id = JustPreSharedKeyID::External(psk_id);
+ let proposal = self.group.psk_proposal(key_id)?;
+ self.proposals.push(proposal);
+ Ok(self)
+ }
+
+ /// Insert a
+ /// [`PreSharedKeyProposal`](crate::group::proposal::PreSharedKeyProposal) with
+ /// a resumption PSK into the current commit that is being built.
+ #[cfg(feature = "psk")]
+ pub fn add_resumption_psk(mut self, psk_epoch: u64) -> Result<Self, MlsError> {
+ let psk_id = ResumptionPsk {
+ psk_epoch,
+ usage: ResumptionPSKUsage::Application,
+ psk_group_id: PskGroupId(self.group.group_id().to_vec()),
+ };
+
+ let key_id = JustPreSharedKeyID::Resumption(psk_id);
+ let proposal = self.group.psk_proposal(key_id)?;
+ self.proposals.push(proposal);
+ Ok(self)
+ }
+
+ /// Insert a [`ReInitProposal`](crate::group::proposal::ReInitProposal) into
+ /// the current commit that is being built.
+ pub fn reinit(
+ mut self,
+ group_id: Option<Vec<u8>>,
+ version: ProtocolVersion,
+ cipher_suite: CipherSuite,
+ extensions: ExtensionList,
+ ) -> Result<Self, MlsError> {
+ let proposal = self
+ .group
+ .reinit_proposal(group_id, version, cipher_suite, extensions)?;
+
+ self.proposals.push(proposal);
+ Ok(self)
+ }
+
+ /// Insert a [`CustomProposal`](crate::group::proposal::CustomProposal) into
+ /// the current commit that is being built.
+ #[cfg(feature = "custom_proposal")]
+ pub fn custom_proposal(mut self, proposal: CustomProposal) -> Self {
+ self.proposals.push(Proposal::Custom(proposal));
+ self
+ }
+
+ /// Insert a proposal that was previously constructed such as when a
+ /// proposal is returned from
+ /// [`StateUpdate::unused_proposals`](super::StateUpdate::unused_proposals).
+ pub fn raw_proposal(mut self, proposal: Proposal) -> Self {
+ self.proposals.push(proposal);
+ self
+ }
+
+ /// Insert proposals that were previously constructed such as when a
+ /// proposal is returned from
+ /// [`StateUpdate::unused_proposals`](super::StateUpdate::unused_proposals).
+ pub fn raw_proposals(mut self, mut proposals: Vec<Proposal>) -> Self {
+ self.proposals.append(&mut proposals);
+ self
+ }
+
+ /// Add additional authenticated data to the commit.
+ ///
+ /// # Warning
+ ///
+ /// The data provided here is always sent unencrypted.
+ pub fn authenticated_data(self, authenticated_data: Vec<u8>) -> Self {
+ Self {
+ authenticated_data,
+ ..self
+ }
+ }
+
+ /// Change the committer's signing identity as part of making this commit.
+ /// This will only succeed if the [`IdentityProvider`](crate::IdentityProvider)
+ /// in use by the group considers the credential inside this signing_identity
+ /// [valid](crate::IdentityProvider::validate_member)
+ /// and results in the same
+ /// [identity](crate::IdentityProvider::identity)
+ /// being used.
+ pub fn set_new_signing_identity(
+ self,
+ signer: SignatureSecretKey,
+ signing_identity: SigningIdentity,
+ ) -> Self {
+ Self {
+ new_signer: Some(signer),
+ new_signing_identity: Some(signing_identity),
+ ..self
+ }
+ }
+
+ /// Finalize the commit to send.
+ ///
+ /// # Errors
+ ///
+ /// This function will return an error if any of the proposals provided
+ /// are not contextually valid according to the rules defined by the
+ /// MLS RFC, or if they do not pass the custom rules defined by the current
+ /// [proposal rules](crate::client_builder::ClientBuilder::mls_rules).
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn build(self) -> Result<CommitOutput, MlsError> {
+ self.group
+ .commit_internal(
+ self.proposals,
+ None,
+ self.authenticated_data,
+ self.group_info_extensions,
+ self.new_signer,
+ self.new_signing_identity,
+ )
+ .await
+ }
+}
+
+impl<C> Group<C>
+where
+ C: ClientConfig + Clone,
+{
+ /// Perform a commit of received proposals.
+ ///
+ /// This function is the equivalent of [`Group::commit_builder`] immediately
+ /// followed by [`CommitBuilder::build`]. Any received proposals since the
+ /// last commit will be included in the resulting message by-reference.
+ ///
+ /// Data provided in the `authenticated_data` field will be placed into
+ /// the resulting commit message unencrypted.
+ ///
+ /// # Pending Commits
+ ///
+ /// When a commit is created, it is not applied immediately in order to
+ /// allow for the resolution of conflicts when multiple members of a group
+ /// attempt to make commits at the same time. For example, a central relay
+ /// can be used to decide which commit should be accepted by the group by
+ /// determining a consistent view of commit packet order for all clients.
+ ///
+ /// Pending commits are stored internally as part of the group's state
+ /// so they do not need to be tracked outside of this library. Any commit
+ /// message that is processed before calling [Group::apply_pending_commit]
+ /// will clear the currently pending commit.
+ ///
+ /// # Empty Commits
+ ///
+ /// Sending a commit that contains no proposals is a valid operation
+ /// within the MLS protocol. It is useful for providing stronger forward
+ /// secrecy and post-compromise security, especially for long running
+ /// groups when group membership does not change often.
+ ///
+ /// # Path Updates
+ ///
+ /// Path updates provide forward secrecy and post-compromise security
+ /// within the MLS protocol.
+ /// The `path_required` option returned by [`MlsRules::commit_options`](`crate::MlsRules::commit_options`)
+ /// controls the ability of a group to send a commit without a path update.
+ /// An update path will automatically be sent if there are no proposals
+ /// in the commit, or if any proposal other than
+ /// [`Add`](crate::group::proposal::Proposal::Add),
+ /// [`Psk`](crate::group::proposal::Proposal::Psk),
+ /// or [`ReInit`](crate::group::proposal::Proposal::ReInit) are part of the commit.
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn commit(&mut self, authenticated_data: Vec<u8>) -> Result<CommitOutput, MlsError> {
+ self.commit_internal(
+ vec![],
+ None,
+ authenticated_data,
+ Default::default(),
+ None,
+ None,
+ )
+ .await
+ }
+
+ /// Create a new commit builder that can include proposals
+ /// by-value.
+ pub fn commit_builder(&mut self) -> CommitBuilder<C> {
+ CommitBuilder {
+ group: self,
+ proposals: Default::default(),
+ authenticated_data: Default::default(),
+ group_info_extensions: Default::default(),
+ new_signer: Default::default(),
+ new_signing_identity: Default::default(),
+ }
+ }
+
+ /// Returns commit and optional [`MlsMessage`] containing a welcome message
+ /// for newly added members.
+ #[allow(clippy::too_many_arguments)]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(super) async fn commit_internal(
+ &mut self,
+ proposals: Vec<Proposal>,
+ external_leaf: Option<&LeafNode>,
+ authenticated_data: Vec<u8>,
+ mut welcome_group_info_extensions: ExtensionList,
+ new_signer: Option<SignatureSecretKey>,
+ new_signing_identity: Option<SigningIdentity>,
+ ) -> Result<CommitOutput, MlsError> {
+ if self.pending_commit.is_some() {
+ return Err(MlsError::ExistingPendingCommit);
+ }
+
+ if self.state.pending_reinit.is_some() {
+ return Err(MlsError::GroupUsedAfterReInit);
+ }
+
+ let mls_rules = self.config.mls_rules();
+
+ let is_external = external_leaf.is_some();
+
+ // Construct an initial Commit object with the proposals field populated from Proposals
+ // received during the current epoch, and an empty path field. Add passed in proposals
+ // by value
+ let sender = if is_external {
+ Sender::NewMemberCommit
+ } else {
+ Sender::Member(*self.private_tree.self_index)
+ };
+
+ let new_signer_ref = new_signer.as_ref().unwrap_or(&self.signer);
+ let old_signer = &self.signer;
+
+ #[cfg(feature = "std")]
+ let time = Some(crate::time::MlsTime::now());
+
+ #[cfg(not(feature = "std"))]
+ let time = None;
+
+ #[cfg(feature = "by_ref_proposal")]
+ let proposals = self.state.proposals.prepare_commit(sender, proposals);
+
+ #[cfg(not(feature = "by_ref_proposal"))]
+ let proposals = prepare_commit(sender, proposals);
+
+ let mut provisional_state = self
+ .state
+ .apply_resolved(
+ sender,
+ proposals,
+ external_leaf,
+ &self.config.identity_provider(),
+ &self.cipher_suite_provider,
+ &self.config.secret_store(),
+ &mls_rules,
+ time,
+ CommitDirection::Send,
+ )
+ .await?;
+
+ let (mut provisional_private_tree, _) =
+ self.provisional_private_tree(&provisional_state)?;
+
+ if is_external {
+ provisional_private_tree.self_index = provisional_state
+ .external_init_index
+ .ok_or(MlsError::ExternalCommitMissingExternalInit)?;
+
+ self.private_tree.self_index = provisional_private_tree.self_index;
+ }
+
+ let mut provisional_group_context = provisional_state.group_context;
+
+ // Decide whether to populate the path field: If the path field is required based on the
+ // proposals that are in the commit (see above), then it MUST be populated. Otherwise, the
+ // sender MAY omit the path field at its discretion.
+ let commit_options = mls_rules
+ .commit_options(
+ &provisional_state.public_tree.roster(),
+ &provisional_group_context.extensions,
+ &provisional_state.applied_proposals,
+ )
+ .map_err(|e| MlsError::MlsRulesError(e.into_any_error()))?;
+
+ let perform_path_update = commit_options.path_required
+ || path_update_required(&provisional_state.applied_proposals);
+
+ let (update_path, path_secrets, commit_secret) = if perform_path_update {
+ // If populating the path field: Create an UpdatePath using the new tree. Any new
+ // member (from an add proposal) MUST be excluded from the resolution during the
+ // computation of the UpdatePath. The GroupContext for this operation uses the
+ // group_id, epoch, tree_hash, and confirmed_transcript_hash values in the initial
+ // GroupContext object. The leaf_key_package for this UpdatePath must have a
+ // parent_hash extension.
+ let encap_gen = TreeKem::new(
+ &mut provisional_state.public_tree,
+ &mut provisional_private_tree,
+ )
+ .encap(
+ &mut provisional_group_context,
+ &provisional_state.indexes_of_added_kpkgs,
+ new_signer_ref,
+ self.config.leaf_properties(),
+ new_signing_identity,
+ &self.cipher_suite_provider,
+ #[cfg(test)]
+ &self.commit_modifiers,
+ )
+ .await?;
+
+ (
+ Some(encap_gen.update_path),
+ Some(encap_gen.path_secrets),
+ encap_gen.commit_secret,
+ )
+ } else {
+ // Update the tree hash, since it was not updated by encap.
+ provisional_state
+ .public_tree
+ .update_hashes(
+ &[provisional_private_tree.self_index],
+ &self.cipher_suite_provider,
+ )
+ .await?;
+
+ provisional_group_context.tree_hash = provisional_state
+ .public_tree
+ .tree_hash(&self.cipher_suite_provider)
+ .await?;
+
+ (None, None, PathSecret::empty(&self.cipher_suite_provider))
+ };
+
+ #[cfg(feature = "psk")]
+ let (psk_secret, psks) = self
+ .get_psk(&provisional_state.applied_proposals.psks)
+ .await?;
+
+ #[cfg(not(feature = "psk"))]
+ let psk_secret = self.get_psk();
+
+ let added_key_pkgs: Vec<_> = provisional_state
+ .applied_proposals
+ .additions
+ .iter()
+ .map(|info| info.proposal.key_package.clone())
+ .collect();
+
+ let commit = Commit {
+ proposals: provisional_state.applied_proposals.into_proposals_or_refs(),
+ path: update_path,
+ };
+
+ let mut auth_content = AuthenticatedContent::new_signed(
+ &self.cipher_suite_provider,
+ self.context(),
+ sender,
+ Content::Commit(alloc::boxed::Box::new(commit)),
+ old_signer,
+ #[cfg(feature = "private_message")]
+ self.encryption_options()?.control_wire_format(sender),
+ #[cfg(not(feature = "private_message"))]
+ WireFormat::PublicMessage,
+ authenticated_data,
+ )
+ .await?;
+
+ // Use the signature, the commit_secret and the psk_secret to advance the key schedule and
+ // compute the confirmation_tag value in the MlsPlaintext.
+ let confirmed_transcript_hash = ConfirmedTranscriptHash::create(
+ self.cipher_suite_provider(),
+ &self.state.interim_transcript_hash,
+ &auth_content,
+ )
+ .await?;
+
+ provisional_group_context.confirmed_transcript_hash = confirmed_transcript_hash;
+
+ let key_schedule_result = KeySchedule::from_key_schedule(
+ &self.key_schedule,
+ &commit_secret,
+ &provisional_group_context,
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ self.state.public_tree.total_leaf_count(),
+ &psk_secret,
+ &self.cipher_suite_provider,
+ )
+ .await?;
+
+ let confirmation_tag = ConfirmationTag::create(
+ &key_schedule_result.confirmation_key,
+ &provisional_group_context.confirmed_transcript_hash,
+ &self.cipher_suite_provider,
+ )
+ .await?;
+
+ auth_content.auth.confirmation_tag = Some(confirmation_tag.clone());
+
+ let ratchet_tree_ext = commit_options
+ .ratchet_tree_extension
+ .then(|| RatchetTreeExt {
+ tree_data: ExportedTree::new(provisional_state.public_tree.nodes.clone()),
+ });
+
+ // Generate external commit group info if required by commit_options
+ let external_commit_group_info = match commit_options.allow_external_commit {
+ true => {
+ let mut extensions = ExtensionList::new();
+
+ extensions.set_from({
+ key_schedule_result
+ .key_schedule
+ .get_external_key_pair_ext(&self.cipher_suite_provider)
+ .await?
+ })?;
+
+ if let Some(ref ratchet_tree_ext) = ratchet_tree_ext {
+ extensions.set_from(ratchet_tree_ext.clone())?;
+ }
+
+ let info = self
+ .make_group_info(
+ &provisional_group_context,
+ extensions,
+ &confirmation_tag,
+ new_signer_ref,
+ )
+ .await?;
+
+ let msg =
+ MlsMessage::new(self.protocol_version(), MlsMessagePayload::GroupInfo(info));
+
+ Some(msg)
+ }
+ false => None,
+ };
+
+ // Build the group info that will be placed into the welcome messages.
+ // Add the ratchet tree extension if necessary
+ if let Some(ratchet_tree_ext) = ratchet_tree_ext {
+ welcome_group_info_extensions.set_from(ratchet_tree_ext)?;
+ }
+
+ let welcome_group_info = self
+ .make_group_info(
+ &provisional_group_context,
+ welcome_group_info_extensions,
+ &confirmation_tag,
+ new_signer_ref,
+ )
+ .await?;
+
+ // Encrypt the GroupInfo using the key and nonce derived from the joiner_secret for
+ // the new epoch
+ let welcome_secret = WelcomeSecret::from_joiner_secret(
+ &self.cipher_suite_provider,
+ &key_schedule_result.joiner_secret,
+ &psk_secret,
+ )
+ .await?;
+
+ let encrypted_group_info = welcome_secret
+ .encrypt(&welcome_group_info.mls_encode_to_vec()?)
+ .await?;
+
+ // Encrypt path secrets and joiner secret to new members
+ let path_secrets = path_secrets.as_ref();
+
+ #[cfg(not(any(mls_build_async, not(feature = "rayon"))))]
+ let encrypted_path_secrets: Vec<_> = added_key_pkgs
+ .into_par_iter()
+ .zip(provisional_state.indexes_of_added_kpkgs)
+ .map(|(key_package, leaf_index)| {
+ self.encrypt_group_secrets(
+ &key_package,
+ leaf_index,
+ &key_schedule_result.joiner_secret,
+ path_secrets,
+ #[cfg(feature = "psk")]
+ psks.clone(),
+ &encrypted_group_info,
+ )
+ })
+ .try_collect()?;
+
+ #[cfg(any(mls_build_async, not(feature = "rayon")))]
+ let encrypted_path_secrets = {
+ let mut secrets = Vec::new();
+
+ for (key_package, leaf_index) in added_key_pkgs
+ .into_iter()
+ .zip(provisional_state.indexes_of_added_kpkgs)
+ {
+ secrets.push(
+ self.encrypt_group_secrets(
+ &key_package,
+ leaf_index,
+ &key_schedule_result.joiner_secret,
+ path_secrets,
+ #[cfg(feature = "psk")]
+ psks.clone(),
+ &encrypted_group_info,
+ )
+ .await?,
+ );
+ }
+
+ secrets
+ };
+
+ let welcome_messages =
+ if commit_options.single_welcome_message && !encrypted_path_secrets.is_empty() {
+ vec![self.make_welcome_message(encrypted_path_secrets, encrypted_group_info)]
+ } else {
+ encrypted_path_secrets
+ .into_iter()
+ .map(|s| self.make_welcome_message(vec![s], encrypted_group_info.clone()))
+ .collect()
+ };
+
+ let commit_message = self.format_for_wire(auth_content.clone()).await?;
+
+ let pending_commit = CommitGeneration {
+ content: auth_content,
+ pending_private_tree: provisional_private_tree,
+ pending_commit_secret: commit_secret,
+ commit_message_hash: CommitHash::compute(&self.cipher_suite_provider, &commit_message)
+ .await?,
+ };
+
+ self.pending_commit = Some(pending_commit);
+
+ let ratchet_tree = (!commit_options.ratchet_tree_extension)
+ .then(|| ExportedTree::new(provisional_state.public_tree.nodes));
+
+ if let Some(signer) = new_signer {
+ self.signer = signer;
+ }
+
+ Ok(CommitOutput {
+ commit_message,
+ welcome_messages,
+ ratchet_tree,
+ external_commit_group_info,
+ #[cfg(feature = "by_ref_proposal")]
+ unused_proposals: provisional_state.unused_proposals,
+ })
+ }
+
+ // Construct a GroupInfo reflecting the new state
+ // Group ID, epoch, tree, and confirmed transcript hash from the new state
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn make_group_info(
+ &self,
+ group_context: &GroupContext,
+ extensions: ExtensionList,
+ confirmation_tag: &ConfirmationTag,
+ signer: &SignatureSecretKey,
+ ) -> Result<GroupInfo, MlsError> {
+ let mut group_info = GroupInfo {
+ group_context: group_context.clone(),
+ extensions,
+ confirmation_tag: confirmation_tag.clone(), // The confirmation_tag from the MlsPlaintext object
+ signer: LeafIndex(self.current_member_index()),
+ signature: vec![],
+ };
+
+ group_info.grease(self.cipher_suite_provider())?;
+
+ // Sign the GroupInfo using the member's private signing key
+ group_info
+ .sign(&self.cipher_suite_provider, signer, &())
+ .await?;
+
+ Ok(group_info)
+ }
+
+ fn make_welcome_message(
+ &self,
+ secrets: Vec<EncryptedGroupSecrets>,
+ encrypted_group_info: Vec<u8>,
+ ) -> MlsMessage {
+ MlsMessage::new(
+ self.context().protocol_version,
+ MlsMessagePayload::Welcome(Welcome {
+ cipher_suite: self.context().cipher_suite,
+ secrets,
+ encrypted_group_info,
+ }),
+ )
+ }
+}
+
+#[cfg(test)]
+pub(crate) mod test_utils {
+ use alloc::vec::Vec;
+
+ use crate::{
+ crypto::SignatureSecretKey,
+ tree_kem::{leaf_node::LeafNode, TreeKemPublic, UpdatePathNode},
+ };
+
+ #[derive(Copy, Clone, Debug)]
+ pub struct CommitModifiers {
+ pub modify_leaf: fn(&mut LeafNode, &SignatureSecretKey) -> Option<SignatureSecretKey>,
+ pub modify_tree: fn(&mut TreeKemPublic),
+ pub modify_path: fn(Vec<UpdatePathNode>) -> Vec<UpdatePathNode>,
+ }
+
+ impl Default for CommitModifiers {
+ fn default() -> Self {
+ Self {
+ modify_leaf: |_, _| None,
+ modify_tree: |_| (),
+ modify_path: |a| a,
+ }
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use alloc::boxed::Box;
+
+ use mls_rs_core::{
+ error::IntoAnyError,
+ extension::ExtensionType,
+ identity::{CredentialType, IdentityProvider},
+ time::MlsTime,
+ };
+
+ use crate::{
+ crypto::test_utils::{test_cipher_suite_provider, TestCryptoProvider},
+ group::{mls_rules::DefaultMlsRules, test_utils::test_group_custom},
+ mls_rules::CommitOptions,
+ Client,
+ };
+
+ #[cfg(feature = "by_ref_proposal")]
+ use crate::extension::ExternalSendersExt;
+
+ use crate::{
+ client::test_utils::{test_client_with_key_pkg, TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
+ client_builder::{
+ test_utils::TestClientConfig, BaseConfig, ClientBuilder, WithCryptoProvider,
+ WithIdentityProvider,
+ },
+ client_config::ClientConfig,
+ extension::test_utils::{TestExtension, TEST_EXTENSION_TYPE},
+ group::{
+ proposal::ProposalType,
+ test_utils::{test_group_custom_config, test_n_member_group},
+ },
+ identity::test_utils::get_test_signing_identity,
+ identity::{basic::BasicIdentityProvider, test_utils::get_test_basic_credential},
+ key_package::test_utils::test_key_package_message,
+ };
+
+ use crate::extension::RequiredCapabilitiesExt;
+
+ #[cfg(feature = "psk")]
+ use crate::{
+ group::proposal::PreSharedKeyProposal,
+ psk::{JustPreSharedKeyID, PreSharedKey, PreSharedKeyID},
+ };
+
+ use super::*;
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn test_commit_builder_group() -> Group<TestClientConfig> {
+ test_group_custom_config(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, |b| {
+ b.custom_proposal_type(ProposalType::from(42))
+ .extension_type(TEST_EXTENSION_TYPE.into())
+ })
+ .await
+ .group
+ }
+
+ fn assert_commit_builder_output<C: ClientConfig>(
+ group: Group<C>,
+ mut commit_output: CommitOutput,
+ expected: Vec<Proposal>,
+ welcome_count: usize,
+ ) {
+ let plaintext = commit_output.commit_message.into_plaintext().unwrap();
+
+ let commit_data = match plaintext.content.content {
+ Content::Commit(commit) => commit,
+ #[cfg(any(feature = "private_message", feature = "by_ref_proposal"))]
+ _ => panic!("Found non-commit data"),
+ };
+
+ assert_eq!(commit_data.proposals.len(), expected.len());
+
+ commit_data.proposals.into_iter().for_each(|proposal| {
+ let proposal = match proposal {
+ ProposalOrRef::Proposal(p) => p,
+ #[cfg(feature = "by_ref_proposal")]
+ ProposalOrRef::Reference(_) => panic!("found proposal reference"),
+ };
+
+ #[cfg(feature = "psk")]
+ if let Some(psk_id) = match proposal.as_ref() {
+ Proposal::Psk(PreSharedKeyProposal { psk: PreSharedKeyID { key_id: JustPreSharedKeyID::External(psk_id), .. },}) => Some(psk_id),
+ _ => None,
+ } {
+ let found = expected.iter().any(|item| matches!(item, Proposal::Psk(PreSharedKeyProposal { psk: PreSharedKeyID { key_id: JustPreSharedKeyID::External(id), .. }}) if id == psk_id));
+
+ assert!(found)
+ } else {
+ assert!(expected.contains(&proposal));
+ }
+
+ #[cfg(not(feature = "psk"))]
+ assert!(expected.contains(&proposal));
+ });
+
+ if welcome_count > 0 {
+ let welcome_msg = commit_output.welcome_messages.pop().unwrap();
+
+ assert_eq!(welcome_msg.version, group.state.context.protocol_version);
+
+ let welcome_msg = welcome_msg.into_welcome().unwrap();
+
+ assert_eq!(welcome_msg.cipher_suite, group.state.context.cipher_suite);
+ assert_eq!(welcome_msg.secrets.len(), welcome_count);
+ } else {
+ assert!(commit_output.welcome_messages.is_empty());
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_commit_builder_add() {
+ let mut group = test_commit_builder_group().await;
+
+ let test_key_package =
+ test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "alice").await;
+
+ let commit_output = group
+ .commit_builder()
+ .add_member(test_key_package.clone())
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ let expected_add = group.add_proposal(test_key_package).unwrap();
+
+ assert_commit_builder_output(group, commit_output, vec![expected_add], 1)
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_commit_builder_add_with_ext() {
+ let mut group = test_commit_builder_group().await;
+
+ let (bob_client, bob_key_package) =
+ test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "bob").await;
+
+ let ext = TestExtension { foo: 42 };
+ let mut extension_list = ExtensionList::default();
+ extension_list.set_from(ext.clone()).unwrap();
+
+ let welcome_message = group
+ .commit_builder()
+ .add_member(bob_key_package)
+ .unwrap()
+ .set_group_info_ext(extension_list)
+ .build()
+ .await
+ .unwrap()
+ .welcome_messages
+ .remove(0);
+
+ let (_, context) = bob_client.join_group(None, &welcome_message).await.unwrap();
+
+ assert_eq!(
+ context
+ .group_info_extensions
+ .get_as::<TestExtension>()
+ .unwrap()
+ .unwrap(),
+ ext
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_commit_builder_remove() {
+ let mut group = test_commit_builder_group().await;
+ let test_key_package =
+ test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "alice").await;
+
+ group
+ .commit_builder()
+ .add_member(test_key_package)
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ group.apply_pending_commit().await.unwrap();
+
+ let commit_output = group
+ .commit_builder()
+ .remove_member(1)
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ let expected_remove = group.remove_proposal(1).unwrap();
+
+ assert_commit_builder_output(group, commit_output, vec![expected_remove], 0);
+ }
+
+ #[cfg(feature = "psk")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_commit_builder_psk() {
+ let mut group = test_commit_builder_group().await;
+ let test_psk = ExternalPskId::new(vec![1]);
+
+ group
+ .config
+ .secret_store()
+ .insert(test_psk.clone(), PreSharedKey::from(vec![1]));
+
+ let commit_output = group
+ .commit_builder()
+ .add_external_psk(test_psk.clone())
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ let key_id = JustPreSharedKeyID::External(test_psk);
+ let expected_psk = group.psk_proposal(key_id).unwrap();
+
+ assert_commit_builder_output(group, commit_output, vec![expected_psk], 0)
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_commit_builder_group_context_ext() {
+ let mut group = test_commit_builder_group().await;
+ let mut test_ext = ExtensionList::default();
+ test_ext
+ .set_from(RequiredCapabilitiesExt::default())
+ .unwrap();
+
+ let commit_output = group
+ .commit_builder()
+ .set_group_context_ext(test_ext.clone())
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ let expected_ext = group.group_context_extensions_proposal(test_ext);
+
+ assert_commit_builder_output(group, commit_output, vec![expected_ext], 0);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_commit_builder_reinit() {
+ let mut group = test_commit_builder_group().await;
+ let test_group_id = "foo".as_bytes().to_vec();
+ let test_cipher_suite = TEST_CIPHER_SUITE;
+ let test_protocol_version = TEST_PROTOCOL_VERSION;
+ let mut test_ext = ExtensionList::default();
+
+ test_ext
+ .set_from(RequiredCapabilitiesExt::default())
+ .unwrap();
+
+ let commit_output = group
+ .commit_builder()
+ .reinit(
+ Some(test_group_id.clone()),
+ test_protocol_version,
+ test_cipher_suite,
+ test_ext.clone(),
+ )
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ let expected_reinit = group
+ .reinit_proposal(
+ Some(test_group_id),
+ test_protocol_version,
+ test_cipher_suite,
+ test_ext,
+ )
+ .unwrap();
+
+ assert_commit_builder_output(group, commit_output, vec![expected_reinit], 0);
+ }
+
+ #[cfg(feature = "custom_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_commit_builder_custom_proposal() {
+ let mut group = test_commit_builder_group().await;
+
+ let proposal = CustomProposal::new(42.into(), vec![0, 1]);
+
+ let commit_output = group
+ .commit_builder()
+ .custom_proposal(proposal.clone())
+ .build()
+ .await
+ .unwrap();
+
+ assert_commit_builder_output(group, commit_output, vec![Proposal::Custom(proposal)], 0);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_commit_builder_chaining() {
+ let mut group = test_commit_builder_group().await;
+ let kp1 = test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "alice").await;
+ let kp2 = test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "bob").await;
+
+ let expected_adds = vec![
+ group.add_proposal(kp1.clone()).unwrap(),
+ group.add_proposal(kp2.clone()).unwrap(),
+ ];
+
+ let commit_output = group
+ .commit_builder()
+ .add_member(kp1)
+ .unwrap()
+ .add_member(kp2)
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ assert_commit_builder_output(group, commit_output, expected_adds, 2);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_commit_builder_empty_commit() {
+ let mut group = test_commit_builder_group().await;
+
+ let commit_output = group.commit_builder().build().await.unwrap();
+
+ assert_commit_builder_output(group, commit_output, vec![], 0);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_commit_builder_authenticated_data() {
+ let mut group = test_commit_builder_group().await;
+ let test_data = "test".as_bytes().to_vec();
+
+ let commit_output = group
+ .commit_builder()
+ .authenticated_data(test_data.clone())
+ .build()
+ .await
+ .unwrap();
+
+ assert_eq!(
+ commit_output
+ .commit_message
+ .into_plaintext()
+ .unwrap()
+ .content
+ .authenticated_data,
+ test_data
+ );
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_commit_builder_multiple_welcome_messages() {
+ let mut group = test_group_custom_config(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, |b| {
+ let options = CommitOptions::new().with_single_welcome_message(false);
+ b.mls_rules(DefaultMlsRules::new().with_commit_options(options))
+ })
+ .await;
+
+ let (alice, alice_kp) =
+ test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "a").await;
+
+ let (bob, bob_kp) =
+ test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "b").await;
+
+ group
+ .group
+ .propose_add(alice_kp.clone(), vec![])
+ .await
+ .unwrap();
+
+ group
+ .group
+ .propose_add(bob_kp.clone(), vec![])
+ .await
+ .unwrap();
+
+ let output = group.group.commit(Vec::new()).await.unwrap();
+ let welcomes = output.welcome_messages;
+
+ let cs = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ for (client, kp) in [(alice, alice_kp), (bob, bob_kp)] {
+ let kp_ref = kp.key_package_reference(&cs).await.unwrap().unwrap();
+
+ let welcome = welcomes
+ .iter()
+ .find(|w| w.welcome_key_package_references().contains(&&kp_ref))
+ .unwrap();
+
+ client.join_group(None, welcome).await.unwrap();
+
+ assert_eq!(welcome.clone().into_welcome().unwrap().secrets.len(), 1);
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn commit_can_change_credential() {
+ let cs = TEST_CIPHER_SUITE;
+ let mut groups = test_n_member_group(TEST_PROTOCOL_VERSION, cs, 3).await;
+ let (identity, secret_key) = get_test_signing_identity(cs, b"member").await;
+
+ let commit_output = groups[0]
+ .group
+ .commit_builder()
+ .set_new_signing_identity(secret_key, identity.clone())
+ .build()
+ .await
+ .unwrap();
+
+ // Check that the credential was updated by in the committer's state.
+ groups[0].process_pending_commit().await.unwrap();
+ let new_member = groups[0].group.roster().member_with_index(0).unwrap();
+
+ assert_eq!(
+ new_member.signing_identity.credential,
+ get_test_basic_credential(b"member".to_vec())
+ );
+
+ assert_eq!(
+ new_member.signing_identity.signature_key,
+ identity.signature_key
+ );
+
+ // Check that the credential was updated in another member's state.
+ groups[1]
+ .process_message(commit_output.commit_message)
+ .await
+ .unwrap();
+
+ let new_member = groups[1].group.roster().member_with_index(0).unwrap();
+
+ assert_eq!(
+ new_member.signing_identity.credential,
+ get_test_basic_credential(b"member".to_vec())
+ );
+
+ assert_eq!(
+ new_member.signing_identity.signature_key,
+ identity.signature_key
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn commit_includes_tree_if_no_ratchet_tree_ext() {
+ let mut group = test_group_custom(
+ TEST_PROTOCOL_VERSION,
+ TEST_CIPHER_SUITE,
+ Default::default(),
+ None,
+ Some(CommitOptions::new().with_ratchet_tree_extension(false)),
+ )
+ .await
+ .group;
+
+ let commit = group.commit(vec![]).await.unwrap();
+
+ group.apply_pending_commit().await.unwrap();
+
+ let new_tree = group.export_tree();
+
+ assert_eq!(new_tree, commit.ratchet_tree.unwrap())
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn commit_does_not_include_tree_if_ratchet_tree_ext() {
+ let mut group = test_group_custom(
+ TEST_PROTOCOL_VERSION,
+ TEST_CIPHER_SUITE,
+ Default::default(),
+ None,
+ Some(CommitOptions::new().with_ratchet_tree_extension(true)),
+ )
+ .await
+ .group;
+
+ let commit = group.commit(vec![]).await.unwrap();
+
+ assert!(commit.ratchet_tree.is_none());
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn commit_includes_external_commit_group_info_if_requested() {
+ let mut group = test_group_custom(
+ TEST_PROTOCOL_VERSION,
+ TEST_CIPHER_SUITE,
+ Default::default(),
+ None,
+ Some(
+ CommitOptions::new()
+ .with_allow_external_commit(true)
+ .with_ratchet_tree_extension(false),
+ ),
+ )
+ .await
+ .group;
+
+ let commit = group.commit(vec![]).await.unwrap();
+
+ let info = commit
+ .external_commit_group_info
+ .unwrap()
+ .into_group_info()
+ .unwrap();
+
+ assert!(!info.extensions.has_extension(ExtensionType::RATCHET_TREE));
+ assert!(info.extensions.has_extension(ExtensionType::EXTERNAL_PUB));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn commit_includes_external_commit_and_tree_if_requested() {
+ let mut group = test_group_custom(
+ TEST_PROTOCOL_VERSION,
+ TEST_CIPHER_SUITE,
+ Default::default(),
+ None,
+ Some(
+ CommitOptions::new()
+ .with_allow_external_commit(true)
+ .with_ratchet_tree_extension(true),
+ ),
+ )
+ .await
+ .group;
+
+ let commit = group.commit(vec![]).await.unwrap();
+
+ let info = commit
+ .external_commit_group_info
+ .unwrap()
+ .into_group_info()
+ .unwrap();
+
+ assert!(info.extensions.has_extension(ExtensionType::RATCHET_TREE));
+ assert!(info.extensions.has_extension(ExtensionType::EXTERNAL_PUB));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn commit_does_not_include_external_commit_group_info_if_not_requested() {
+ let mut group = test_group_custom(
+ TEST_PROTOCOL_VERSION,
+ TEST_CIPHER_SUITE,
+ Default::default(),
+ None,
+ Some(CommitOptions::new().with_allow_external_commit(false)),
+ )
+ .await
+ .group;
+
+ let commit = group.commit(vec![]).await.unwrap();
+
+ assert!(commit.external_commit_group_info.is_none());
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn member_identity_is_validated_against_new_extensions() {
+ let alice = client_with_test_extension(b"alice").await;
+ let mut alice = alice.create_group(ExtensionList::new()).await.unwrap();
+
+ let bob = client_with_test_extension(b"bob").await;
+ let bob_kp = bob.generate_key_package_message().await.unwrap();
+
+ let mut extension_list = ExtensionList::new();
+ let extension = TestExtension { foo: b'a' };
+ extension_list.set_from(extension).unwrap();
+
+ let res = alice
+ .commit_builder()
+ .add_member(bob_kp)
+ .unwrap()
+ .set_group_context_ext(extension_list.clone())
+ .unwrap()
+ .build()
+ .await;
+
+ assert!(res.is_err());
+
+ let alex = client_with_test_extension(b"alex").await;
+
+ alice
+ .commit_builder()
+ .add_member(alex.generate_key_package_message().await.unwrap())
+ .unwrap()
+ .set_group_context_ext(extension_list.clone())
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn server_identity_is_validated_against_new_extensions() {
+ let alice = client_with_test_extension(b"alice").await;
+ let mut alice = alice.create_group(ExtensionList::new()).await.unwrap();
+
+ let mut extension_list = ExtensionList::new();
+ let extension = TestExtension { foo: b'a' };
+ extension_list.set_from(extension).unwrap();
+
+ let (alex_server, _) = get_test_signing_identity(TEST_CIPHER_SUITE, b"alex").await;
+
+ let mut alex_extensions = extension_list.clone();
+
+ alex_extensions
+ .set_from(ExternalSendersExt {
+ allowed_senders: vec![alex_server],
+ })
+ .unwrap();
+
+ let res = alice
+ .commit_builder()
+ .set_group_context_ext(alex_extensions)
+ .unwrap()
+ .build()
+ .await;
+
+ assert!(res.is_err());
+
+ let (bob_server, _) = get_test_signing_identity(TEST_CIPHER_SUITE, b"bob").await;
+
+ let mut bob_extensions = extension_list;
+
+ bob_extensions
+ .set_from(ExternalSendersExt {
+ allowed_senders: vec![bob_server],
+ })
+ .unwrap();
+
+ alice
+ .commit_builder()
+ .set_group_context_ext(bob_extensions)
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+ }
+
+ #[derive(Debug, Clone)]
+ struct IdentityProviderWithExtension(BasicIdentityProvider);
+
+ #[derive(Clone, Debug)]
+ #[cfg_attr(feature = "std", derive(thiserror::Error))]
+ #[cfg_attr(feature = "std", error("test error"))]
+ struct IdentityProviderWithExtensionError {}
+
+ impl IntoAnyError for IdentityProviderWithExtensionError {
+ #[cfg(feature = "std")]
+ fn into_dyn_error(self) -> Result<Box<dyn std::error::Error + Send + Sync>, Self> {
+ Ok(self.into())
+ }
+ }
+
+ impl IdentityProviderWithExtension {
+ // True if the identity starts with the character `foo` from `TestExtension` or if `TestExtension`
+ // is not set.
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn starts_with_foo(
+ &self,
+ identity: &SigningIdentity,
+ _timestamp: Option<MlsTime>,
+ extensions: Option<&ExtensionList>,
+ ) -> bool {
+ if let Some(extensions) = extensions {
+ if let Some(ext) = extensions.get_as::<TestExtension>().unwrap() {
+ self.identity(identity, extensions).await.unwrap()[0] == ext.foo
+ } else {
+ true
+ }
+ } else {
+ true
+ }
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ #[cfg_attr(mls_build_async, maybe_async::must_be_async)]
+ impl IdentityProvider for IdentityProviderWithExtension {
+ type Error = IdentityProviderWithExtensionError;
+
+ async fn validate_member(
+ &self,
+ identity: &SigningIdentity,
+ timestamp: Option<MlsTime>,
+ extensions: Option<&ExtensionList>,
+ ) -> Result<(), Self::Error> {
+ self.starts_with_foo(identity, timestamp, extensions)
+ .await
+ .then_some(())
+ .ok_or(IdentityProviderWithExtensionError {})
+ }
+
+ async fn validate_external_sender(
+ &self,
+ identity: &SigningIdentity,
+ timestamp: Option<MlsTime>,
+ extensions: Option<&ExtensionList>,
+ ) -> Result<(), Self::Error> {
+ (!self.starts_with_foo(identity, timestamp, extensions).await)
+ .then_some(())
+ .ok_or(IdentityProviderWithExtensionError {})
+ }
+
+ async fn identity(
+ &self,
+ signing_identity: &SigningIdentity,
+ extensions: &ExtensionList,
+ ) -> Result<Vec<u8>, Self::Error> {
+ self.0
+ .identity(signing_identity, extensions)
+ .await
+ .map_err(|_| IdentityProviderWithExtensionError {})
+ }
+
+ async fn valid_successor(
+ &self,
+ _predecessor: &SigningIdentity,
+ _successor: &SigningIdentity,
+ _extensions: &ExtensionList,
+ ) -> Result<bool, Self::Error> {
+ Ok(true)
+ }
+
+ fn supported_types(&self) -> Vec<CredentialType> {
+ self.0.supported_types()
+ }
+ }
+
+ type ExtensionClientConfig = WithIdentityProvider<
+ IdentityProviderWithExtension,
+ WithCryptoProvider<TestCryptoProvider, BaseConfig>,
+ >;
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn client_with_test_extension(name: &[u8]) -> Client<ExtensionClientConfig> {
+ let (identity, secret_key) = get_test_signing_identity(TEST_CIPHER_SUITE, name).await;
+
+ ClientBuilder::new()
+ .crypto_provider(TestCryptoProvider::new())
+ .extension_types(vec![TEST_EXTENSION_TYPE.into()])
+ .identity_provider(IdentityProviderWithExtension(BasicIdentityProvider::new()))
+ .signing_identity(identity, secret_key, TEST_CIPHER_SUITE)
+ .build()
+ }
+}
diff --git a/src/group/confirmation_tag.rs b/src/group/confirmation_tag.rs
new file mode 100644
index 0000000..409b382
--- /dev/null
+++ b/src/group/confirmation_tag.rs
@@ -0,0 +1,150 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use crate::CipherSuiteProvider;
+use crate::{client::MlsError, group::transcript_hash::ConfirmedTranscriptHash};
+use alloc::vec::Vec;
+use core::{
+ fmt::{self, Debug},
+ ops::Deref,
+};
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+use mls_rs_core::error::IntoAnyError;
+
+#[derive(Clone, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub struct ConfirmationTag(
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
+ Vec<u8>,
+);
+
+impl Debug for ConfirmationTag {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ mls_rs_core::debug::pretty_bytes(&self.0)
+ .named("ConfirmationTag")
+ .fmt(f)
+ }
+}
+
+impl Deref for ConfirmationTag {
+ type Target = Vec<u8>;
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+impl ConfirmationTag {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn create<P: CipherSuiteProvider>(
+ confirmation_key: &[u8],
+ confirmed_transcript_hash: &ConfirmedTranscriptHash,
+ cipher_suite_provider: &P,
+ ) -> Result<Self, MlsError> {
+ cipher_suite_provider
+ .mac(confirmation_key, confirmed_transcript_hash)
+ .await
+ .map(ConfirmationTag)
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn matches<P: CipherSuiteProvider>(
+ &self,
+ confirmation_key: &[u8],
+ confirmed_transcript_hash: &ConfirmedTranscriptHash,
+ cipher_suite_provider: &P,
+ ) -> Result<bool, MlsError> {
+ let tag = ConfirmationTag::create(
+ confirmation_key,
+ confirmed_transcript_hash,
+ cipher_suite_provider,
+ )
+ .await?;
+
+ Ok(&tag == self)
+ }
+}
+
+#[cfg(test)]
+impl ConfirmationTag {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn empty<P: CipherSuiteProvider>(cipher_suite_provider: &P) -> Self {
+ Self(
+ cipher_suite_provider
+ .mac(
+ &alloc::vec![0; cipher_suite_provider.kdf_extract_size()],
+ &[],
+ )
+ .await
+ .unwrap(),
+ )
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::crypto::test_utils::{test_cipher_suite_provider, TestCryptoProvider};
+
+ #[cfg(target_arch = "wasm32")]
+ use wasm_bindgen_test::wasm_bindgen_test as test;
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_confirmation_tag_matching() {
+ for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
+ let cipher_suite_provider = test_cipher_suite_provider(cipher_suite);
+
+ let confirmed_hash_a = ConfirmedTranscriptHash::from(b"foo_a".to_vec());
+
+ let confirmation_key_a = b"bar_a".to_vec();
+
+ let confirmed_hash_b = ConfirmedTranscriptHash::from(b"foo_b".to_vec());
+
+ let confirmation_key_b = b"bar_b".to_vec();
+
+ let confirmation_tag = ConfirmationTag::create(
+ &confirmation_key_a,
+ &confirmed_hash_a,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap();
+
+ let matches = confirmation_tag
+ .matches(
+ &confirmation_key_a,
+ &confirmed_hash_a,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap();
+
+ assert!(matches);
+
+ let matches = confirmation_tag
+ .matches(
+ &confirmation_key_b,
+ &confirmed_hash_a,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap();
+
+ assert!(!matches);
+
+ let matches = confirmation_tag
+ .matches(
+ &confirmation_key_a,
+ &confirmed_hash_b,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap();
+
+ assert!(!matches);
+ }
+ }
+}
diff --git a/src/group/context.rs b/src/group/context.rs
new file mode 100644
index 0000000..4ec23a9
--- /dev/null
+++ b/src/group/context.rs
@@ -0,0 +1,98 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::vec;
+use alloc::vec::Vec;
+use core::fmt::{self, Debug};
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+
+use crate::{cipher_suite::CipherSuite, protocol_version::ProtocolVersion, ExtensionList};
+
+use super::ConfirmedTranscriptHash;
+
+#[derive(Clone, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(
+ all(feature = "ffi", not(test)),
+ safer_ffi_gen::ffi_type(clone, opaque)
+)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub struct GroupContext {
+ pub(crate) protocol_version: ProtocolVersion,
+ pub(crate) cipher_suite: CipherSuite,
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
+ pub(crate) group_id: Vec<u8>,
+ pub(crate) epoch: u64,
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ pub(crate) tree_hash: Vec<u8>,
+ pub(crate) confirmed_transcript_hash: ConfirmedTranscriptHash,
+ pub(crate) extensions: ExtensionList,
+}
+
+impl Debug for GroupContext {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("GroupContext")
+ .field("protocol_version", &self.protocol_version)
+ .field("cipher_suite", &self.cipher_suite)
+ .field(
+ "group_id",
+ &mls_rs_core::debug::pretty_group_id(&self.group_id),
+ )
+ .field("epoch", &self.epoch)
+ .field(
+ "tree_hash",
+ &mls_rs_core::debug::pretty_bytes(&self.tree_hash),
+ )
+ .field("confirmed_transcript_hash", &self.confirmed_transcript_hash)
+ .field("extensions", &self.extensions)
+ .finish()
+ }
+}
+
+#[cfg_attr(all(feature = "ffi", not(test)), ::safer_ffi_gen::safer_ffi_gen)]
+impl GroupContext {
+ pub(crate) fn new_group(
+ protocol_version: ProtocolVersion,
+ cipher_suite: CipherSuite,
+ group_id: Vec<u8>,
+ tree_hash: Vec<u8>,
+ extensions: ExtensionList,
+ ) -> Self {
+ GroupContext {
+ protocol_version,
+ cipher_suite,
+ group_id,
+ epoch: 0,
+ tree_hash,
+ confirmed_transcript_hash: ConfirmedTranscriptHash::from(vec![]),
+ extensions,
+ }
+ }
+
+ /// Get the current protocol version in use by the group.
+ pub fn version(&self) -> ProtocolVersion {
+ self.protocol_version
+ }
+
+ /// Get the current cipher suite in use by the group.
+ pub fn cipher_suite(&self) -> CipherSuite {
+ self.cipher_suite
+ }
+
+ /// Get the unique identifier of this group.
+ pub fn group_id(&self) -> &[u8] {
+ &self.group_id
+ }
+
+ /// Get the current epoch number of the group's state.
+ pub fn epoch(&self) -> u64 {
+ self.epoch
+ }
+
+ pub fn extensions(&self) -> &ExtensionList {
+ &self.extensions
+ }
+}
diff --git a/src/group/epoch.rs b/src/group/epoch.rs
new file mode 100644
index 0000000..58352d6
--- /dev/null
+++ b/src/group/epoch.rs
@@ -0,0 +1,165 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+#[cfg(feature = "psk")]
+use crate::psk::PreSharedKey;
+#[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+use crate::tree_kem::node::NodeIndex;
+#[cfg(feature = "prior_epoch")]
+use crate::{crypto::SignaturePublicKey, group::GroupContext, tree_kem::node::LeafIndex};
+use alloc::vec::Vec;
+use core::{
+ fmt::{self, Debug},
+ ops::Deref,
+};
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+use zeroize::Zeroizing;
+
+#[cfg(all(feature = "prior_epoch", feature = "private_message"))]
+use super::ciphertext_processor::GroupStateProvider;
+
+#[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+use crate::group::secret_tree::SecretTree;
+
+#[cfg(feature = "prior_epoch")]
+#[derive(Debug, Clone, MlsEncode, MlsDecode, MlsSize, PartialEq)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub(crate) struct PriorEpoch {
+ pub(crate) context: GroupContext,
+ pub(crate) self_index: LeafIndex,
+ pub(crate) secrets: EpochSecrets,
+ pub(crate) signature_public_keys: Vec<Option<SignaturePublicKey>>,
+}
+
+#[cfg(feature = "prior_epoch")]
+impl PriorEpoch {
+ #[inline(always)]
+ pub(crate) fn epoch_id(&self) -> u64 {
+ self.context.epoch
+ }
+
+ #[inline(always)]
+ pub(crate) fn group_id(&self) -> &[u8] {
+ &self.context.group_id
+ }
+}
+
+#[cfg(all(feature = "private_message", feature = "prior_epoch"))]
+impl GroupStateProvider for PriorEpoch {
+ fn group_context(&self) -> &GroupContext {
+ &self.context
+ }
+
+ fn self_index(&self) -> LeafIndex {
+ self.self_index
+ }
+
+ fn epoch_secrets_mut(&mut self) -> &mut EpochSecrets {
+ &mut self.secrets
+ }
+
+ fn epoch_secrets(&self) -> &EpochSecrets {
+ &self.secrets
+ }
+}
+
+#[derive(Debug, Clone, PartialEq, MlsEncode, MlsDecode, MlsSize)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub(crate) struct EpochSecrets {
+ #[cfg(feature = "psk")]
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ pub(crate) resumption_secret: PreSharedKey,
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ pub(crate) sender_data_secret: SenderDataSecret,
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ pub(crate) secret_tree: SecretTree<NodeIndex>,
+}
+
+#[derive(Clone, PartialEq, MlsEncode, MlsDecode, MlsSize)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub(crate) struct SenderDataSecret(
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))]
+ Zeroizing<Vec<u8>>,
+);
+
+impl Debug for SenderDataSecret {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ mls_rs_core::debug::pretty_bytes(&self.0)
+ .named("SenderDataSecret")
+ .fmt(f)
+ }
+}
+
+impl AsRef<[u8]> for SenderDataSecret {
+ fn as_ref(&self) -> &[u8] {
+ &self.0
+ }
+}
+
+impl Deref for SenderDataSecret {
+ type Target = Vec<u8>;
+
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+impl From<Vec<u8>> for SenderDataSecret {
+ fn from(bytes: Vec<u8>) -> Self {
+ Self(Zeroizing::new(bytes))
+ }
+}
+
+impl From<Zeroizing<Vec<u8>>> for SenderDataSecret {
+ fn from(bytes: Zeroizing<Vec<u8>>) -> Self {
+ Self(bytes)
+ }
+}
+
+#[cfg(test)]
+pub(crate) mod test_utils {
+ use mls_rs_core::crypto::CipherSuiteProvider;
+
+ use super::*;
+ use crate::cipher_suite::CipherSuite;
+ use crate::crypto::test_utils::test_cipher_suite_provider;
+
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ use crate::group::secret_tree::test_utils::get_test_tree;
+
+ #[cfg(feature = "prior_epoch")]
+ use crate::group::test_utils::get_test_group_context_with_id;
+
+ use crate::group::test_utils::random_bytes;
+
+ pub(crate) fn get_test_epoch_secrets(cipher_suite: CipherSuite) -> EpochSecrets {
+ let cs_provider = test_cipher_suite_provider(cipher_suite);
+
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ let secret_tree = get_test_tree(random_bytes(cs_provider.kdf_extract_size()), 2);
+
+ EpochSecrets {
+ #[cfg(feature = "psk")]
+ resumption_secret: random_bytes(cs_provider.kdf_extract_size()).into(),
+ sender_data_secret: random_bytes(cs_provider.kdf_extract_size()).into(),
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ secret_tree,
+ }
+ }
+
+ #[cfg(feature = "prior_epoch")]
+ pub(crate) fn get_test_epoch_with_id(
+ group_id: Vec<u8>,
+ cipher_suite: CipherSuite,
+ id: u64,
+ ) -> PriorEpoch {
+ PriorEpoch {
+ context: get_test_group_context_with_id(group_id, id, cipher_suite),
+ self_index: LeafIndex(0),
+ secrets: get_test_epoch_secrets(cipher_suite),
+ signature_public_keys: Default::default(),
+ }
+ }
+}
diff --git a/src/group/exported_tree.rs b/src/group/exported_tree.rs
new file mode 100644
index 0000000..acf507f
--- /dev/null
+++ b/src/group/exported_tree.rs
@@ -0,0 +1,51 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::{borrow::Cow, vec::Vec};
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+
+use crate::{client::MlsError, tree_kem::node::NodeVec};
+
+#[cfg_attr(
+ all(feature = "ffi", not(test)),
+ safer_ffi_gen::ffi_type(clone, opaque)
+)]
+#[derive(Debug, MlsSize, MlsEncode, MlsDecode, PartialEq, Clone)]
+pub struct ExportedTree<'a>(pub(crate) Cow<'a, NodeVec>);
+
+#[cfg_attr(all(feature = "ffi", not(test)), ::safer_ffi_gen::safer_ffi_gen)]
+impl<'a> ExportedTree<'a> {
+ pub(crate) fn new(node_data: NodeVec) -> Self {
+ Self(Cow::Owned(node_data))
+ }
+
+ pub(crate) fn new_borrowed(node_data: &'a NodeVec) -> Self {
+ Self(Cow::Borrowed(node_data))
+ }
+
+ pub fn to_bytes(&self) -> Result<Vec<u8>, MlsError> {
+ self.mls_encode_to_vec().map_err(Into::into)
+ }
+
+ pub fn byte_size(&self) -> usize {
+ self.mls_encoded_len()
+ }
+
+ pub fn into_owned(self) -> ExportedTree<'static> {
+ ExportedTree(Cow::Owned(self.0.into_owned()))
+ }
+}
+
+#[cfg_attr(all(feature = "ffi", not(test)), ::safer_ffi_gen::safer_ffi_gen)]
+impl ExportedTree<'static> {
+ pub fn from_bytes(bytes: &[u8]) -> Result<Self, MlsError> {
+ Self::mls_decode(&mut &*bytes).map_err(Into::into)
+ }
+}
+
+impl From<ExportedTree<'_>> for NodeVec {
+ fn from(value: ExportedTree) -> Self {
+ value.0.into_owned()
+ }
+}
diff --git a/src/group/external_commit.rs b/src/group/external_commit.rs
new file mode 100644
index 0000000..34b1042
--- /dev/null
+++ b/src/group/external_commit.rs
@@ -0,0 +1,266 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use mls_rs_core::{crypto::SignatureSecretKey, identity::SigningIdentity};
+
+use crate::{
+ client_config::ClientConfig,
+ group::{
+ cipher_suite_provider,
+ epoch::SenderDataSecret,
+ key_schedule::{InitSecret, KeySchedule},
+ proposal::{ExternalInit, Proposal, RemoveProposal},
+ EpochSecrets, ExternalPubExt, LeafIndex, LeafNode, MlsError, TreeKemPrivate,
+ },
+ Group, MlsMessage,
+};
+
+#[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+use crate::group::secret_tree::SecretTree;
+
+#[cfg(feature = "custom_proposal")]
+use crate::group::{
+ framing::MlsMessagePayload,
+ message_processor::{EventOrContent, MessageProcessor},
+ message_signature::AuthenticatedContent,
+ message_verifier::verify_plaintext_authentication,
+ CustomProposal,
+};
+
+use alloc::vec;
+use alloc::vec::Vec;
+
+#[cfg(feature = "psk")]
+use mls_rs_core::psk::{ExternalPskId, PreSharedKey};
+
+#[cfg(feature = "psk")]
+use crate::group::{
+ PreSharedKeyProposal, {JustPreSharedKeyID, PreSharedKeyID},
+};
+
+use super::{validate_group_info_joiner, ExportedTree};
+
+/// A builder that aids with the construction of an external commit.
+#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::ffi_type(opaque))]
+pub struct ExternalCommitBuilder<C: ClientConfig> {
+ signer: SignatureSecretKey,
+ signing_identity: SigningIdentity,
+ config: C,
+ tree_data: Option<ExportedTree<'static>>,
+ to_remove: Option<u32>,
+ #[cfg(feature = "psk")]
+ external_psks: Vec<ExternalPskId>,
+ authenticated_data: Vec<u8>,
+ #[cfg(feature = "custom_proposal")]
+ custom_proposals: Vec<Proposal>,
+ #[cfg(feature = "custom_proposal")]
+ received_custom_proposals: Vec<MlsMessage>,
+}
+
+impl<C: ClientConfig> ExternalCommitBuilder<C> {
+ pub(crate) fn new(
+ signer: SignatureSecretKey,
+ signing_identity: SigningIdentity,
+ config: C,
+ ) -> Self {
+ Self {
+ tree_data: None,
+ to_remove: None,
+ authenticated_data: Vec::new(),
+ signer,
+ signing_identity,
+ config,
+ #[cfg(feature = "psk")]
+ external_psks: Vec::new(),
+ #[cfg(feature = "custom_proposal")]
+ custom_proposals: Vec::new(),
+ #[cfg(feature = "custom_proposal")]
+ received_custom_proposals: Vec::new(),
+ }
+ }
+
+ #[must_use]
+ /// Use external tree data if the GroupInfo message does not contain a
+ /// [`RatchetTreeExt`](crate::extension::built_in::RatchetTreeExt)
+ pub fn with_tree_data(self, tree_data: ExportedTree<'static>) -> Self {
+ Self {
+ tree_data: Some(tree_data),
+ ..self
+ }
+ }
+
+ #[must_use]
+ /// Propose the removal of an old version of the client as part of the external commit.
+ /// Only one such proposal is allowed.
+ pub fn with_removal(self, to_remove: u32) -> Self {
+ Self {
+ to_remove: Some(to_remove),
+ ..self
+ }
+ }
+
+ #[must_use]
+ /// Add plaintext authenticated data to the resulting commit message.
+ pub fn with_authenticated_data(self, data: Vec<u8>) -> Self {
+ Self {
+ authenticated_data: data,
+ ..self
+ }
+ }
+
+ #[cfg(feature = "psk")]
+ #[must_use]
+ /// Add an external psk to the group as part of the external commit.
+ pub fn with_external_psk(mut self, psk: ExternalPskId) -> Self {
+ self.external_psks.push(psk);
+ self
+ }
+
+ #[cfg(feature = "custom_proposal")]
+ #[must_use]
+ /// Insert a [`CustomProposal`] into the current commit that is being built.
+ pub fn with_custom_proposal(mut self, proposal: CustomProposal) -> Self {
+ self.custom_proposals.push(Proposal::Custom(proposal));
+ self
+ }
+
+ #[cfg(all(feature = "custom_proposal", feature = "by_ref_proposal"))]
+ #[must_use]
+ /// Insert a [`CustomProposal`] received from a current group member into the current
+ /// commit that is being built.
+ ///
+ /// # Warning
+ ///
+ /// The authenticity of the proposal is NOT fully verified. It is only verified the
+ /// same way as by [`ExternalGroup`](`crate::external_client::ExternalGroup`).
+ /// The proposal MUST be an MlsPlaintext, else the [`Self::build`] function will fail.
+ pub fn with_received_custom_proposal(mut self, proposal: MlsMessage) -> Self {
+ self.received_custom_proposals.push(proposal);
+ self
+ }
+
+ /// Build the external commit using a GroupInfo message provided by an existing group member.
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn build(self, group_info: MlsMessage) -> Result<(Group<C>, MlsMessage), MlsError> {
+ let protocol_version = group_info.version;
+
+ if !self.config.version_supported(protocol_version) {
+ return Err(MlsError::UnsupportedProtocolVersion(protocol_version));
+ }
+
+ let group_info = group_info
+ .into_group_info()
+ .ok_or(MlsError::UnexpectedMessageType)?;
+
+ let cipher_suite = cipher_suite_provider(
+ self.config.crypto_provider(),
+ group_info.group_context.cipher_suite,
+ )?;
+
+ let external_pub_ext = group_info
+ .extensions
+ .get_as::<ExternalPubExt>()?
+ .ok_or(MlsError::MissingExternalPubExtension)?;
+
+ let public_tree = validate_group_info_joiner(
+ protocol_version,
+ &group_info,
+ self.tree_data,
+ &self.config.identity_provider(),
+ &cipher_suite,
+ )
+ .await?;
+
+ let (leaf_node, _) = LeafNode::generate(
+ &cipher_suite,
+ self.config.leaf_properties(),
+ self.signing_identity,
+ &self.signer,
+ self.config.lifetime(),
+ )
+ .await?;
+
+ let (init_secret, kem_output) =
+ InitSecret::encode_for_external(&cipher_suite, &external_pub_ext.external_pub).await?;
+
+ let epoch_secrets = EpochSecrets {
+ #[cfg(feature = "psk")]
+ resumption_secret: PreSharedKey::new(vec![]),
+ sender_data_secret: SenderDataSecret::from(vec![]),
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ secret_tree: SecretTree::empty(),
+ };
+
+ let (mut group, _) = Group::join_with(
+ self.config,
+ group_info,
+ public_tree,
+ KeySchedule::new(init_secret),
+ epoch_secrets,
+ TreeKemPrivate::new_for_external(),
+ None,
+ self.signer,
+ )
+ .await?;
+
+ #[cfg(feature = "psk")]
+ let psk_ids = self
+ .external_psks
+ .into_iter()
+ .map(|psk_id| PreSharedKeyID::new(JustPreSharedKeyID::External(psk_id), &cipher_suite))
+ .collect::<Result<Vec<_>, MlsError>>()?;
+
+ let mut proposals = vec![Proposal::ExternalInit(ExternalInit { kem_output })];
+
+ #[cfg(feature = "psk")]
+ proposals.extend(
+ psk_ids
+ .into_iter()
+ .map(|psk| Proposal::Psk(PreSharedKeyProposal { psk })),
+ );
+
+ #[cfg(feature = "custom_proposal")]
+ {
+ let mut custom_proposals = self.custom_proposals;
+ proposals.append(&mut custom_proposals);
+ }
+
+ #[cfg(all(feature = "custom_proposal", feature = "by_ref_proposal"))]
+ for message in self.received_custom_proposals {
+ let MlsMessagePayload::Plain(plaintext) = message.payload else {
+ return Err(MlsError::UnexpectedMessageType);
+ };
+
+ let auth_content = AuthenticatedContent::from(plaintext.clone());
+
+ verify_plaintext_authentication(&cipher_suite, plaintext, None, None, &group.state)
+ .await?;
+
+ group
+ .process_event_or_content(EventOrContent::Content(auth_content), true, None)
+ .await?;
+ }
+
+ if let Some(r) = self.to_remove {
+ proposals.push(Proposal::Remove(RemoveProposal {
+ to_remove: LeafIndex(r),
+ }));
+ }
+
+ let commit_output = group
+ .commit_internal(
+ proposals,
+ Some(&leaf_node),
+ self.authenticated_data,
+ Default::default(),
+ None,
+ None,
+ )
+ .await?;
+
+ group.apply_pending_commit().await?;
+
+ Ok((group, commit_output.commit_message))
+ }
+}
diff --git a/src/group/framing.rs b/src/group/framing.rs
new file mode 100644
index 0000000..8663b96
--- /dev/null
+++ b/src/group/framing.rs
@@ -0,0 +1,741 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use core::ops::Deref;
+
+use crate::{client::MlsError, tree_kem::node::LeafIndex, KeyPackage, KeyPackageRef};
+
+use super::{Commit, FramedContentAuthData, GroupInfo, MembershipTag, Welcome};
+
+#[cfg(feature = "by_ref_proposal")]
+use crate::{group::Proposal, mls_rules::ProposalRef};
+
+use alloc::vec::Vec;
+use core::fmt::{self, Debug};
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+use mls_rs_core::{
+ crypto::{CipherSuite, CipherSuiteProvider},
+ protocol_version::ProtocolVersion,
+};
+use zeroize::ZeroizeOnDrop;
+
+#[cfg(feature = "private_message")]
+use alloc::boxed::Box;
+
+#[cfg(feature = "custom_proposal")]
+use crate::group::proposal::{CustomProposal, ProposalOrRef};
+
+#[derive(Copy, Clone, Debug, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[repr(u8)]
+pub enum ContentType {
+ #[cfg(feature = "private_message")]
+ Application = 1u8,
+ #[cfg(feature = "by_ref_proposal")]
+ Proposal = 2u8,
+ Commit = 3u8,
+}
+
+impl From<&Content> for ContentType {
+ fn from(content: &Content) -> Self {
+ match content {
+ #[cfg(feature = "private_message")]
+ Content::Application(_) => ContentType::Application,
+ #[cfg(feature = "by_ref_proposal")]
+ Content::Proposal(_) => ContentType::Proposal,
+ Content::Commit(_) => ContentType::Commit,
+ }
+ }
+}
+
+#[cfg_attr(
+ all(feature = "ffi", not(test)),
+ safer_ffi_gen::ffi_type(clone, opaque)
+)]
+#[derive(Clone, Copy, Debug, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+#[repr(u8)]
+#[non_exhaustive]
+/// Description of a [`MlsMessage`] sender
+pub enum Sender {
+ /// Current group member index.
+ Member(u32) = 1u8,
+ /// An external entity sending a proposal proposal identified by an index
+ /// in the current
+ /// [`ExternalSendersExt`](crate::extension::ExternalSendersExt) stored in
+ /// group context extensions.
+ #[cfg(feature = "by_ref_proposal")]
+ External(u32) = 2u8,
+ /// A new member proposing their own addition to the group.
+ #[cfg(feature = "by_ref_proposal")]
+ NewMemberProposal = 3u8,
+ /// A member sending an external commit.
+ NewMemberCommit = 4u8,
+}
+
+impl From<LeafIndex> for Sender {
+ fn from(leaf_index: LeafIndex) -> Self {
+ Sender::Member(*leaf_index)
+ }
+}
+
+impl From<u32> for Sender {
+ fn from(leaf_index: u32) -> Self {
+ Sender::Member(leaf_index)
+ }
+}
+
+#[derive(Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode, ZeroizeOnDrop)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub struct ApplicationData(
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
+ Vec<u8>,
+);
+
+impl Debug for ApplicationData {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ mls_rs_core::debug::pretty_bytes(&self.0)
+ .named("ApplicationData")
+ .fmt(f)
+ }
+}
+
+impl From<Vec<u8>> for ApplicationData {
+ fn from(data: Vec<u8>) -> Self {
+ Self(data)
+ }
+}
+
+impl Deref for ApplicationData {
+ type Target = [u8];
+
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+impl ApplicationData {
+ /// Underlying message content.
+ pub fn as_bytes(&self) -> &[u8] {
+ &self.0
+ }
+}
+
+#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+#[repr(u8)]
+pub(crate) enum Content {
+ #[cfg(feature = "private_message")]
+ Application(ApplicationData) = 1u8,
+ #[cfg(feature = "by_ref_proposal")]
+ Proposal(alloc::boxed::Box<Proposal>) = 2u8,
+ Commit(alloc::boxed::Box<Commit>) = 3u8,
+}
+
+impl Content {
+ pub fn content_type(&self) -> ContentType {
+ self.into()
+ }
+}
+
+#[derive(Clone, Debug, PartialEq)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+pub(crate) struct PublicMessage {
+ pub content: FramedContent,
+ pub auth: FramedContentAuthData,
+ pub membership_tag: Option<MembershipTag>,
+}
+
+impl MlsSize for PublicMessage {
+ fn mls_encoded_len(&self) -> usize {
+ self.content.mls_encoded_len()
+ + self.auth.mls_encoded_len()
+ + self
+ .membership_tag
+ .as_ref()
+ .map_or(0, |tag| tag.mls_encoded_len())
+ }
+}
+
+impl MlsEncode for PublicMessage {
+ fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error> {
+ self.content.mls_encode(writer)?;
+ self.auth.mls_encode(writer)?;
+
+ self.membership_tag
+ .as_ref()
+ .map_or(Ok(()), |tag| tag.mls_encode(writer))
+ }
+}
+
+impl MlsDecode for PublicMessage {
+ fn mls_decode(reader: &mut &[u8]) -> Result<Self, mls_rs_codec::Error> {
+ let content = FramedContent::mls_decode(reader)?;
+ let auth = FramedContentAuthData::mls_decode(reader, content.content_type())?;
+
+ let membership_tag = match content.sender {
+ Sender::Member(_) => Some(MembershipTag::mls_decode(reader)?),
+ _ => None,
+ };
+
+ Ok(Self {
+ content,
+ auth,
+ membership_tag,
+ })
+ }
+}
+
+#[cfg(feature = "private_message")]
+#[derive(Clone, Debug, PartialEq)]
+pub(crate) struct PrivateMessageContent {
+ pub content: Content,
+ pub auth: FramedContentAuthData,
+}
+
+#[cfg(feature = "private_message")]
+impl MlsSize for PrivateMessageContent {
+ fn mls_encoded_len(&self) -> usize {
+ let content_len_without_type = match &self.content {
+ Content::Application(c) => c.mls_encoded_len(),
+ #[cfg(feature = "by_ref_proposal")]
+ Content::Proposal(c) => c.mls_encoded_len(),
+ Content::Commit(c) => c.mls_encoded_len(),
+ };
+
+ content_len_without_type + self.auth.mls_encoded_len()
+ }
+}
+
+#[cfg(feature = "private_message")]
+impl MlsEncode for PrivateMessageContent {
+ fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error> {
+ match &self.content {
+ Content::Application(c) => c.mls_encode(writer),
+ #[cfg(feature = "by_ref_proposal")]
+ Content::Proposal(c) => c.mls_encode(writer),
+ Content::Commit(c) => c.mls_encode(writer),
+ }?;
+
+ self.auth.mls_encode(writer)?;
+
+ Ok(())
+ }
+}
+
+#[cfg(feature = "private_message")]
+impl PrivateMessageContent {
+ pub(crate) fn mls_decode(
+ reader: &mut &[u8],
+ content_type: ContentType,
+ ) -> Result<Self, mls_rs_codec::Error> {
+ let content = match content_type {
+ ContentType::Application => Content::Application(ApplicationData::mls_decode(reader)?),
+ #[cfg(feature = "by_ref_proposal")]
+ ContentType::Proposal => Content::Proposal(Box::new(Proposal::mls_decode(reader)?)),
+ ContentType::Commit => {
+ Content::Commit(alloc::boxed::Box::new(Commit::mls_decode(reader)?))
+ }
+ };
+
+ let auth = FramedContentAuthData::mls_decode(reader, content.content_type())?;
+
+ if reader.iter().any(|&i| i != 0u8) {
+ // #[cfg(feature = "std")]
+ // return Err(mls_rs_codec::Error::Custom(
+ // "non-zero padding bytes discovered".to_string(),
+ // ));
+
+ // #[cfg(not(feature = "std"))]
+ return Err(mls_rs_codec::Error::Custom(5));
+ }
+
+ Ok(Self { content, auth })
+ }
+}
+
+#[cfg(feature = "private_message")]
+#[derive(Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
+pub struct PrivateContentAAD {
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ pub group_id: Vec<u8>,
+ pub epoch: u64,
+ pub content_type: ContentType,
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ pub authenticated_data: Vec<u8>,
+}
+
+#[cfg(feature = "private_message")]
+impl Debug for PrivateContentAAD {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("PrivateContentAAD")
+ .field(
+ "group_id",
+ &mls_rs_core::debug::pretty_group_id(&self.group_id),
+ )
+ .field("epoch", &self.epoch)
+ .field("content_type", &self.content_type)
+ .field(
+ "authenticated_data",
+ &mls_rs_core::debug::pretty_bytes(&self.authenticated_data),
+ )
+ .finish()
+ }
+}
+
+#[cfg(feature = "private_message")]
+#[derive(Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+pub struct PrivateMessage {
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ pub group_id: Vec<u8>,
+ pub epoch: u64,
+ pub content_type: ContentType,
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ pub authenticated_data: Vec<u8>,
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ pub encrypted_sender_data: Vec<u8>,
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ pub ciphertext: Vec<u8>,
+}
+
+#[cfg(feature = "private_message")]
+impl Debug for PrivateMessage {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("PrivateMessage")
+ .field(
+ "group_id",
+ &mls_rs_core::debug::pretty_group_id(&self.group_id),
+ )
+ .field("epoch", &self.epoch)
+ .field("content_type", &self.content_type)
+ .field(
+ "authenticated_data",
+ &mls_rs_core::debug::pretty_bytes(&self.authenticated_data),
+ )
+ .field(
+ "encrypted_sender_data",
+ &mls_rs_core::debug::pretty_bytes(&self.encrypted_sender_data),
+ )
+ .field(
+ "ciphertext",
+ &mls_rs_core::debug::pretty_bytes(&self.ciphertext),
+ )
+ .finish()
+ }
+}
+
+#[cfg(feature = "private_message")]
+impl From<&PrivateMessage> for PrivateContentAAD {
+ fn from(ciphertext: &PrivateMessage) -> Self {
+ Self {
+ group_id: ciphertext.group_id.clone(),
+ epoch: ciphertext.epoch,
+ content_type: ciphertext.content_type,
+ authenticated_data: ciphertext.authenticated_data.clone(),
+ }
+ }
+}
+
+#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(
+ all(feature = "ffi", not(test)),
+ ::safer_ffi_gen::ffi_type(clone, opaque)
+)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+/// A MLS protocol message for sending data over the wire.
+pub struct MlsMessage {
+ pub(crate) version: ProtocolVersion,
+ pub(crate) payload: MlsMessagePayload,
+}
+
+#[cfg_attr(all(feature = "ffi", not(test)), ::safer_ffi_gen::safer_ffi_gen)]
+#[allow(dead_code)]
+impl MlsMessage {
+ pub(crate) fn new(version: ProtocolVersion, payload: MlsMessagePayload) -> MlsMessage {
+ Self { version, payload }
+ }
+
+ #[inline(always)]
+ pub(crate) fn into_plaintext(self) -> Option<PublicMessage> {
+ match self.payload {
+ MlsMessagePayload::Plain(plaintext) => Some(plaintext),
+ _ => None,
+ }
+ }
+
+ #[cfg(feature = "private_message")]
+ #[inline(always)]
+ pub(crate) fn into_ciphertext(self) -> Option<PrivateMessage> {
+ match self.payload {
+ MlsMessagePayload::Cipher(ciphertext) => Some(ciphertext),
+ _ => None,
+ }
+ }
+
+ #[inline(always)]
+ pub(crate) fn into_welcome(self) -> Option<Welcome> {
+ match self.payload {
+ MlsMessagePayload::Welcome(welcome) => Some(welcome),
+ _ => None,
+ }
+ }
+
+ #[inline(always)]
+ pub fn into_group_info(self) -> Option<GroupInfo> {
+ match self.payload {
+ MlsMessagePayload::GroupInfo(info) => Some(info),
+ _ => None,
+ }
+ }
+
+ #[inline(always)]
+ pub fn as_group_info(&self) -> Option<&GroupInfo> {
+ match &self.payload {
+ MlsMessagePayload::GroupInfo(info) => Some(info),
+ _ => None,
+ }
+ }
+
+ #[inline(always)]
+ pub fn into_key_package(self) -> Option<KeyPackage> {
+ match self.payload {
+ MlsMessagePayload::KeyPackage(kp) => Some(kp),
+ _ => None,
+ }
+ }
+
+ /// The wire format value describing the contents of this message.
+ pub fn wire_format(&self) -> WireFormat {
+ match self.payload {
+ MlsMessagePayload::Plain(_) => WireFormat::PublicMessage,
+ #[cfg(feature = "private_message")]
+ MlsMessagePayload::Cipher(_) => WireFormat::PrivateMessage,
+ MlsMessagePayload::Welcome(_) => WireFormat::Welcome,
+ MlsMessagePayload::GroupInfo(_) => WireFormat::GroupInfo,
+ MlsMessagePayload::KeyPackage(_) => WireFormat::KeyPackage,
+ }
+ }
+
+ /// The epoch that this message belongs to.
+ ///
+ /// Returns `None` if the message is [`WireFormat::KeyPackage`]
+ /// or [`WireFormat::Welcome`]
+ #[cfg_attr(all(feature = "ffi", not(test)), ::safer_ffi_gen::safer_ffi_gen_ignore)]
+ pub fn epoch(&self) -> Option<u64> {
+ match &self.payload {
+ MlsMessagePayload::Plain(p) => Some(p.content.epoch),
+ #[cfg(feature = "private_message")]
+ MlsMessagePayload::Cipher(c) => Some(c.epoch),
+ MlsMessagePayload::GroupInfo(gi) => Some(gi.group_context.epoch),
+ _ => None,
+ }
+ }
+
+ #[cfg_attr(all(feature = "ffi", not(test)), ::safer_ffi_gen::safer_ffi_gen_ignore)]
+ pub fn cipher_suite(&self) -> Option<CipherSuite> {
+ match &self.payload {
+ MlsMessagePayload::GroupInfo(i) => Some(i.group_context.cipher_suite),
+ MlsMessagePayload::Welcome(w) => Some(w.cipher_suite),
+ MlsMessagePayload::KeyPackage(k) => Some(k.cipher_suite),
+ _ => None,
+ }
+ }
+
+ pub fn group_id(&self) -> Option<&[u8]> {
+ match &self.payload {
+ MlsMessagePayload::Plain(p) => Some(&p.content.group_id),
+ #[cfg(feature = "private_message")]
+ MlsMessagePayload::Cipher(p) => Some(&p.group_id),
+ MlsMessagePayload::GroupInfo(p) => Some(&p.group_context.group_id),
+ MlsMessagePayload::KeyPackage(_) | MlsMessagePayload::Welcome(_) => None,
+ }
+ }
+
+ /// Deserialize a message from transport.
+ #[inline(never)]
+ pub fn from_bytes(bytes: &[u8]) -> Result<Self, MlsError> {
+ Self::mls_decode(&mut &*bytes).map_err(Into::into)
+ }
+
+ /// Serialize a message for transport.
+ pub fn to_bytes(&self) -> Result<Vec<u8>, MlsError> {
+ self.mls_encode_to_vec().map_err(Into::into)
+ }
+
+ /// If this is a plaintext commit message, return all custom proposals committed by value.
+ /// If this is not a plaintext or not a commit, this returns an empty list.
+ #[cfg(feature = "custom_proposal")]
+ pub fn custom_proposals_by_value(&self) -> Vec<&CustomProposal> {
+ match &self.payload {
+ MlsMessagePayload::Plain(plaintext) => match &plaintext.content.content {
+ Content::Commit(commit) => Self::find_custom_proposals(commit),
+ _ => Vec::new(),
+ },
+ _ => Vec::new(),
+ }
+ }
+
+ /// If this is a welcome message, return key package references of all members who can
+ /// join using this message.
+ pub fn welcome_key_package_references(&self) -> Vec<&KeyPackageRef> {
+ let MlsMessagePayload::Welcome(welcome) = &self.payload else {
+ return Vec::new();
+ };
+
+ welcome.secrets.iter().map(|s| &s.new_member).collect()
+ }
+
+ /// If this is a key package, return its key package reference.
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn key_package_reference<C: CipherSuiteProvider>(
+ &self,
+ cipher_suite: &C,
+ ) -> Result<Option<KeyPackageRef>, MlsError> {
+ let MlsMessagePayload::KeyPackage(kp) = &self.payload else {
+ return Ok(None);
+ };
+
+ kp.to_reference(cipher_suite).await.map(Some)
+ }
+
+ /// If this is a plaintext proposal, return the proposal reference that can be matched e.g. with
+ /// [`StateUpdate::unused_proposals`](super::StateUpdate::unused_proposals).
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn into_proposal_reference<C: CipherSuiteProvider>(
+ self,
+ cipher_suite: &C,
+ ) -> Result<Option<Vec<u8>>, MlsError> {
+ let MlsMessagePayload::Plain(public_message) = self.payload else {
+ return Ok(None);
+ };
+
+ ProposalRef::from_content(cipher_suite, &public_message.into())
+ .await
+ .map(|r| Some(r.to_vec()))
+ }
+}
+
+#[cfg(feature = "custom_proposal")]
+impl MlsMessage {
+ fn find_custom_proposals(commit: &Commit) -> Vec<&CustomProposal> {
+ commit
+ .proposals
+ .iter()
+ .filter_map(|p| match p {
+ ProposalOrRef::Proposal(p) => match p.as_ref() {
+ crate::group::Proposal::Custom(p) => Some(p),
+ _ => None,
+ },
+ _ => None,
+ })
+ .collect()
+ }
+}
+
+#[allow(clippy::large_enum_variant)]
+#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[repr(u16)]
+pub(crate) enum MlsMessagePayload {
+ Plain(PublicMessage) = 1u16,
+ #[cfg(feature = "private_message")]
+ Cipher(PrivateMessage) = 2u16,
+ Welcome(Welcome) = 3u16,
+ GroupInfo(GroupInfo) = 4u16,
+ KeyPackage(KeyPackage) = 5u16,
+}
+
+impl From<PublicMessage> for MlsMessagePayload {
+ fn from(m: PublicMessage) -> Self {
+ Self::Plain(m)
+ }
+}
+
+#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::ffi_type)]
+#[derive(
+ Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd, MlsSize, MlsEncode, MlsDecode,
+)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+#[repr(u16)]
+#[non_exhaustive]
+/// Content description of an [`MlsMessage`]
+pub enum WireFormat {
+ PublicMessage = 1u16,
+ PrivateMessage = 2u16,
+ Welcome = 3u16,
+ GroupInfo = 4u16,
+ KeyPackage = 5u16,
+}
+
+#[derive(Clone, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub(crate) struct FramedContent {
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
+ pub group_id: Vec<u8>,
+ pub epoch: u64,
+ pub sender: Sender,
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
+ pub authenticated_data: Vec<u8>,
+ pub content: Content,
+}
+
+impl Debug for FramedContent {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("FramedContent")
+ .field(
+ "group_id",
+ &mls_rs_core::debug::pretty_group_id(&self.group_id),
+ )
+ .field("epoch", &self.epoch)
+ .field("sender", &self.sender)
+ .field(
+ "authenticated_data",
+ &mls_rs_core::debug::pretty_bytes(&self.authenticated_data),
+ )
+ .field("content", &self.content)
+ .finish()
+ }
+}
+
+impl FramedContent {
+ pub fn content_type(&self) -> ContentType {
+ self.content.content_type()
+ }
+}
+
+#[cfg(test)]
+pub(crate) mod test_utils {
+ #[cfg(feature = "private_message")]
+ use crate::group::test_utils::random_bytes;
+
+ use crate::group::{AuthenticatedContent, MessageSignature};
+
+ use super::*;
+
+ use alloc::boxed::Box;
+
+ pub(crate) fn get_test_auth_content() -> AuthenticatedContent {
+ // This is not a valid commit and should not be validated
+ let commit = Commit {
+ proposals: Default::default(),
+ path: None,
+ };
+
+ AuthenticatedContent {
+ wire_format: WireFormat::PublicMessage,
+ content: FramedContent {
+ group_id: Vec::new(),
+ epoch: 0,
+ sender: Sender::Member(1),
+ authenticated_data: Vec::new(),
+ content: Content::Commit(Box::new(commit)),
+ },
+ auth: FramedContentAuthData {
+ signature: MessageSignature::empty(),
+ confirmation_tag: None,
+ },
+ }
+ }
+
+ #[cfg(feature = "private_message")]
+ pub(crate) fn get_test_ciphertext_content() -> PrivateMessageContent {
+ PrivateMessageContent {
+ content: Content::Application(random_bytes(1024).into()),
+ auth: FramedContentAuthData {
+ signature: MessageSignature::from(random_bytes(128)),
+ confirmation_tag: None,
+ },
+ }
+ }
+
+ impl AsRef<[u8]> for ApplicationData {
+ fn as_ref(&self) -> &[u8] {
+ &self.0
+ }
+ }
+}
+
+#[cfg(feature = "private_message")]
+#[cfg(test)]
+mod tests {
+ use assert_matches::assert_matches;
+
+ use crate::{
+ client::test_utils::{TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
+ crypto::test_utils::test_cipher_suite_provider,
+ group::{
+ framing::test_utils::get_test_ciphertext_content,
+ proposal_ref::test_utils::auth_content_from_proposal, RemoveProposal,
+ },
+ };
+
+ use super::*;
+
+ #[test]
+ fn test_mls_ciphertext_content_mls_encoding() {
+ let ciphertext_content = get_test_ciphertext_content();
+
+ let mut encoded = ciphertext_content.mls_encode_to_vec().unwrap();
+ encoded.extend_from_slice(&[0u8; 128]);
+
+ let decoded =
+ PrivateMessageContent::mls_decode(&mut &*encoded, (&ciphertext_content.content).into())
+ .unwrap();
+
+ assert_eq!(ciphertext_content, decoded);
+ }
+
+ #[test]
+ fn test_mls_ciphertext_content_non_zero_padding_error() {
+ let ciphertext_content = get_test_ciphertext_content();
+
+ let mut encoded = ciphertext_content.mls_encode_to_vec().unwrap();
+ encoded.extend_from_slice(&[1u8; 128]);
+
+ let decoded =
+ PrivateMessageContent::mls_decode(&mut &*encoded, (&ciphertext_content.content).into());
+
+ assert_matches!(decoded, Err(mls_rs_codec::Error::Custom(_)));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn proposal_ref() {
+ let cs = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let test_auth = auth_content_from_proposal(
+ Proposal::Remove(RemoveProposal {
+ to_remove: LeafIndex(0),
+ }),
+ Sender::External(0),
+ );
+
+ let expected_ref = ProposalRef::from_content(&cs, &test_auth).await.unwrap();
+
+ let test_message = MlsMessage {
+ version: TEST_PROTOCOL_VERSION,
+ payload: MlsMessagePayload::Plain(PublicMessage {
+ content: test_auth.content,
+ auth: test_auth.auth,
+ membership_tag: Some(cs.mac(&[1, 2, 3], &[1, 2, 3]).await.unwrap().into()),
+ }),
+ };
+
+ let computed_ref = test_message
+ .into_proposal_reference(&cs)
+ .await
+ .unwrap()
+ .unwrap();
+
+ assert_eq!(computed_ref, expected_ref.to_vec());
+ }
+}
diff --git a/src/group/group_info.rs b/src/group/group_info.rs
new file mode 100644
index 0000000..a5e7268
--- /dev/null
+++ b/src/group/group_info.rs
@@ -0,0 +1,95 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::vec::Vec;
+use core::fmt::{self, Debug};
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+use mls_rs_core::extension::ExtensionList;
+
+use crate::{signer::Signable, tree_kem::node::LeafIndex};
+
+use super::{ConfirmationTag, GroupContext};
+
+#[derive(Clone, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(
+ all(feature = "ffi", not(test)),
+ safer_ffi_gen::ffi_type(clone, opaque)
+)]
+pub struct GroupInfo {
+ pub(crate) group_context: GroupContext,
+ pub(crate) extensions: ExtensionList,
+ pub(crate) confirmation_tag: ConfirmationTag,
+ pub(crate) signer: LeafIndex,
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ pub(crate) signature: Vec<u8>,
+}
+
+impl Debug for GroupInfo {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("GroupInfo")
+ .field("group_context", &self.group_context)
+ .field("extensions", &self.extensions)
+ .field("confirmation_tag", &self.confirmation_tag)
+ .field("signer", &self.signer)
+ .field(
+ "signature",
+ &mls_rs_core::debug::pretty_bytes(&self.signature),
+ )
+ .finish()
+ }
+}
+
+#[cfg_attr(all(feature = "ffi", not(test)), ::safer_ffi_gen::safer_ffi_gen)]
+impl GroupInfo {
+ /// Group context.
+ pub fn group_context(&self) -> &GroupContext {
+ &self.group_context
+ }
+
+ /// Group info extensions (not to be confused with group context extensions),
+ /// e.g. the ratchet tree.
+ pub fn extensions(&self) -> &ExtensionList {
+ &self.extensions
+ }
+
+ /// Leaf index of the sender who generated and signed this group info.
+ pub fn sender(&self) -> u32 {
+ *self.signer
+ }
+}
+
+#[derive(MlsEncode, MlsSize)]
+struct SignableGroupInfo<'a> {
+ group_context: &'a GroupContext,
+ extensions: &'a ExtensionList,
+ confirmation_tag: &'a ConfirmationTag,
+ signer: LeafIndex,
+}
+
+impl<'a> Signable<'a> for GroupInfo {
+ const SIGN_LABEL: &'static str = "GroupInfoTBS";
+ type SigningContext = ();
+
+ fn signature(&self) -> &[u8] {
+ &self.signature
+ }
+
+ fn signable_content(
+ &self,
+ _context: &Self::SigningContext,
+ ) -> Result<Vec<u8>, mls_rs_codec::Error> {
+ SignableGroupInfo {
+ group_context: &self.group_context,
+ extensions: &self.extensions,
+ confirmation_tag: &self.confirmation_tag,
+ signer: self.signer,
+ }
+ .mls_encode_to_vec()
+ }
+
+ fn write_signature(&mut self, signature: Vec<u8>) {
+ self.signature = signature
+ }
+}
diff --git a/src/group/interop_test_vectors.rs b/src/group/interop_test_vectors.rs
new file mode 100644
index 0000000..abd82fc
--- /dev/null
+++ b/src/group/interop_test_vectors.rs
@@ -0,0 +1,9 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+mod framing;
+mod passive_client;
+mod serialization;
+mod tree_kem;
+mod tree_modifications;
diff --git a/src/group/interop_test_vectors/framing.rs b/src/group/interop_test_vectors/framing.rs
new file mode 100644
index 0000000..30e4225
--- /dev/null
+++ b/src/group/interop_test_vectors/framing.rs
@@ -0,0 +1,461 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::vec;
+use alloc::vec::Vec;
+use mls_rs_codec::{MlsDecode, MlsEncode};
+use mls_rs_core::crypto::{CipherSuite, CipherSuiteProvider, SignaturePublicKey};
+
+use crate::{
+ client::test_utils::{TestClientConfig, TEST_PROTOCOL_VERSION},
+ crypto::test_utils::{test_cipher_suite_provider, try_test_cipher_suite_provider},
+ group::{
+ confirmation_tag::ConfirmationTag,
+ epoch::EpochSecrets,
+ framing::{Content, WireFormat},
+ message_processor::{EventOrContent, MessageProcessor},
+ mls_rules::EncryptionOptions,
+ padding::PaddingMode,
+ proposal::{Proposal, RemoveProposal},
+ secret_tree::test_utils::get_test_tree,
+ test_utils::{random_bytes, test_group_custom_config},
+ AuthenticatedContent, Commit, Group, GroupContext, MlsMessage, Sender,
+ },
+ mls_rules::DefaultMlsRules,
+ test_utils::is_edwards,
+ tree_kem::{leaf_node::test_utils::get_basic_test_node, node::LeafIndex},
+};
+
+const FRAMING_N_LEAVES: u32 = 2;
+
+#[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
+struct FramingTestCase {
+ #[serde(flatten)]
+ pub context: InteropGroupContext,
+
+ #[serde(with = "hex::serde")]
+ pub signature_priv: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub signature_pub: Vec<u8>,
+
+ #[serde(with = "hex::serde")]
+ pub encryption_secret: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub sender_data_secret: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub membership_key: Vec<u8>,
+
+ #[serde(with = "hex::serde")]
+ pub proposal: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub proposal_priv: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub proposal_pub: Vec<u8>,
+
+ #[serde(with = "hex::serde")]
+ pub commit: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub commit_priv: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub commit_pub: Vec<u8>,
+
+ #[serde(with = "hex::serde")]
+ pub application: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub application_priv: Vec<u8>,
+}
+
+impl FramingTestCase {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ async fn random<P: CipherSuiteProvider>(cs: &P) -> Self {
+ let mut context = InteropGroupContext::random(cs);
+ context.cipher_suite = cs.cipher_suite().into();
+
+ let (mut signature_priv, signature_pub) = cs.signature_key_generate().await.unwrap();
+
+ if is_edwards(*cs.cipher_suite()) {
+ signature_priv = signature_priv[0..signature_priv.len() / 2].to_vec().into();
+ }
+
+ Self {
+ context,
+ signature_priv: signature_priv.to_vec(),
+ signature_pub: signature_pub.to_vec(),
+ encryption_secret: random_bytes(cs.kdf_extract_size()),
+ sender_data_secret: random_bytes(cs.kdf_extract_size()),
+ membership_key: random_bytes(cs.kdf_extract_size()),
+ ..Default::default()
+ }
+ }
+}
+
+#[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
+pub struct InteropGroupContext {
+ pub cipher_suite: u16,
+ #[serde(with = "hex::serde")]
+ pub group_id: Vec<u8>,
+ pub epoch: u64,
+ #[serde(with = "hex::serde")]
+ pub tree_hash: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub confirmed_transcript_hash: Vec<u8>,
+}
+
+impl InteropGroupContext {
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn random<P: CipherSuiteProvider>(cs: &P) -> Self {
+ Self {
+ cipher_suite: cs.cipher_suite().into(),
+ group_id: random_bytes(cs.kdf_extract_size()),
+ epoch: 0x121212,
+ tree_hash: random_bytes(cs.kdf_extract_size()),
+ confirmed_transcript_hash: random_bytes(cs.kdf_extract_size()),
+ }
+ }
+}
+
+impl From<InteropGroupContext> for GroupContext {
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn from(ctx: InteropGroupContext) -> Self {
+ Self {
+ cipher_suite: ctx.cipher_suite.into(),
+ protocol_version: TEST_PROTOCOL_VERSION,
+ group_id: ctx.group_id,
+ epoch: ctx.epoch,
+ tree_hash: ctx.tree_hash,
+ confirmed_transcript_hash: ctx.confirmed_transcript_hash.into(),
+ extensions: vec![].into(),
+ }
+ }
+}
+
+// The test vector can be found here:
+// https://github.com/mlswg/mls-implementations/blob/main/test-vectors/message-protection.json
+#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+async fn framing_proposal() {
+ #[cfg(not(mls_build_async))]
+ let test_cases: Vec<FramingTestCase> =
+ load_test_case_json!(framing, generate_framing_test_vector());
+
+ #[cfg(mls_build_async)]
+ let test_cases: Vec<FramingTestCase> =
+ load_test_case_json!(framing, generate_framing_test_vector().await);
+
+ for test_case in test_cases.into_iter() {
+ let Some(cs) = try_test_cipher_suite_provider(test_case.context.cipher_suite) else {
+ continue;
+ };
+
+ let to_check = vec![
+ test_case.proposal_priv.clone(),
+ test_case.proposal_pub.clone(),
+ ];
+
+ // Wasm uses incompatible signature secret key format
+ #[cfg(not(target_arch = "wasm32"))]
+ let mut to_check = to_check;
+
+ #[cfg(not(target_arch = "wasm32"))]
+ for enable_encryption in [true, false] {
+ let proposal = Proposal::mls_decode(&mut &*test_case.proposal).unwrap();
+
+ let built = make_group(&test_case, true, enable_encryption, &cs)
+ .await
+ .proposal_message(proposal, vec![])
+ .await
+ .unwrap()
+ .mls_encode_to_vec()
+ .unwrap();
+
+ to_check.push(built);
+ }
+
+ let proposal = Proposal::mls_decode(&mut &*test_case.proposal).unwrap();
+
+ for message in to_check {
+ match process_message(&test_case, &message, &cs).await {
+ Content::Proposal(p) => assert_eq!(p.as_ref(), &proposal),
+ _ => panic!("received value not proposal"),
+ };
+ }
+ }
+}
+
+// The test vector can be found here:
+// https://github.com/mlswg/mls-implementations/blob/main/test-vectors/message-protection.json
+// Wasm uses incompatible signature secret key format
+#[cfg(not(target_arch = "wasm32"))]
+#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+async fn framing_application() {
+ #[cfg(not(mls_build_async))]
+ let test_cases: Vec<FramingTestCase> =
+ load_test_case_json!(framing, generate_framing_test_vector());
+
+ #[cfg(mls_build_async)]
+ let test_cases: Vec<FramingTestCase> =
+ load_test_case_json!(framing, generate_framing_test_vector().await);
+
+ for test_case in test_cases.into_iter() {
+ let Some(cs) = try_test_cipher_suite_provider(test_case.context.cipher_suite) else {
+ continue;
+ };
+
+ let built_priv = make_group(&test_case, true, true, &cs)
+ .await
+ .encrypt_application_message(&test_case.application, vec![])
+ .await
+ .unwrap()
+ .mls_encode_to_vec()
+ .unwrap();
+
+ for message in [&test_case.application_priv, &built_priv] {
+ match process_message(&test_case, message, &cs).await {
+ Content::Application(data) => assert_eq!(data.as_ref(), &test_case.application),
+ _ => panic!("decrypted value not application data"),
+ };
+ }
+ }
+}
+
+// The test vector can be found here:
+// https://github.com/mlswg/mls-implementations/blob/main/test-vectors/message-protection.json
+#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+async fn framing_commit() {
+ #[cfg(not(mls_build_async))]
+ let test_cases: Vec<FramingTestCase> =
+ load_test_case_json!(framing, generate_framing_test_vector());
+
+ #[cfg(mls_build_async)]
+ let test_cases: Vec<FramingTestCase> =
+ load_test_case_json!(framing, generate_framing_test_vector().await);
+
+ for test_case in test_cases.into_iter() {
+ let Some(cs) = try_test_cipher_suite_provider(test_case.context.cipher_suite) else {
+ continue;
+ };
+
+ let commit = Commit::mls_decode(&mut &*test_case.commit).unwrap();
+
+ let to_check = vec![test_case.commit_priv.clone(), test_case.commit_pub.clone()];
+
+ // Wasm uses incompatible signature secret key format
+ #[cfg(not(target_arch = "wasm32"))]
+ let to_check = {
+ let mut to_check = to_check;
+
+ let mut signature_priv = test_case.signature_priv.clone();
+
+ if is_edwards(test_case.context.cipher_suite) {
+ signature_priv.extend(test_case.signature_pub.iter());
+ }
+
+ let mut auth_content = AuthenticatedContent::new_signed(
+ &cs,
+ &test_case.context.clone().into(),
+ Sender::Member(1),
+ Content::Commit(alloc::boxed::Box::new(commit.clone())),
+ &signature_priv.into(),
+ WireFormat::PublicMessage,
+ vec![],
+ )
+ .await
+ .unwrap();
+
+ auth_content.auth.confirmation_tag = Some(ConfirmationTag::empty(&cs).await);
+
+ for enable_encryption in [true, false] {
+ let built = make_group(&test_case, true, enable_encryption, &cs)
+ .await
+ .format_for_wire(auth_content.clone())
+ .await
+ .unwrap()
+ .mls_encode_to_vec()
+ .unwrap();
+
+ to_check.push(built);
+ }
+
+ to_check
+ };
+
+ for message in to_check {
+ match process_message(&test_case, &message, &cs).await {
+ Content::Commit(c) => assert_eq!(&*c, &commit),
+ _ => panic!("received value not commit"),
+ };
+ }
+ let commit = Commit::mls_decode(&mut &*test_case.commit).unwrap();
+
+ match process_message(&test_case, &test_case.commit_priv.clone(), &cs).await {
+ Content::Commit(c) => assert_eq!(&*c, &commit),
+ _ => panic!("received value not commit"),
+ };
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(coverage_nightly, coverage(off))]
+async fn generate_framing_test_vector() -> Vec<FramingTestCase> {
+ let mut test_vector = vec![];
+
+ for cs in CipherSuite::all() {
+ let cs = test_cipher_suite_provider(cs);
+
+ let mut test_case = FramingTestCase::random(&cs).await;
+
+ // Generate private application message
+ test_case.application = cs.random_bytes_vec(42).unwrap();
+
+ let application_priv = make_group(&test_case, true, true, &cs)
+ .await
+ .encrypt_application_message(&test_case.application, vec![])
+ .await
+ .unwrap();
+
+ test_case.application_priv = application_priv.mls_encode_to_vec().unwrap();
+
+ // Generate private and public proposal message
+ let proposal = Proposal::Remove(RemoveProposal {
+ to_remove: LeafIndex(2),
+ });
+
+ test_case.proposal = proposal.mls_encode_to_vec().unwrap();
+
+ let mut group = make_group(&test_case, true, false, &cs).await;
+ let proposal_pub = group.proposal_message(proposal.clone(), vec![]).await;
+ test_case.proposal_pub = proposal_pub.unwrap().mls_encode_to_vec().unwrap();
+
+ let mut group = make_group(&test_case, true, true, &cs).await;
+ let proposal_priv = group.proposal_message(proposal, vec![]).await.unwrap();
+ test_case.proposal_priv = proposal_priv.mls_encode_to_vec().unwrap();
+
+ // Generate private and public commit message
+ let commit = Commit {
+ proposals: vec![],
+ path: None,
+ };
+
+ test_case.commit = commit.mls_encode_to_vec().unwrap();
+
+ let mut auth_content = AuthenticatedContent::new_signed(
+ &cs,
+ group.context(),
+ Sender::Member(1),
+ Content::Commit(alloc::boxed::Box::new(commit.clone())),
+ &group.signer,
+ WireFormat::PublicMessage,
+ vec![],
+ )
+ .await
+ .unwrap();
+
+ auth_content.auth.confirmation_tag = Some(ConfirmationTag::empty(&cs).await);
+
+ let mut group = make_group(&test_case, true, false, &cs).await;
+ let commit_pub = group.format_for_wire(auth_content.clone()).await.unwrap();
+ test_case.commit_pub = commit_pub.mls_encode_to_vec().unwrap();
+
+ let mut auth_content = AuthenticatedContent::new_signed(
+ &cs,
+ group.context(),
+ Sender::Member(1),
+ Content::Commit(alloc::boxed::Box::new(commit)),
+ &group.signer,
+ WireFormat::PrivateMessage,
+ vec![],
+ )
+ .await
+ .unwrap();
+
+ auth_content.auth.confirmation_tag = Some(ConfirmationTag::empty(&cs).await);
+
+ let mut group = make_group(&test_case, true, true, &cs).await;
+ let commit_priv = group.format_for_wire(auth_content.clone()).await.unwrap();
+ test_case.commit_priv = commit_priv.mls_encode_to_vec().unwrap();
+
+ test_vector.push(test_case);
+ }
+
+ test_vector
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+async fn make_group<P: CipherSuiteProvider>(
+ test_case: &FramingTestCase,
+ for_send: bool,
+ control_encryption_enabled: bool,
+ cs: &P,
+) -> Group<TestClientConfig> {
+ let mut group =
+ test_group_custom_config(
+ TEST_PROTOCOL_VERSION,
+ test_case.context.cipher_suite.into(),
+ |b| {
+ b.mls_rules(DefaultMlsRules::default().with_encryption_options(
+ EncryptionOptions::new(control_encryption_enabled, PaddingMode::None),
+ ))
+ },
+ )
+ .await
+ .group;
+
+ // Add a leaf for the sender. It will get index 1.
+ let mut leaf = get_basic_test_node(cs.cipher_suite(), "leaf").await;
+
+ leaf.signing_identity.signature_key = SignaturePublicKey::from(test_case.signature_pub.clone());
+
+ group
+ .state
+ .public_tree
+ .add_leaves(vec![leaf], &group.config.0.identity_provider, cs)
+ .await
+ .unwrap();
+
+ // Convince the group that their index is 1 if they send or 0 if they receive.
+ group.private_tree.self_index = LeafIndex(if for_send { 1 } else { 0 });
+
+ // Convince the group that their signing key is the one from the test case
+ let mut signature_priv = test_case.signature_priv.clone();
+
+ if is_edwards(test_case.context.cipher_suite) {
+ signature_priv.extend(test_case.signature_pub.iter());
+ }
+
+ group.signer = signature_priv.into();
+
+ // Set the group context and secrets
+ let context = GroupContext::from(test_case.context.clone());
+ let secret_tree = get_test_tree(test_case.encryption_secret.clone(), FRAMING_N_LEAVES);
+
+ let secrets = EpochSecrets {
+ secret_tree,
+ resumption_secret: vec![0_u8; cs.kdf_extract_size()].into(),
+ sender_data_secret: test_case.sender_data_secret.clone().into(),
+ };
+
+ group.epoch_secrets = secrets;
+ group.state.context = context;
+ let membership_key = test_case.membership_key.clone();
+ group.key_schedule.set_membership_key(membership_key);
+
+ group
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+async fn process_message<P: CipherSuiteProvider>(
+ test_case: &FramingTestCase,
+ message: &[u8],
+ cs: &P,
+) -> Content {
+ // Enabling encryption doesn't matter for processing
+ let mut group = make_group(test_case, false, true, cs).await;
+ let message = MlsMessage::mls_decode(&mut &*message).unwrap();
+ let evt_or_cont = group.get_event_from_incoming_message(message);
+
+ match evt_or_cont.await.unwrap() {
+ EventOrContent::Content(content) => content.content.content,
+ EventOrContent::Event(_) => panic!("expected content, got event"),
+ }
+}
diff --git a/src/group/interop_test_vectors/passive_client.rs b/src/group/interop_test_vectors/passive_client.rs
new file mode 100644
index 0000000..29588ed
--- /dev/null
+++ b/src/group/interop_test_vectors/passive_client.rs
@@ -0,0 +1,732 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::vec;
+use alloc::vec::Vec;
+
+use itertools::Itertools;
+use mls_rs_core::{
+ crypto::{CipherSuite, CipherSuiteProvider, CryptoProvider},
+ identity::SigningIdentity,
+ protocol_version::ProtocolVersion,
+ psk::ExternalPskId,
+ time::MlsTime,
+};
+use rand::{seq::IteratorRandom, Rng, SeedableRng};
+
+use crate::{
+ client_builder::{ClientBuilder, MlsConfig},
+ crypto::test_utils::TestCryptoProvider,
+ group::{ClientConfig, CommitBuilder, ExportedTree},
+ identity::basic::BasicIdentityProvider,
+ key_package::KeyPackageGeneration,
+ mls_rules::CommitOptions,
+ storage_provider::in_memory::InMemoryKeyPackageStorage,
+ test_utils::{
+ all_process_message, generate_basic_client, get_test_basic_credential, get_test_groups,
+ make_test_ext_psk, TEST_EXT_PSK_ID,
+ },
+ tree_kem::Lifetime,
+ Client, Group, MlsMessage,
+};
+
+const VERSION: ProtocolVersion = ProtocolVersion::MLS_10;
+
+const ETERNAL_LIFETIME: Lifetime = Lifetime {
+ not_before: 0,
+ not_after: u64::MAX,
+};
+
+#[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
+pub struct TestCase {
+ pub cipher_suite: u16,
+
+ pub external_psks: Vec<TestExternalPsk>,
+ #[serde(with = "hex::serde")]
+ pub key_package: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub signature_priv: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub encryption_priv: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub init_priv: Vec<u8>,
+
+ #[serde(with = "hex::serde")]
+ pub welcome: Vec<u8>,
+ pub ratchet_tree: Option<TestRatchetTree>,
+ #[serde(with = "hex::serde")]
+ pub initial_epoch_authenticator: Vec<u8>,
+
+ pub epochs: Vec<TestEpoch>,
+}
+
+#[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
+pub struct TestExternalPsk {
+ #[serde(with = "hex::serde")]
+ pub psk_id: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub psk: Vec<u8>,
+}
+
+#[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
+pub struct TestEpoch {
+ pub proposals: Vec<TestMlsMessage>,
+ #[serde(with = "hex::serde")]
+ pub commit: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub epoch_authenticator: Vec<u8>,
+}
+
+#[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
+pub struct TestMlsMessage(#[serde(with = "hex::serde")] pub Vec<u8>);
+
+#[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
+pub struct TestRatchetTree(#[serde(with = "hex::serde")] pub Vec<u8>);
+
+impl TestEpoch {
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ pub fn new(
+ proposals: Vec<MlsMessage>,
+ commit: &MlsMessage,
+ epoch_authenticator: Vec<u8>,
+ ) -> Self {
+ let proposals = proposals
+ .into_iter()
+ .map(|p| TestMlsMessage(p.to_bytes().unwrap()))
+ .collect();
+
+ Self {
+ proposals,
+ commit: commit.to_bytes().unwrap(),
+ epoch_authenticator,
+ }
+ }
+}
+
+#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+async fn interop_passive_client() {
+ // Test vectors can be found here:
+ // * https://github.com/mlswg/mls-implementations/blob/main/test-vectors/passive-client-welcome.json
+ // * https://github.com/mlswg/mls-implementations/blob/main/test-vectors/passive-client-handle-commit.json
+ // * https://github.com/mlswg/mls-implementations/blob/main/test-vectors/passive-client-random.json
+
+ #[cfg(mls_build_async)]
+ let (test_cases_wel, test_cases_com, test_cases_rand) = {
+ let test_cases_wel: Vec<TestCase> = load_test_case_json!(
+ interop_passive_client_welcome,
+ generate_passive_client_welcome_tests().await
+ );
+
+ let test_cases_com: Vec<TestCase> = load_test_case_json!(
+ interop_passive_client_handle_commit,
+ generate_passive_client_proposal_tests().await
+ );
+
+ let test_cases_rand: Vec<TestCase> = load_test_case_json!(
+ interop_passive_client_random,
+ generate_passive_client_random_tests().await
+ );
+
+ (test_cases_wel, test_cases_com, test_cases_rand)
+ };
+
+ #[cfg(not(mls_build_async))]
+ let (test_cases_wel, test_cases_com, test_cases_rand) = {
+ let test_cases_wel: Vec<TestCase> = load_test_case_json!(
+ interop_passive_client_welcome,
+ generate_passive_client_welcome_tests()
+ );
+
+ let test_cases_com: Vec<TestCase> = load_test_case_json!(
+ interop_passive_client_handle_commit,
+ generate_passive_client_proposal_tests()
+ );
+
+ let test_cases_rand: Vec<TestCase> = load_test_case_json!(
+ interop_passive_client_random,
+ generate_passive_client_random_tests()
+ );
+
+ (test_cases_wel, test_cases_com, test_cases_rand)
+ };
+
+ for test_case in vec![]
+ .into_iter()
+ .chain(test_cases_com)
+ .chain(test_cases_wel)
+ .chain(test_cases_rand)
+ {
+ let crypto_provider = TestCryptoProvider::new();
+ let Some(cs) = crypto_provider.cipher_suite_provider(test_case.cipher_suite.into()) else {
+ continue;
+ };
+
+ let message = MlsMessage::from_bytes(&test_case.key_package).unwrap();
+ let key_package = message.into_key_package().unwrap();
+ let id = key_package.leaf_node.signing_identity.clone();
+ let key = test_case.signature_priv.clone().into();
+
+ let mut client_builder = ClientBuilder::new()
+ .crypto_provider(crypto_provider)
+ .identity_provider(BasicIdentityProvider::new());
+
+ for psk in test_case.external_psks {
+ client_builder = client_builder.psk(ExternalPskId::new(psk.psk_id), psk.psk.into());
+ }
+
+ let client = client_builder
+ .signing_identity(id, key, cs.cipher_suite())
+ .build();
+
+ let key_pckg_gen = KeyPackageGeneration {
+ reference: key_package.to_reference(&cs).await.unwrap(),
+ key_package,
+ init_secret_key: test_case.init_priv.into(),
+ leaf_node_secret_key: test_case.encryption_priv.into(),
+ };
+
+ let (id, pkg) = key_pckg_gen.to_storage().unwrap();
+ client.config.key_package_repo().insert(id, pkg);
+
+ let welcome = MlsMessage::from_bytes(&test_case.welcome).unwrap();
+
+ let tree = test_case
+ .ratchet_tree
+ .map(|t| ExportedTree::from_bytes(&t.0).unwrap());
+
+ let (mut group, _info) = client.join_group(tree, &welcome).await.unwrap();
+
+ assert_eq!(
+ group.epoch_authenticator().unwrap().to_vec(),
+ test_case.initial_epoch_authenticator
+ );
+
+ for epoch in test_case.epochs {
+ for proposal in epoch.proposals.iter() {
+ let message = MlsMessage::from_bytes(&proposal.0).unwrap();
+
+ group
+ .process_incoming_message_with_time(message, MlsTime::now())
+ .await
+ .unwrap();
+ }
+
+ let message = MlsMessage::from_bytes(&epoch.commit).unwrap();
+
+ group
+ .process_incoming_message_with_time(message, MlsTime::now())
+ .await
+ .unwrap();
+
+ assert_eq!(
+ epoch.epoch_authenticator,
+ group.epoch_authenticator().unwrap().to_vec()
+ );
+ }
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(coverage_nightly, coverage(off))]
+async fn invite_passive_client<P: CipherSuiteProvider>(
+ groups: &mut [Group<impl MlsConfig>],
+ with_psk: bool,
+ cs: &P,
+) -> TestCase {
+ let crypto_provider = TestCryptoProvider::new();
+
+ let (secret_key, public_key) = cs.signature_key_generate().await.unwrap();
+ let credential = get_test_basic_credential(b"Arnold".to_vec());
+ let identity = SigningIdentity::new(credential, public_key);
+ let key_package_repo = InMemoryKeyPackageStorage::new();
+
+ let client = ClientBuilder::new()
+ .crypto_provider(crypto_provider)
+ .identity_provider(BasicIdentityProvider::new())
+ .key_package_repo(key_package_repo.clone())
+ .key_package_lifetime(ETERNAL_LIFETIME.not_after - ETERNAL_LIFETIME.not_before)
+ .key_package_not_before(ETERNAL_LIFETIME.not_before)
+ .signing_identity(identity.clone(), secret_key.clone(), cs.cipher_suite())
+ .build();
+
+ let key_pckg = client.generate_key_package_message().await.unwrap();
+
+ let (_, key_pckg_secrets) = key_package_repo.key_packages()[0].clone();
+
+ let mut commit_builder = groups[0]
+ .commit_builder()
+ .add_member(key_pckg.clone())
+ .unwrap();
+
+ if with_psk {
+ commit_builder = commit_builder
+ .add_external_psk(ExternalPskId::new(TEST_EXT_PSK_ID.to_vec()))
+ .unwrap();
+ }
+
+ let commit = commit_builder.build().await.unwrap();
+
+ all_process_message(groups, &commit.commit_message, 0, true).await;
+
+ let external_psk = TestExternalPsk {
+ psk_id: TEST_EXT_PSK_ID.to_vec(),
+ psk: make_test_ext_psk(),
+ };
+
+ TestCase {
+ cipher_suite: cs.cipher_suite().into(),
+ key_package: key_pckg.to_bytes().unwrap(),
+ encryption_priv: key_pckg_secrets.leaf_node_key.to_vec(),
+ init_priv: key_pckg_secrets.init_key.to_vec(),
+ welcome: commit.welcome_messages[0].to_bytes().unwrap(),
+ initial_epoch_authenticator: groups[0].epoch_authenticator().unwrap().to_vec(),
+ epochs: vec![],
+ signature_priv: secret_key.to_vec(),
+ external_psks: if with_psk { vec![external_psk] } else { vec![] },
+ ratchet_tree: None,
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(coverage_nightly, coverage(off))]
+pub async fn generate_passive_client_proposal_tests() -> Vec<TestCase> {
+ let mut test_cases: Vec<TestCase> = vec![];
+
+ for cs in CipherSuite::all() {
+ let crypto_provider = TestCryptoProvider::new();
+ let Some(cs) = crypto_provider.cipher_suite_provider(cs) else {
+ continue;
+ };
+
+ let mut groups =
+ get_test_groups(VERSION, cs.cipher_suite(), 7, None, false, &crypto_provider).await;
+
+ let mut partial_test_case = invite_passive_client(&mut groups, true, &cs).await;
+
+ // Create a new epoch s.t. the passive member can process resumption PSK from the current one
+ let commit = groups[0].commit(vec![]).await.unwrap();
+ all_process_message(&mut groups, &commit.commit_message, 0, true).await;
+
+ partial_test_case.epochs.push(TestEpoch::new(
+ vec![],
+ &commit.commit_message,
+ groups[0].epoch_authenticator().unwrap().to_vec(),
+ ));
+
+ let psk = ExternalPskId::new(TEST_EXT_PSK_ID.to_vec());
+ let key_pckg = create_key_package(cs.cipher_suite()).await;
+
+ // Create by value proposals
+ let test_case = commit_by_value(
+ &mut groups[3].clone(),
+ |b| b.add_member(key_pckg.clone()).unwrap(),
+ partial_test_case.clone(),
+ )
+ .await;
+
+ test_cases.push(test_case);
+
+ let test_case = commit_by_value(
+ &mut groups[3].clone(),
+ |b| b.remove_member(5).unwrap(),
+ partial_test_case.clone(),
+ )
+ .await;
+
+ test_cases.push(test_case);
+
+ let test_case = commit_by_value(
+ &mut groups[1].clone(),
+ |b| b.add_external_psk(psk.clone()).unwrap(),
+ partial_test_case.clone(),
+ )
+ .await;
+
+ test_cases.push(test_case);
+
+ let test_case = commit_by_value(
+ &mut groups[5].clone(),
+ |b| b.add_resumption_psk(groups[1].current_epoch() - 1).unwrap(),
+ partial_test_case.clone(),
+ )
+ .await;
+
+ test_cases.push(test_case);
+
+ let test_case = commit_by_value(
+ &mut groups[2].clone(),
+ |b| b.set_group_context_ext(Default::default()).unwrap(),
+ partial_test_case.clone(),
+ )
+ .await;
+
+ test_cases.push(test_case);
+
+ let test_case = commit_by_value(
+ &mut groups[3].clone(),
+ |b| {
+ b.add_member(key_pckg)
+ .unwrap()
+ .remove_member(5)
+ .unwrap()
+ .add_external_psk(psk.clone())
+ .unwrap()
+ .add_resumption_psk(groups[4].current_epoch() - 1)
+ .unwrap()
+ .set_group_context_ext(Default::default())
+ .unwrap()
+ },
+ partial_test_case.clone(),
+ )
+ .await;
+
+ test_cases.push(test_case);
+
+ // Create by reference proposals
+ let add = groups[0]
+ .propose_add(create_key_package(cs.cipher_suite()).await, vec![])
+ .await
+ .unwrap();
+
+ let add = (add, 0);
+
+ let update = (groups[1].propose_update(vec![]).await.unwrap(), 1);
+ let remove = (groups[2].propose_remove(2, vec![]).await.unwrap(), 2);
+
+ let ext_psk = groups[3]
+ .propose_external_psk(psk.clone(), vec![])
+ .await
+ .unwrap();
+
+ let ext_psk = (ext_psk, 3);
+
+ let last_ep = groups[3].current_epoch() - 1;
+
+ let res_psk = groups[3]
+ .propose_resumption_psk(last_ep, vec![])
+ .await
+ .unwrap();
+
+ let res_psk = (res_psk, 3);
+
+ let grp_ext = groups[4]
+ .propose_group_context_extensions(Default::default(), vec![])
+ .await
+ .unwrap();
+
+ let grp_ext = (grp_ext, 4);
+
+ let proposals = [add, update, remove, ext_psk, res_psk, grp_ext];
+
+ for (p, sender) in &proposals {
+ let mut groups = groups.clone();
+
+ all_process_message(&mut groups, p, *sender, false).await;
+
+ let commit = groups[5].commit(vec![]).await.unwrap().commit_message;
+
+ groups[5].apply_pending_commit().await.unwrap();
+ let auth = groups[5].epoch_authenticator().unwrap().to_vec();
+
+ let mut test_case = partial_test_case.clone();
+ let epoch = TestEpoch::new(vec![p.clone()], &commit, auth);
+ test_case.epochs.push(epoch);
+
+ test_cases.push(test_case);
+ }
+
+ let mut group = groups[4].clone();
+
+ for (p, _) in proposals.iter().filter(|(_, i)| *i != 4) {
+ group.process_incoming_message(p.clone()).await.unwrap();
+ }
+
+ let commit = group.commit(vec![]).await.unwrap().commit_message;
+ group.apply_pending_commit().await.unwrap();
+ let auth = group.epoch_authenticator().unwrap().to_vec();
+ let mut test_case = partial_test_case.clone();
+ let proposals = proposals.into_iter().map(|(p, _)| p).collect();
+ let epoch = TestEpoch::new(proposals, &commit, auth);
+ test_case.epochs.push(epoch);
+ test_cases.push(test_case);
+ }
+
+ test_cases
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(coverage_nightly, coverage(off))]
+async fn commit_by_value<F, C: MlsConfig>(
+ group: &mut Group<C>,
+ proposal_adder: F,
+ partial_test_case: TestCase,
+) -> TestCase
+where
+ F: FnOnce(CommitBuilder<C>) -> CommitBuilder<C>,
+{
+ let builder = proposal_adder(group.commit_builder());
+ let commit = builder.build().await.unwrap().commit_message;
+ group.apply_pending_commit().await.unwrap();
+ let auth = group.epoch_authenticator().unwrap().to_vec();
+ let epoch = TestEpoch::new(vec![], &commit, auth);
+ let mut test_case = partial_test_case;
+ test_case.epochs.push(epoch);
+ test_case
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(coverage_nightly, coverage(off))]
+async fn create_key_package(cs: CipherSuite) -> MlsMessage {
+ let client = generate_basic_client(
+ cs,
+ VERSION,
+ 0xbeef,
+ None,
+ false,
+ &TestCryptoProvider::new(),
+ Some(ETERNAL_LIFETIME),
+ )
+ .await;
+
+ client.generate_key_package_message().await.unwrap()
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(coverage_nightly, coverage(off))]
+pub async fn generate_passive_client_welcome_tests() -> Vec<TestCase> {
+ let mut test_cases: Vec<TestCase> = vec![];
+
+ for cs in CipherSuite::all() {
+ let crypto_provider = TestCryptoProvider::new();
+ let Some(cs) = crypto_provider.cipher_suite_provider(cs) else {
+ continue;
+ };
+
+ for with_tree_in_extension in [true, false] {
+ for (with_psk, with_path) in [false, true].into_iter().cartesian_product([true, false])
+ {
+ let options = CommitOptions::new()
+ .with_path_required(with_path)
+ .with_ratchet_tree_extension(with_tree_in_extension);
+
+ let mut groups = get_test_groups(
+ VERSION,
+ cs.cipher_suite(),
+ 16,
+ Some(options),
+ false,
+ &crypto_provider,
+ )
+ .await;
+
+ // Remove a member s.t. the passive member joins in their place
+ let proposal = groups[0].propose_remove(7, vec![]).await.unwrap();
+ all_process_message(&mut groups, &proposal, 0, false).await;
+
+ let mut test_case = invite_passive_client(&mut groups, with_psk, &cs).await;
+
+ if !with_tree_in_extension {
+ let tree = groups[0].export_tree().to_bytes().unwrap();
+ test_case.ratchet_tree = Some(TestRatchetTree(tree));
+ }
+
+ test_cases.push(test_case);
+ }
+ }
+ }
+
+ test_cases
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(coverage_nightly, coverage(off))]
+pub async fn generate_passive_client_random_tests() -> Vec<TestCase> {
+ let mut test_cases: Vec<TestCase> = vec![];
+
+ for cs in CipherSuite::all() {
+ let crypto = TestCryptoProvider::new();
+ let Some(csp) = crypto.cipher_suite_provider(cs) else {
+ continue;
+ };
+
+ let creator =
+ generate_basic_client(cs, VERSION, 0, None, false, &crypto, Some(ETERNAL_LIFETIME))
+ .await;
+
+ let creator_group = creator.create_group(Default::default()).await.unwrap();
+
+ let mut groups = vec![creator_group];
+
+ let mut new_clients = Vec::new();
+
+ for i in 0..10 {
+ new_clients.push(
+ generate_basic_client(
+ cs,
+ VERSION,
+ i + 1,
+ None,
+ false,
+ &crypto,
+ Some(ETERNAL_LIFETIME),
+ )
+ .await,
+ )
+ }
+
+ add_random_members(0, &mut groups, new_clients, None).await;
+
+ let mut test_case = invite_passive_client(&mut groups, false, &csp).await;
+
+ let passive_client_index = 11;
+
+ let seed: <rand::rngs::StdRng as SeedableRng>::Seed = rand::random();
+ let mut rng = rand::rngs::StdRng::from_seed(seed);
+ #[cfg(feature = "std")]
+ println!("generating random commits for seed {}", hex::encode(seed));
+
+ let mut next_free_idx = 11;
+ for _ in 0..100 {
+ // We keep the passive client and another member to send
+ let num_removed = rng.gen_range(0..groups.len() - 2);
+ let num_added = rng.gen_range(1..30);
+
+ let mut members = (0..groups.len())
+ .filter(|i| groups[*i].current_member_index() != passive_client_index)
+ .choose_multiple(&mut rng, num_removed + 1);
+
+ let sender = members.pop().unwrap();
+
+ remove_members(members, sender, &mut groups, Some(&mut test_case)).await;
+
+ let sender = (0..groups.len())
+ .filter(|i| groups[*i].current_member_index() != passive_client_index)
+ .choose(&mut rng)
+ .unwrap();
+
+ let mut new_clients = Vec::new();
+
+ for i in 0..num_added {
+ new_clients.push(
+ generate_basic_client(
+ cs,
+ VERSION,
+ next_free_idx + i,
+ None,
+ false,
+ &crypto,
+ Some(ETERNAL_LIFETIME),
+ )
+ .await,
+ );
+ }
+
+ add_random_members(sender, &mut groups, new_clients, Some(&mut test_case)).await;
+
+ next_free_idx += num_added;
+ }
+
+ test_cases.push(test_case);
+ }
+
+ test_cases
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(coverage_nightly, coverage(off))]
+pub async fn add_random_members<C: MlsConfig>(
+ committer: usize,
+ groups: &mut Vec<Group<C>>,
+ clients: Vec<Client<C>>,
+ test_case: Option<&mut TestCase>,
+) {
+ let committer_index = groups[committer].current_member_index() as usize;
+
+ let mut key_packages = Vec::new();
+
+ for client in &clients {
+ let key_package = client.generate_key_package_message().await.unwrap();
+ key_packages.push(key_package);
+ }
+
+ let mut add_proposals = Vec::new();
+
+ let committer_group = &mut groups[committer];
+
+ for key_package in key_packages {
+ add_proposals.push(
+ committer_group
+ .propose_add(key_package, vec![])
+ .await
+ .unwrap(),
+ );
+ }
+
+ for p in &add_proposals {
+ all_process_message(groups, p, committer_index, false).await;
+ }
+
+ let commit_output = groups[committer].commit(vec![]).await.unwrap();
+
+ all_process_message(groups, &commit_output.commit_message, committer_index, true).await;
+
+ let auth = groups[committer].epoch_authenticator().unwrap().to_vec();
+ let epoch = TestEpoch::new(add_proposals, &commit_output.commit_message, auth);
+
+ if let Some(tc) = test_case {
+ tc.epochs.push(epoch)
+ };
+
+ let tree_data = groups[committer].export_tree().into_owned();
+
+ for client in &clients {
+ let commit = commit_output.welcome_messages[0].clone();
+
+ let group = client
+ .join_group(Some(tree_data.clone()), &commit)
+ .await
+ .unwrap()
+ .0;
+
+ groups.push(group);
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(coverage_nightly, coverage(off))]
+pub async fn remove_members<C: MlsConfig>(
+ removed_members: Vec<usize>,
+ committer: usize,
+ groups: &mut Vec<Group<C>>,
+ test_case: Option<&mut TestCase>,
+) {
+ let remove_indexes = removed_members
+ .iter()
+ .map(|removed| groups[*removed].current_member_index())
+ .collect::<Vec<u32>>();
+
+ let mut commit_builder = groups[committer].commit_builder();
+
+ for index in remove_indexes {
+ commit_builder = commit_builder.remove_member(index).unwrap();
+ }
+
+ let commit = commit_builder.build().await.unwrap().commit_message;
+ let committer_index = groups[committer].current_member_index() as usize;
+ all_process_message(groups, &commit, committer_index, true).await;
+
+ let auth = groups[committer].epoch_authenticator().unwrap().to_vec();
+ let epoch = TestEpoch::new(vec![], &commit, auth);
+
+ if let Some(tc) = test_case {
+ tc.epochs.push(epoch)
+ };
+
+ let mut index = 0;
+
+ groups.retain(|_| {
+ index += 1;
+ !(removed_members.contains(&(index - 1)))
+ });
+}
diff --git a/src/group/interop_test_vectors/serialization.rs b/src/group/interop_test_vectors/serialization.rs
new file mode 100644
index 0000000..cbaf6fa
--- /dev/null
+++ b/src/group/interop_test_vectors/serialization.rs
@@ -0,0 +1,169 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::vec::Vec;
+use mls_rs_codec::{MlsDecode, MlsEncode};
+
+use mls_rs_core::extension::ExtensionList;
+
+use crate::{
+ group::{
+ framing::ContentType,
+ proposal::{
+ AddProposal, ExternalInit, PreSharedKeyProposal, ReInitProposal, RemoveProposal,
+ UpdateProposal,
+ },
+ Commit, GroupSecrets, MlsMessage,
+ },
+ tree_kem::node::NodeVec,
+};
+
+#[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
+struct TestCase {
+ #[serde(with = "hex::serde")]
+ mls_welcome: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ mls_group_info: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ mls_key_package: Vec<u8>,
+
+ #[serde(with = "hex::serde")]
+ ratchet_tree: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ group_secrets: Vec<u8>,
+
+ #[serde(with = "hex::serde")]
+ add_proposal: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ update_proposal: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ remove_proposal: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pre_shared_key_proposal: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ re_init_proposal: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ external_init_proposal: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ group_context_extensions_proposal: Vec<u8>,
+
+ #[serde(with = "hex::serde")]
+ commit: Vec<u8>,
+
+ #[serde(with = "hex::serde")]
+ public_message_application: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ public_message_proposal: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ public_message_commit: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ private_message: Vec<u8>,
+}
+
+// The test vector can be found here:
+// https://github.com/mlswg/mls-implementations/blob/main/test-vectors/messages.json
+#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+async fn serialization() {
+ let test_cases: Vec<TestCase> = load_test_case_json!(serialization, Vec::<TestCase>::new());
+
+ for test_case in test_cases.into_iter() {
+ let message = MlsMessage::from_bytes(&test_case.mls_welcome).unwrap();
+ message.clone().into_welcome().unwrap();
+ assert_eq!(&message.to_bytes().unwrap(), &test_case.mls_welcome);
+
+ let message = MlsMessage::from_bytes(&test_case.mls_group_info).unwrap();
+ message.clone().into_group_info().unwrap();
+ assert_eq!(&message.to_bytes().unwrap(), &test_case.mls_group_info);
+
+ let message = MlsMessage::from_bytes(&test_case.mls_key_package).unwrap();
+ message.clone().into_key_package().unwrap();
+ assert_eq!(&message.to_bytes().unwrap(), &test_case.mls_key_package);
+
+ let tree = NodeVec::mls_decode(&mut &*test_case.ratchet_tree).unwrap();
+
+ assert_eq!(&tree.mls_encode_to_vec().unwrap(), &test_case.ratchet_tree);
+
+ let secs = GroupSecrets::mls_decode(&mut &*test_case.group_secrets).unwrap();
+
+ assert_eq!(&secs.mls_encode_to_vec().unwrap(), &test_case.group_secrets);
+
+ let proposal = AddProposal::mls_decode(&mut &*test_case.add_proposal).unwrap();
+
+ assert_eq!(
+ &proposal.mls_encode_to_vec().unwrap(),
+ &test_case.add_proposal
+ );
+
+ let proposal = UpdateProposal::mls_decode(&mut &*test_case.update_proposal).unwrap();
+
+ assert_eq!(
+ &proposal.mls_encode_to_vec().unwrap(),
+ &test_case.update_proposal
+ );
+
+ let proposal = RemoveProposal::mls_decode(&mut &*test_case.remove_proposal).unwrap();
+
+ assert_eq!(
+ &proposal.mls_encode_to_vec().unwrap(),
+ &test_case.remove_proposal
+ );
+
+ let proposal = ReInitProposal::mls_decode(&mut &*test_case.re_init_proposal).unwrap();
+
+ assert_eq!(
+ &proposal.mls_encode_to_vec().unwrap(),
+ &test_case.re_init_proposal
+ );
+
+ let proposal =
+ PreSharedKeyProposal::mls_decode(&mut &*test_case.pre_shared_key_proposal).unwrap();
+
+ assert_eq!(
+ &proposal.mls_encode_to_vec().unwrap(),
+ &test_case.pre_shared_key_proposal
+ );
+
+ let proposal = ExternalInit::mls_decode(&mut &*test_case.external_init_proposal).unwrap();
+
+ assert_eq!(
+ &proposal.mls_encode_to_vec().unwrap(),
+ &test_case.external_init_proposal
+ );
+
+ let proposal =
+ ExtensionList::mls_decode(&mut &*test_case.group_context_extensions_proposal).unwrap();
+
+ assert_eq!(
+ &proposal.mls_encode_to_vec().unwrap(),
+ &test_case.group_context_extensions_proposal
+ );
+
+ let commit = Commit::mls_decode(&mut &*test_case.commit).unwrap();
+
+ assert_eq!(&commit.mls_encode_to_vec().unwrap(), &test_case.commit);
+
+ let message = MlsMessage::from_bytes(&test_case.public_message_application).unwrap();
+ let serialized = message.mls_encode_to_vec().unwrap();
+ assert_eq!(&serialized, &test_case.public_message_application);
+ let content_type = message.into_plaintext().unwrap().content.content_type();
+ assert_eq!(content_type, ContentType::Application);
+
+ let message = MlsMessage::from_bytes(&test_case.public_message_proposal).unwrap();
+ let serialized = message.mls_encode_to_vec().unwrap();
+ assert_eq!(&serialized, &test_case.public_message_proposal);
+ let content_type = message.into_plaintext().unwrap().content.content_type();
+ assert_eq!(content_type, ContentType::Proposal);
+
+ let message = MlsMessage::from_bytes(&test_case.public_message_commit).unwrap();
+ let serialized = message.mls_encode_to_vec().unwrap();
+ assert_eq!(&serialized, &test_case.public_message_commit);
+ let content_type = message.into_plaintext().unwrap().content.content_type();
+ assert_eq!(content_type, ContentType::Commit);
+
+ let message = MlsMessage::from_bytes(&test_case.private_message).unwrap();
+ let serialized = message.mls_encode_to_vec().unwrap();
+ assert_eq!(&serialized, &test_case.private_message);
+ message.into_ciphertext().unwrap();
+ }
+}
diff --git a/src/group/interop_test_vectors/tree_kem.rs b/src/group/interop_test_vectors/tree_kem.rs
new file mode 100644
index 0000000..0a04312
--- /dev/null
+++ b/src/group/interop_test_vectors/tree_kem.rs
@@ -0,0 +1,185 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use crate::{
+ client::test_utils::TEST_PROTOCOL_VERSION,
+ crypto::test_utils::try_test_cipher_suite_provider,
+ group::{
+ confirmation_tag::ConfirmationTag, framing::Content, message_processor::MessageProcessor,
+ message_signature::AuthenticatedContent, test_utils::GroupWithoutKeySchedule, Commit,
+ GroupContext, PathSecret, Sender,
+ },
+ identity::basic::BasicIdentityProvider,
+ tree_kem::{
+ node::{LeafIndex, NodeVec},
+ TreeKemPrivate, TreeKemPublic, UpdatePath,
+ },
+ WireFormat,
+};
+use alloc::vec;
+use alloc::vec::Vec;
+use mls_rs_codec::MlsDecode;
+use mls_rs_core::{crypto::CipherSuiteProvider, extension::ExtensionList};
+
+#[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
+struct TreeKemTestCase {
+ pub cipher_suite: u16,
+
+ #[serde(with = "hex::serde")]
+ pub group_id: Vec<u8>,
+ epoch: u64,
+ #[serde(with = "hex::serde")]
+ confirmed_transcript_hash: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ ratchet_tree: Vec<u8>,
+
+ leaves_private: Vec<TestLeafPrivate>,
+ update_paths: Vec<TestUpdatePath>,
+}
+
+#[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
+struct TestLeafPrivate {
+ index: u32,
+ #[serde(with = "hex::serde")]
+ encryption_priv: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ signature_priv: Vec<u8>,
+ path_secrets: Vec<TestPathSecretPrivate>,
+}
+
+#[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
+struct TestPathSecretPrivate {
+ node: u32,
+ #[serde(with = "hex::serde")]
+ path_secret: Vec<u8>,
+}
+
+#[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
+struct TestUpdatePath {
+ sender: u32,
+ #[serde(with = "hex::serde")]
+ update_path: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ tree_hash_after: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ commit_secret: Vec<u8>,
+}
+
+#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+async fn tree_kem() {
+ // The test vector can be found here https://github.com/mlswg/mls-implementations/blob/main/test-vectors/treekem.json
+
+ let test_cases: Vec<TreeKemTestCase> =
+ load_test_case_json!(interop_tree_kem, Vec::<TreeKemTestCase>::new());
+
+ for test_case in test_cases {
+ let Some(cs) = try_test_cipher_suite_provider(test_case.cipher_suite) else {
+ continue;
+ };
+
+ // Import the public ratchet tree
+ let nodes = NodeVec::mls_decode(&mut &*test_case.ratchet_tree).unwrap();
+
+ let mut tree =
+ TreeKemPublic::import_node_data(nodes, &BasicIdentityProvider, &Default::default())
+ .await
+ .unwrap();
+
+ // Construct GroupContext
+ let group_context = GroupContext {
+ protocol_version: TEST_PROTOCOL_VERSION,
+ cipher_suite: cs.cipher_suite(),
+ group_id: test_case.group_id,
+ epoch: test_case.epoch,
+ tree_hash: tree.tree_hash(&cs).await.unwrap(),
+ confirmed_transcript_hash: test_case.confirmed_transcript_hash.into(),
+ extensions: ExtensionList::new(),
+ };
+
+ for leaf in test_case.leaves_private.iter() {
+ // Construct the private ratchet tree
+ let mut tree_private = TreeKemPrivate::new(LeafIndex(leaf.index));
+
+ // Set and validate HPKE keys on direct path
+ let path = tree.nodes.direct_copath(tree_private.self_index);
+
+ tree_private.secret_keys = Vec::new();
+
+ for dp in path {
+ let dp = dp.path;
+
+ let secret = leaf
+ .path_secrets
+ .iter()
+ .find_map(|s| (s.node == dp).then_some(s.path_secret.clone()));
+
+ let private_key = if let Some(secret) = secret {
+ let (secret_key, public_key) = PathSecret::from(secret)
+ .to_hpke_key_pair(&cs)
+ .await
+ .unwrap();
+
+ let tree_public = &tree.nodes.borrow_as_parent(dp).unwrap().public_key;
+ assert_eq!(&public_key, tree_public);
+
+ Some(secret_key)
+ } else {
+ None
+ };
+
+ tree_private.secret_keys.push(private_key);
+ }
+
+ // Set HPKE key for leaf
+ tree_private
+ .secret_keys
+ .insert(0, Some(leaf.encryption_priv.clone().into()));
+
+ let paths = test_case
+ .update_paths
+ .iter()
+ .filter(|path| path.sender != leaf.index);
+
+ for update_path in paths {
+ let mut group = GroupWithoutKeySchedule::new(cs.cipher_suite()).await;
+ group.state.context = group_context.clone();
+ group.state.public_tree = tree.clone();
+ group.private_tree = tree_private.clone();
+
+ let path = UpdatePath::mls_decode(&mut &*update_path.update_path).unwrap();
+
+ let commit = Commit {
+ proposals: vec![],
+ path: Some(path),
+ };
+
+ let mut auth_content = AuthenticatedContent::new(
+ &group_context,
+ Sender::Member(update_path.sender),
+ Content::Commit(alloc::boxed::Box::new(commit)),
+ vec![],
+ WireFormat::PublicMessage,
+ );
+
+ auth_content.auth.confirmation_tag = Some(ConfirmationTag::empty(&cs).await);
+
+ // Hack not to increment epoch
+ group.state.context.epoch -= 1;
+
+ group.process_commit(auth_content, None).await.unwrap();
+
+ // Check that we got the expected commit secret and correctly merged the update path.
+ // This implies that we computed the path secrets correctly.
+ let commit_secret = group.secrets.unwrap().1;
+
+ assert_eq!(&*commit_secret, &update_path.commit_secret);
+
+ let new_tree = &mut group.provisional_public_state.unwrap().public_tree;
+ let new_tree_hash = new_tree.tree_hash(&cs).await.unwrap();
+
+ assert_eq!(&new_tree_hash, &update_path.tree_hash_after);
+ }
+ }
+ }
+}
diff --git a/src/group/interop_test_vectors/tree_modifications.rs b/src/group/interop_test_vectors/tree_modifications.rs
new file mode 100644
index 0000000..a172e0c
--- /dev/null
+++ b/src/group/interop_test_vectors/tree_modifications.rs
@@ -0,0 +1,177 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::boxed::Box;
+use alloc::vec;
+use alloc::vec::Vec;
+use mls_rs_codec::{MlsDecode, MlsEncode};
+use mls_rs_core::crypto::CipherSuite;
+
+use crate::{
+ client::test_utils::{TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
+ crypto::test_utils::{test_cipher_suite_provider, try_test_cipher_suite_provider},
+ group::{
+ proposal::{AddProposal, Proposal, ProposalOrRef, RemoveProposal, UpdateProposal},
+ proposal_cache::test_utils::CommitReceiver,
+ proposal_ref::ProposalRef,
+ test_utils::TEST_GROUP,
+ LeafIndex, Sender, TreeKemPublic,
+ },
+ identity::basic::BasicIdentityProvider,
+ key_package::test_utils::test_key_package,
+ tree_kem::{
+ leaf_node::test_utils::default_properties, node::NodeVec, test_utils::TreeWithSigners,
+ },
+};
+
+#[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
+struct TreeModsTestCase {
+ #[serde(with = "hex::serde")]
+ pub tree_before: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub proposal: Vec<u8>,
+ pub proposal_sender: u32,
+ #[serde(with = "hex::serde")]
+ pub tree_after: Vec<u8>,
+}
+
+impl TreeModsTestCase {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ async fn new(tree_before: TreeKemPublic, proposal: Proposal, proposal_sender: u32) -> Self {
+ let tree_after = apply_proposal(proposal.clone(), proposal_sender, &tree_before).await;
+
+ Self {
+ tree_before: tree_before.nodes.mls_encode_to_vec().unwrap(),
+ proposal: proposal.mls_encode_to_vec().unwrap(),
+ tree_after: tree_after.nodes.mls_encode_to_vec().unwrap(),
+ proposal_sender,
+ }
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(coverage_nightly, coverage(off))]
+async fn generate_tree_mods_tests() -> Vec<TreeModsTestCase> {
+ let mut test_vector = vec![];
+ let cs = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ // Update
+ let tree_before = TreeWithSigners::make_full_tree(8, &cs).await;
+ let update = generate_update(6, &tree_before).await;
+ test_vector.push(TreeModsTestCase::new(tree_before.tree, update, 6).await);
+
+ // Add in the middle
+ let mut tree_before = TreeWithSigners::make_full_tree(6, &cs).await;
+ tree_before.remove_member(3);
+ test_vector.push(TreeModsTestCase::new(tree_before.tree, generate_add().await, 2).await);
+
+ // Add at the end
+ let tree_before = TreeWithSigners::make_full_tree(6, &cs).await;
+ test_vector.push(TreeModsTestCase::new(tree_before.tree, generate_add().await, 2).await);
+
+ // Add at the end, tree grows
+ let tree_before = TreeWithSigners::make_full_tree(8, &cs).await;
+ test_vector.push(TreeModsTestCase::new(tree_before.tree, generate_add().await, 2).await);
+
+ // Remove in the middle
+ let tree_before = TreeWithSigners::make_full_tree(8, &cs).await;
+ test_vector.push(TreeModsTestCase::new(tree_before.tree, generate_remove(2), 2).await);
+
+ // Remove at the end
+ let tree_before = TreeWithSigners::make_full_tree(8, &cs).await;
+ test_vector.push(TreeModsTestCase::new(tree_before.tree, generate_remove(7), 2).await);
+
+ // Remove at the end, tree shrinks
+ let tree_before = TreeWithSigners::make_full_tree(9, &cs).await;
+ test_vector.push(TreeModsTestCase::new(tree_before.tree, generate_remove(8), 2).await);
+
+ test_vector
+}
+
+#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+async fn tree_modifications_interop() {
+ // The test vector can be found here https://github.com/mlswg/mls-implementations/blob/main/test-vectors/tree-operations.json
+
+ // All test vectors use cipher suite 1
+ if try_test_cipher_suite_provider(*CipherSuite::CURVE25519_AES128).is_none() {
+ return;
+ }
+
+ #[cfg(not(mls_build_async))]
+ let test_cases: Vec<TreeModsTestCase> =
+ load_test_case_json!(tree_modifications_interop, generate_tree_mods_tests());
+
+ #[cfg(mls_build_async)]
+ let test_cases: Vec<TreeModsTestCase> =
+ load_test_case_json!(tree_modifications_interop, generate_tree_mods_tests().await);
+
+ for test_case in test_cases.into_iter() {
+ let nodes = NodeVec::mls_decode(&mut &*test_case.tree_before).unwrap();
+
+ let tree_before =
+ TreeKemPublic::import_node_data(nodes, &BasicIdentityProvider, &Default::default())
+ .await
+ .unwrap();
+
+ let proposal = Proposal::mls_decode(&mut &*test_case.proposal).unwrap();
+
+ let tree_after = apply_proposal(proposal, test_case.proposal_sender, &tree_before).await;
+
+ let tree_after = tree_after.nodes.mls_encode_to_vec().unwrap();
+
+ assert_eq!(tree_after, test_case.tree_after);
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+async fn apply_proposal(
+ proposal: Proposal,
+ sender: u32,
+ tree_before: &TreeKemPublic,
+) -> TreeKemPublic {
+ let cs = test_cipher_suite_provider(CipherSuite::CURVE25519_AES128);
+ let p_ref = ProposalRef::new_fake(b"fake ref".to_vec());
+
+ CommitReceiver::new(tree_before, Sender::Member(0), LeafIndex(1), cs)
+ .cache(p_ref.clone(), proposal, Sender::Member(sender))
+ .receive(vec![ProposalOrRef::Reference(p_ref)])
+ .await
+ .unwrap()
+ .public_tree
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(coverage_nightly, coverage(off))]
+async fn generate_add() -> Proposal {
+ let key_package = test_key_package(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "Roger").await;
+ Proposal::Add(Box::new(AddProposal { key_package }))
+}
+
+#[cfg_attr(coverage_nightly, coverage(off))]
+fn generate_remove(i: u32) -> Proposal {
+ let to_remove = LeafIndex(i);
+ Proposal::Remove(RemoveProposal { to_remove })
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(coverage_nightly, coverage(off))]
+async fn generate_update(i: u32, tree: &TreeWithSigners) -> Proposal {
+ let signer = tree.signers[i as usize].as_ref().unwrap();
+ let mut leaf_node = tree.tree.get_leaf_node(LeafIndex(i)).unwrap().clone();
+
+ leaf_node
+ .update(
+ &test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ TEST_GROUP,
+ i,
+ default_properties(),
+ None,
+ signer,
+ )
+ .await
+ .unwrap();
+
+ Proposal::Update(UpdateProposal { leaf_node })
+}
diff --git a/src/group/key_schedule.rs b/src/group/key_schedule.rs
new file mode 100644
index 0000000..77c1d65
--- /dev/null
+++ b/src/group/key_schedule.rs
@@ -0,0 +1,988 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use crate::client::MlsError;
+use crate::extension::ExternalPubExt;
+use crate::group::{GroupContext, MembershipTag};
+use crate::psk::secret::PskSecret;
+#[cfg(feature = "psk")]
+use crate::psk::PreSharedKey;
+use crate::tree_kem::path_secret::PathSecret;
+use crate::CipherSuiteProvider;
+
+#[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+use crate::group::SecretTree;
+
+use alloc::vec;
+use alloc::vec::Vec;
+use core::fmt::{self, Debug};
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+use mls_rs_core::error::IntoAnyError;
+use zeroize::Zeroizing;
+
+use crate::crypto::{HpkeContextR, HpkeContextS, HpkePublicKey, HpkeSecretKey};
+
+use super::epoch::{EpochSecrets, SenderDataSecret};
+use super::message_signature::AuthenticatedContent;
+
+#[derive(Clone, PartialEq, Eq, Default, MlsEncode, MlsDecode, MlsSize)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub struct KeySchedule {
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))]
+ exporter_secret: Zeroizing<Vec<u8>>,
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))]
+ pub authentication_secret: Zeroizing<Vec<u8>>,
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))]
+ external_secret: Zeroizing<Vec<u8>>,
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))]
+ membership_key: Zeroizing<Vec<u8>>,
+ init_secret: InitSecret,
+}
+
+impl Debug for KeySchedule {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("KeySchedule")
+ .field(
+ "exporter_secret",
+ &mls_rs_core::debug::pretty_bytes(&self.exporter_secret),
+ )
+ .field(
+ "authentication_secret",
+ &mls_rs_core::debug::pretty_bytes(&self.authentication_secret),
+ )
+ .field(
+ "external_secret",
+ &mls_rs_core::debug::pretty_bytes(&self.external_secret),
+ )
+ .field(
+ "membership_key",
+ &mls_rs_core::debug::pretty_bytes(&self.membership_key),
+ )
+ .field("init_secret", &self.init_secret)
+ .finish()
+ }
+}
+
+pub(crate) struct KeyScheduleDerivationResult {
+ pub(crate) key_schedule: KeySchedule,
+ pub(crate) confirmation_key: Zeroizing<Vec<u8>>,
+ pub(crate) joiner_secret: JoinerSecret,
+ pub(crate) epoch_secrets: EpochSecrets,
+}
+
+impl KeySchedule {
+ pub fn new(init_secret: InitSecret) -> Self {
+ KeySchedule {
+ init_secret,
+ ..Default::default()
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn derive_for_external<P: CipherSuiteProvider>(
+ &self,
+ kem_output: &[u8],
+ cipher_suite: &P,
+ ) -> Result<KeySchedule, MlsError> {
+ let (secret, public) = self.get_external_key_pair(cipher_suite).await?;
+
+ let init_secret =
+ InitSecret::decode_for_external(cipher_suite, kem_output, &secret, &public).await?;
+
+ Ok(KeySchedule::new(init_secret))
+ }
+
+ /// Returns the derived epoch as well as the joiner secret required for building welcome
+ /// messages
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn from_key_schedule<P: CipherSuiteProvider>(
+ last_key_schedule: &KeySchedule,
+ commit_secret: &PathSecret,
+ context: &GroupContext,
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ secret_tree_size: u32,
+ psk_secret: &PskSecret,
+ cipher_suite_provider: &P,
+ ) -> Result<KeyScheduleDerivationResult, MlsError> {
+ let joiner_seed = cipher_suite_provider
+ .kdf_extract(&last_key_schedule.init_secret.0, commit_secret)
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
+
+ let joiner_secret = kdf_expand_with_label(
+ cipher_suite_provider,
+ &joiner_seed,
+ b"joiner",
+ &context.mls_encode_to_vec()?,
+ None,
+ )
+ .await?
+ .into();
+
+ let key_schedule_result = Self::from_joiner(
+ cipher_suite_provider,
+ &joiner_secret,
+ context,
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ secret_tree_size,
+ psk_secret,
+ )
+ .await?;
+
+ Ok(KeyScheduleDerivationResult {
+ key_schedule: key_schedule_result.key_schedule,
+ confirmation_key: key_schedule_result.confirmation_key,
+ joiner_secret,
+ epoch_secrets: key_schedule_result.epoch_secrets,
+ })
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn from_joiner<P: CipherSuiteProvider>(
+ cipher_suite_provider: &P,
+ joiner_secret: &JoinerSecret,
+ context: &GroupContext,
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ secret_tree_size: u32,
+ psk_secret: &PskSecret,
+ ) -> Result<KeyScheduleDerivationResult, MlsError> {
+ let epoch_seed =
+ get_pre_epoch_secret(cipher_suite_provider, psk_secret, joiner_secret).await?;
+ let context = context.mls_encode_to_vec()?;
+
+ let epoch_secret =
+ kdf_expand_with_label(cipher_suite_provider, &epoch_seed, b"epoch", &context, None)
+ .await?;
+
+ Self::from_epoch_secret(
+ cipher_suite_provider,
+ &epoch_secret,
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ secret_tree_size,
+ )
+ .await
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn from_random_epoch_secret<P: CipherSuiteProvider>(
+ cipher_suite_provider: &P,
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ secret_tree_size: u32,
+ ) -> Result<KeyScheduleDerivationResult, MlsError> {
+ let epoch_secret = cipher_suite_provider
+ .random_bytes_vec(cipher_suite_provider.kdf_extract_size())
+ .map(Zeroizing::new)
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
+
+ Self::from_epoch_secret(
+ cipher_suite_provider,
+ &epoch_secret,
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ secret_tree_size,
+ )
+ .await
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn from_epoch_secret<P: CipherSuiteProvider>(
+ cipher_suite_provider: &P,
+ epoch_secret: &[u8],
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ secret_tree_size: u32,
+ ) -> Result<KeyScheduleDerivationResult, MlsError> {
+ let secrets_producer = SecretsProducer::new(cipher_suite_provider, epoch_secret);
+
+ let epoch_secrets = EpochSecrets {
+ #[cfg(feature = "psk")]
+ resumption_secret: PreSharedKey::from(secrets_producer.derive(b"resumption").await?),
+ sender_data_secret: SenderDataSecret::from(
+ secrets_producer.derive(b"sender data").await?,
+ ),
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ secret_tree: SecretTree::new(
+ secret_tree_size,
+ secrets_producer.derive(b"encryption").await?,
+ ),
+ };
+
+ let key_schedule = Self {
+ exporter_secret: secrets_producer.derive(b"exporter").await?,
+ authentication_secret: secrets_producer.derive(b"authentication").await?,
+ external_secret: secrets_producer.derive(b"external").await?,
+ membership_key: secrets_producer.derive(b"membership").await?,
+ init_secret: InitSecret(secrets_producer.derive(b"init").await?),
+ };
+
+ Ok(KeyScheduleDerivationResult {
+ key_schedule,
+ confirmation_key: secrets_producer.derive(b"confirm").await?,
+ joiner_secret: Zeroizing::new(vec![]).into(),
+ epoch_secrets,
+ })
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn export_secret<P: CipherSuiteProvider>(
+ &self,
+ label: &[u8],
+ context: &[u8],
+ len: usize,
+ cipher_suite: &P,
+ ) -> Result<Zeroizing<Vec<u8>>, MlsError> {
+ let secret = kdf_derive_secret(cipher_suite, &self.exporter_secret, label).await?;
+
+ let context_hash = cipher_suite
+ .hash(context)
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
+
+ kdf_expand_with_label(cipher_suite, &secret, b"exported", &context_hash, Some(len)).await
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn get_membership_tag<P: CipherSuiteProvider>(
+ &self,
+ content: &AuthenticatedContent,
+ context: &GroupContext,
+ cipher_suite_provider: &P,
+ ) -> Result<MembershipTag, MlsError> {
+ MembershipTag::create(
+ content,
+ context,
+ &self.membership_key,
+ cipher_suite_provider,
+ )
+ .await
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn get_external_key_pair<P: CipherSuiteProvider>(
+ &self,
+ cipher_suite: &P,
+ ) -> Result<(HpkeSecretKey, HpkePublicKey), MlsError> {
+ cipher_suite
+ .kem_derive(&self.external_secret)
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn get_external_key_pair_ext<P: CipherSuiteProvider>(
+ &self,
+ cipher_suite: &P,
+ ) -> Result<ExternalPubExt, MlsError> {
+ let (_external_secret, external_pub) = self.get_external_key_pair(cipher_suite).await?;
+
+ Ok(ExternalPubExt { external_pub })
+ }
+}
+
+#[derive(MlsEncode, MlsSize)]
+struct Label<'a> {
+ length: u16,
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ label: Vec<u8>,
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ context: &'a [u8],
+}
+
+impl<'a> Label<'a> {
+ fn new(length: u16, label: &'a [u8], context: &'a [u8]) -> Self {
+ Self {
+ length,
+ label: [b"MLS 1.0 ", label].concat(),
+ context,
+ }
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(crate) async fn kdf_expand_with_label<P: CipherSuiteProvider>(
+ cipher_suite_provider: &P,
+ secret: &[u8],
+ label: &[u8],
+ context: &[u8],
+ len: Option<usize>,
+) -> Result<Zeroizing<Vec<u8>>, MlsError> {
+ let extract_size = cipher_suite_provider.kdf_extract_size();
+ let len = len.unwrap_or(extract_size);
+ let label = Label::new(len as u16, label, context);
+
+ cipher_suite_provider
+ .kdf_expand(secret, &label.mls_encode_to_vec()?, len)
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(crate) async fn kdf_derive_secret<P: CipherSuiteProvider>(
+ cipher_suite_provider: &P,
+ secret: &[u8],
+ label: &[u8],
+) -> Result<Zeroizing<Vec<u8>>, MlsError> {
+ kdf_expand_with_label(cipher_suite_provider, secret, label, &[], None).await
+}
+
+#[derive(Clone, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+pub(crate) struct JoinerSecret(#[mls_codec(with = "mls_rs_codec::byte_vec")] Zeroizing<Vec<u8>>);
+
+impl Debug for JoinerSecret {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ mls_rs_core::debug::pretty_bytes(&self.0)
+ .named("JoinerSecret")
+ .fmt(f)
+ }
+}
+
+impl From<Zeroizing<Vec<u8>>> for JoinerSecret {
+ fn from(bytes: Zeroizing<Vec<u8>>) -> Self {
+ Self(bytes)
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(crate) async fn get_pre_epoch_secret<P: CipherSuiteProvider>(
+ cipher_suite_provider: &P,
+ psk_secret: &PskSecret,
+ joiner_secret: &JoinerSecret,
+) -> Result<Zeroizing<Vec<u8>>, MlsError> {
+ cipher_suite_provider
+ .kdf_extract(&joiner_secret.0, psk_secret)
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
+}
+
+struct SecretsProducer<'a, P: CipherSuiteProvider> {
+ cipher_suite_provider: &'a P,
+ epoch_secret: &'a [u8],
+}
+
+impl<'a, P: CipherSuiteProvider> SecretsProducer<'a, P> {
+ fn new(cipher_suite_provider: &'a P, epoch_secret: &'a [u8]) -> Self {
+ Self {
+ cipher_suite_provider,
+ epoch_secret,
+ }
+ }
+
+ // TODO document somewhere in the crypto provider that the RFC defines the length of all secrets as
+ // KDF extract size but then inputs secrets as MAC keys etc, therefore, we require that these
+ // lengths match in the crypto provider
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn derive(&self, label: &[u8]) -> Result<Zeroizing<Vec<u8>>, MlsError> {
+ kdf_derive_secret(self.cipher_suite_provider, self.epoch_secret, label).await
+ }
+}
+
+const EXPORTER_CONTEXT: &[u8] = b"MLS 1.0 external init secret";
+
+#[derive(Clone, Eq, PartialEq, MlsEncode, MlsDecode, MlsSize, Default)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub struct InitSecret(
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))]
+ Zeroizing<Vec<u8>>,
+);
+
+impl Debug for InitSecret {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ mls_rs_core::debug::pretty_bytes(&self.0)
+ .named("InitSecret")
+ .fmt(f)
+ }
+}
+
+impl InitSecret {
+ /// Returns init secret and KEM output to be used when creating an external commit.
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn encode_for_external<P: CipherSuiteProvider>(
+ cipher_suite: &P,
+ external_pub: &HpkePublicKey,
+ ) -> Result<(Self, Vec<u8>), MlsError> {
+ let (kem_output, context) = cipher_suite
+ .hpke_setup_s(external_pub, &[])
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
+
+ let init_secret = context
+ .export(EXPORTER_CONTEXT, cipher_suite.kdf_extract_size())
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
+
+ Ok((InitSecret(Zeroizing::new(init_secret)), kem_output))
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn decode_for_external<P: CipherSuiteProvider>(
+ cipher_suite: &P,
+ kem_output: &[u8],
+ external_secret: &HpkeSecretKey,
+ external_pub: &HpkePublicKey,
+ ) -> Result<Self, MlsError> {
+ let context = cipher_suite
+ .hpke_setup_r(kem_output, external_secret, external_pub, &[])
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
+
+ context
+ .export(EXPORTER_CONTEXT, cipher_suite.kdf_extract_size())
+ .await
+ .map(Zeroizing::new)
+ .map(InitSecret)
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
+ }
+}
+
+pub(crate) struct WelcomeSecret<'a, P: CipherSuiteProvider> {
+ cipher_suite: &'a P,
+ key: Zeroizing<Vec<u8>>,
+ nonce: Zeroizing<Vec<u8>>,
+}
+
+impl<'a, P: CipherSuiteProvider> WelcomeSecret<'a, P> {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn from_joiner_secret(
+ cipher_suite: &'a P,
+ joiner_secret: &JoinerSecret,
+ psk_secret: &PskSecret,
+ ) -> Result<WelcomeSecret<'a, P>, MlsError> {
+ let welcome_secret = get_welcome_secret(cipher_suite, joiner_secret, psk_secret).await?;
+
+ let key_len = cipher_suite.aead_key_size();
+ let key = kdf_expand_with_label(cipher_suite, &welcome_secret, b"key", &[], Some(key_len))
+ .await?;
+
+ let nonce_len = cipher_suite.aead_nonce_size();
+
+ let nonce = kdf_expand_with_label(
+ cipher_suite,
+ &welcome_secret,
+ b"nonce",
+ &[],
+ Some(nonce_len),
+ )
+ .await?;
+
+ Ok(Self {
+ cipher_suite,
+ key,
+ nonce,
+ })
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn encrypt(&self, plaintext: &[u8]) -> Result<Vec<u8>, MlsError> {
+ self.cipher_suite
+ .aead_seal(&self.key, plaintext, None, &self.nonce)
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn decrypt(&self, ciphertext: &[u8]) -> Result<Zeroizing<Vec<u8>>, MlsError> {
+ self.cipher_suite
+ .aead_open(&self.key, ciphertext, None, &self.nonce)
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+async fn get_welcome_secret<P: CipherSuiteProvider>(
+ cipher_suite: &P,
+ joiner_secret: &JoinerSecret,
+ psk_secret: &PskSecret,
+) -> Result<Zeroizing<Vec<u8>>, MlsError> {
+ let epoch_seed = get_pre_epoch_secret(cipher_suite, psk_secret, joiner_secret).await?;
+ kdf_derive_secret(cipher_suite, &epoch_seed, b"welcome").await
+}
+
+#[cfg(test)]
+pub(crate) mod test_utils {
+ use alloc::vec;
+ use alloc::vec::Vec;
+ use mls_rs_core::crypto::CipherSuiteProvider;
+ use zeroize::Zeroizing;
+
+ use crate::{cipher_suite::CipherSuite, crypto::test_utils::test_cipher_suite_provider};
+
+ use super::{InitSecret, JoinerSecret, KeySchedule};
+
+ #[cfg(all(feature = "rfc_compliant", not(mls_build_async)))]
+ use mls_rs_core::error::IntoAnyError;
+
+ #[cfg(all(feature = "rfc_compliant", not(mls_build_async)))]
+ use super::MlsError;
+
+ impl From<JoinerSecret> for Vec<u8> {
+ fn from(mut value: JoinerSecret) -> Self {
+ core::mem::take(&mut value.0)
+ }
+ }
+
+ pub(crate) fn get_test_key_schedule(cipher_suite: CipherSuite) -> KeySchedule {
+ let key_size = test_cipher_suite_provider(cipher_suite).kdf_extract_size();
+ let fake_secret = Zeroizing::new(vec![1u8; key_size]);
+
+ KeySchedule {
+ exporter_secret: fake_secret.clone(),
+ authentication_secret: fake_secret.clone(),
+ external_secret: fake_secret.clone(),
+ membership_key: fake_secret,
+ init_secret: InitSecret::new(vec![0u8; key_size]),
+ }
+ }
+
+ impl InitSecret {
+ pub fn new(init_secret: Vec<u8>) -> Self {
+ InitSecret(Zeroizing::new(init_secret))
+ }
+
+ #[cfg(all(feature = "rfc_compliant", test, not(mls_build_async)))]
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ pub fn random<P: CipherSuiteProvider>(cipher_suite: &P) -> Result<Self, MlsError> {
+ cipher_suite
+ .random_bytes_vec(cipher_suite.kdf_extract_size())
+ .map(Zeroizing::new)
+ .map(InitSecret)
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
+ }
+ }
+
+ #[cfg(feature = "rfc_compliant")]
+ impl KeySchedule {
+ pub fn set_membership_key(&mut self, key: Vec<u8>) {
+ self.membership_key = Zeroizing::new(key)
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use crate::client::test_utils::TEST_PROTOCOL_VERSION;
+ use crate::crypto::test_utils::try_test_cipher_suite_provider;
+ use crate::group::key_schedule::{
+ get_welcome_secret, kdf_derive_secret, kdf_expand_with_label,
+ };
+ use crate::group::GroupContext;
+ use alloc::string::String;
+ use alloc::vec::Vec;
+ use mls_rs_codec::MlsEncode;
+ use mls_rs_core::crypto::CipherSuiteProvider;
+ use mls_rs_core::extension::ExtensionList;
+
+ #[cfg(all(not(mls_build_async), feature = "rfc_compliant"))]
+ use crate::{
+ crypto::test_utils::{test_cipher_suite_provider, TestCryptoProvider},
+ group::{
+ key_schedule::KeyScheduleDerivationResult, test_utils::random_bytes, InitSecret,
+ PskSecret,
+ },
+ };
+
+ #[cfg(all(not(mls_build_async), feature = "rfc_compliant"))]
+ use alloc::{string::ToString, vec};
+
+ #[cfg(target_arch = "wasm32")]
+ use wasm_bindgen_test::wasm_bindgen_test as test;
+ use zeroize::Zeroizing;
+
+ use super::test_utils::get_test_key_schedule;
+ use super::KeySchedule;
+
+ #[derive(serde::Deserialize, serde::Serialize)]
+ struct TestCase {
+ cipher_suite: u16,
+ #[serde(with = "hex::serde")]
+ group_id: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ initial_init_secret: Vec<u8>,
+ epochs: Vec<KeyScheduleEpoch>,
+ }
+
+ #[derive(serde::Deserialize, serde::Serialize)]
+ struct KeyScheduleEpoch {
+ #[serde(with = "hex::serde")]
+ commit_secret: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ psk_secret: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ confirmed_transcript_hash: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ tree_hash: Vec<u8>,
+
+ #[serde(with = "hex::serde")]
+ group_context: Vec<u8>,
+
+ #[serde(with = "hex::serde")]
+ joiner_secret: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ welcome_secret: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ init_secret: Vec<u8>,
+
+ #[serde(with = "hex::serde")]
+ sender_data_secret: Vec<u8>,
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ #[serde(with = "hex::serde")]
+ encryption_secret: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ exporter_secret: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ epoch_authenticator: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ external_secret: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ confirmation_key: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ membership_key: Vec<u8>,
+ #[cfg(feature = "psk")]
+ #[serde(with = "hex::serde")]
+ resumption_psk: Vec<u8>,
+
+ #[serde(with = "hex::serde")]
+ external_pub: Vec<u8>,
+
+ exporter: KeyScheduleExporter,
+ }
+
+ #[derive(serde::Deserialize, serde::Serialize)]
+ struct KeyScheduleExporter {
+ label: String,
+ #[serde(with = "hex::serde")]
+ context: Vec<u8>,
+ length: usize,
+ #[serde(with = "hex::serde")]
+ secret: Vec<u8>,
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_key_schedule() {
+ let test_cases: Vec<TestCase> =
+ load_test_case_json!(key_schedule_test_vector, generate_test_vector());
+
+ for test_case in test_cases {
+ let Some(cs_provider) = try_test_cipher_suite_provider(test_case.cipher_suite) else {
+ continue;
+ };
+
+ let mut key_schedule = get_test_key_schedule(cs_provider.cipher_suite());
+ key_schedule.init_secret.0 = Zeroizing::new(test_case.initial_init_secret);
+
+ for (i, epoch) in test_case.epochs.into_iter().enumerate() {
+ let context = GroupContext {
+ protocol_version: TEST_PROTOCOL_VERSION,
+ cipher_suite: cs_provider.cipher_suite(),
+ group_id: test_case.group_id.clone(),
+ epoch: i as u64,
+ tree_hash: epoch.tree_hash,
+ confirmed_transcript_hash: epoch.confirmed_transcript_hash.into(),
+ extensions: ExtensionList::new(),
+ };
+
+ assert_eq!(context.mls_encode_to_vec().unwrap(), epoch.group_context);
+
+ let psk = epoch.psk_secret.into();
+ let commit = epoch.commit_secret.into();
+
+ let key_schedule_res = KeySchedule::from_key_schedule(
+ &key_schedule,
+ &commit,
+ &context,
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ 32,
+ &psk,
+ &cs_provider,
+ )
+ .await
+ .unwrap();
+
+ key_schedule = key_schedule_res.key_schedule;
+
+ let welcome =
+ get_welcome_secret(&cs_provider, &key_schedule_res.joiner_secret, &psk)
+ .await
+ .unwrap();
+
+ assert_eq!(*welcome, epoch.welcome_secret);
+
+ let expected: Vec<u8> = key_schedule_res.joiner_secret.into();
+ assert_eq!(epoch.joiner_secret, expected);
+
+ assert_eq!(&key_schedule.init_secret.0.to_vec(), &epoch.init_secret);
+
+ assert_eq!(
+ epoch.sender_data_secret,
+ *key_schedule_res.epoch_secrets.sender_data_secret.to_vec()
+ );
+
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ assert_eq!(
+ epoch.encryption_secret,
+ *key_schedule_res.epoch_secrets.secret_tree.get_root_secret()
+ );
+
+ assert_eq!(epoch.exporter_secret, key_schedule.exporter_secret.to_vec());
+
+ assert_eq!(
+ epoch.epoch_authenticator,
+ key_schedule.authentication_secret.to_vec()
+ );
+
+ assert_eq!(epoch.external_secret, key_schedule.external_secret.to_vec());
+
+ assert_eq!(
+ epoch.confirmation_key,
+ key_schedule_res.confirmation_key.to_vec()
+ );
+
+ assert_eq!(epoch.membership_key, key_schedule.membership_key.to_vec());
+
+ #[cfg(feature = "psk")]
+ {
+ let expected: Vec<u8> =
+ key_schedule_res.epoch_secrets.resumption_secret.to_vec();
+
+ assert_eq!(epoch.resumption_psk, expected);
+ }
+
+ let (_external_sec, external_pub) = key_schedule
+ .get_external_key_pair(&cs_provider)
+ .await
+ .unwrap();
+
+ assert_eq!(epoch.external_pub, *external_pub);
+
+ let exp = epoch.exporter;
+
+ let exported = key_schedule
+ .export_secret(exp.label.as_bytes(), &exp.context, exp.length, &cs_provider)
+ .await
+ .unwrap();
+
+ assert_eq!(exported.to_vec(), exp.secret);
+ }
+ }
+ }
+
+ #[cfg(all(not(mls_build_async), feature = "rfc_compliant"))]
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn generate_test_vector() -> Vec<TestCase> {
+ let mut test_cases = vec![];
+
+ for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
+ let cs_provider = test_cipher_suite_provider(cipher_suite);
+ let key_size = cs_provider.kdf_extract_size();
+
+ let mut group_context = GroupContext {
+ protocol_version: TEST_PROTOCOL_VERSION,
+ cipher_suite: cs_provider.cipher_suite(),
+ group_id: b"my group 5".to_vec(),
+ epoch: 0,
+ tree_hash: random_bytes(key_size),
+ confirmed_transcript_hash: random_bytes(key_size).into(),
+ extensions: Default::default(),
+ };
+
+ let initial_init_secret = InitSecret::random(&cs_provider).unwrap();
+ let mut key_schedule = get_test_key_schedule(cs_provider.cipher_suite());
+ key_schedule.init_secret = initial_init_secret.clone();
+
+ let commit_secret = random_bytes(key_size).into();
+ let psk_secret = PskSecret::new(&cs_provider);
+
+ let key_schedule_res = KeySchedule::from_key_schedule(
+ &key_schedule,
+ &commit_secret,
+ &group_context,
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ 32,
+ &psk_secret,
+ &cs_provider,
+ )
+ .unwrap();
+
+ key_schedule = key_schedule_res.key_schedule.clone();
+
+ let epoch1 = KeyScheduleEpoch::new(
+ key_schedule_res,
+ psk_secret,
+ commit_secret.to_vec(),
+ &group_context,
+ &cs_provider,
+ );
+
+ group_context.epoch += 1;
+ group_context.confirmed_transcript_hash = random_bytes(key_size).into();
+ group_context.tree_hash = random_bytes(key_size);
+
+ let commit_secret = random_bytes(key_size).into();
+ let psk_secret = PskSecret::new(&cs_provider);
+
+ let key_schedule_res = KeySchedule::from_key_schedule(
+ &key_schedule,
+ &commit_secret,
+ &group_context,
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ 32,
+ &psk_secret,
+ &cs_provider,
+ )
+ .unwrap();
+
+ let epoch2 = KeyScheduleEpoch::new(
+ key_schedule_res,
+ psk_secret,
+ commit_secret.to_vec(),
+ &group_context,
+ &cs_provider,
+ );
+
+ let test_case = TestCase {
+ cipher_suite: cs_provider.cipher_suite().into(),
+ group_id: group_context.group_id.clone(),
+ initial_init_secret: initial_init_secret.0.to_vec(),
+ epochs: vec![epoch1, epoch2],
+ };
+
+ test_cases.push(test_case);
+ }
+
+ test_cases
+ }
+
+ #[cfg(not(all(not(mls_build_async), feature = "rfc_compliant")))]
+ fn generate_test_vector() -> Vec<TestCase> {
+ panic!("Tests cannot be generated in async mode");
+ }
+
+ #[cfg(all(not(mls_build_async), feature = "rfc_compliant"))]
+ impl KeyScheduleEpoch {
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn new<P: CipherSuiteProvider>(
+ key_schedule_res: KeyScheduleDerivationResult,
+ psk_secret: PskSecret,
+ commit_secret: Vec<u8>,
+ group_context: &GroupContext,
+ cs: &P,
+ ) -> Self {
+ let (_external_sec, external_pub) = key_schedule_res
+ .key_schedule
+ .get_external_key_pair(cs)
+ .unwrap();
+
+ let mut exporter = KeyScheduleExporter {
+ label: "exporter label 15".to_string(),
+ context: b"exporter context".to_vec(),
+ length: 64,
+ secret: vec![],
+ };
+
+ exporter.secret = key_schedule_res
+ .key_schedule
+ .export_secret(
+ exporter.label.as_bytes(),
+ &exporter.context,
+ exporter.length,
+ cs,
+ )
+ .unwrap()
+ .to_vec();
+
+ let welcome_secret =
+ get_welcome_secret(cs, &key_schedule_res.joiner_secret, &psk_secret)
+ .unwrap()
+ .to_vec();
+
+ KeyScheduleEpoch {
+ commit_secret,
+ welcome_secret,
+ psk_secret: psk_secret.to_vec(),
+ group_context: group_context.mls_encode_to_vec().unwrap(),
+ joiner_secret: key_schedule_res.joiner_secret.into(),
+ init_secret: key_schedule_res.key_schedule.init_secret.0.to_vec(),
+ sender_data_secret: key_schedule_res.epoch_secrets.sender_data_secret.to_vec(),
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ encryption_secret: key_schedule_res.epoch_secrets.secret_tree.get_root_secret(),
+ exporter_secret: key_schedule_res.key_schedule.exporter_secret.to_vec(),
+ epoch_authenticator: key_schedule_res.key_schedule.authentication_secret.to_vec(),
+ external_secret: key_schedule_res.key_schedule.external_secret.to_vec(),
+ confirmation_key: key_schedule_res.confirmation_key.to_vec(),
+ membership_key: key_schedule_res.key_schedule.membership_key.to_vec(),
+ #[cfg(feature = "psk")]
+ resumption_psk: key_schedule_res.epoch_secrets.resumption_secret.to_vec(),
+ external_pub: external_pub.to_vec(),
+ exporter,
+ confirmed_transcript_hash: group_context.confirmed_transcript_hash.to_vec(),
+ tree_hash: group_context.tree_hash.clone(),
+ }
+ }
+ }
+
+ #[derive(Debug, serde::Serialize, serde::Deserialize)]
+ struct ExpandWithLabelTestCase {
+ #[serde(with = "hex::serde")]
+ secret: Vec<u8>,
+ label: String,
+ #[serde(with = "hex::serde")]
+ context: Vec<u8>,
+ length: usize,
+ #[serde(with = "hex::serde")]
+ out: Vec<u8>,
+ }
+
+ #[derive(Debug, serde::Serialize, serde::Deserialize)]
+ struct DeriveSecretTestCase {
+ #[serde(with = "hex::serde")]
+ secret: Vec<u8>,
+ label: String,
+ #[serde(with = "hex::serde")]
+ out: Vec<u8>,
+ }
+
+ #[derive(Debug, serde::Serialize, serde::Deserialize)]
+ pub struct InteropTestCase {
+ cipher_suite: u16,
+ expand_with_label: ExpandWithLabelTestCase,
+ derive_secret: DeriveSecretTestCase,
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_basic_crypto_test_vectors() {
+ // The test vector can be found here https://github.com/mlswg/mls-implementations/blob/main/test-vectors/crypto-basics.json
+ let test_cases: Vec<InteropTestCase> =
+ load_test_case_json!(basic_crypto, Vec::<InteropTestCase>::new());
+
+ for test_case in test_cases {
+ if let Some(cs) = try_test_cipher_suite_provider(test_case.cipher_suite) {
+ let test_exp = &test_case.expand_with_label;
+
+ let computed = kdf_expand_with_label(
+ &cs,
+ &test_exp.secret,
+ test_exp.label.as_bytes(),
+ &test_exp.context,
+ Some(test_exp.length),
+ )
+ .await
+ .unwrap();
+
+ assert_eq!(&computed.to_vec(), &test_exp.out);
+
+ let test_derive = &test_case.derive_secret;
+
+ let computed =
+ kdf_derive_secret(&cs, &test_derive.secret, test_derive.label.as_bytes())
+ .await
+ .unwrap();
+
+ assert_eq!(&computed.to_vec(), &test_derive.out);
+ }
+ }
+ }
+}
diff --git a/src/group/membership_tag.rs b/src/group/membership_tag.rs
new file mode 100644
index 0000000..b28edea
--- /dev/null
+++ b/src/group/membership_tag.rs
@@ -0,0 +1,163 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use crate::client::MlsError;
+use crate::crypto::CipherSuiteProvider;
+use crate::group::message_signature::{AuthenticatedContentTBS, FramedContentAuthData};
+use crate::group::GroupContext;
+use alloc::vec::Vec;
+use core::{
+ fmt::{self, Debug},
+ ops::Deref,
+};
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+use mls_rs_core::error::IntoAnyError;
+
+use super::message_signature::AuthenticatedContent;
+
+#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode)]
+struct AuthenticatedContentTBM<'a> {
+ content_tbs: AuthenticatedContentTBS<'a>,
+ auth: &'a FramedContentAuthData,
+}
+
+impl<'a> AuthenticatedContentTBM<'a> {
+ pub fn from_authenticated_content(
+ auth_content: &'a AuthenticatedContent,
+ group_context: &'a GroupContext,
+ ) -> AuthenticatedContentTBM<'a> {
+ AuthenticatedContentTBM {
+ content_tbs: AuthenticatedContentTBS::from_authenticated_content(
+ auth_content,
+ Some(group_context),
+ group_context.protocol_version,
+ ),
+ auth: &auth_content.auth,
+ }
+ }
+}
+
+#[derive(Clone, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+pub struct MembershipTag(#[mls_codec(with = "mls_rs_codec::byte_vec")] Vec<u8>);
+
+impl Debug for MembershipTag {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ mls_rs_core::debug::pretty_bytes(&self.0)
+ .named("MembershipTag")
+ .fmt(f)
+ }
+}
+
+impl Deref for MembershipTag {
+ type Target = Vec<u8>;
+
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+impl From<Vec<u8>> for MembershipTag {
+ fn from(m: Vec<u8>) -> Self {
+ Self(m)
+ }
+}
+
+impl MembershipTag {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn create<P: CipherSuiteProvider>(
+ authenticated_content: &AuthenticatedContent,
+ group_context: &GroupContext,
+ membership_key: &[u8],
+ cipher_suite_provider: &P,
+ ) -> Result<Self, MlsError> {
+ let plaintext_tbm = AuthenticatedContentTBM::from_authenticated_content(
+ authenticated_content,
+ group_context,
+ );
+
+ let serialized_tbm = plaintext_tbm.mls_encode_to_vec()?;
+
+ let tag = cipher_suite_provider
+ .mac(membership_key, &serialized_tbm)
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
+
+ Ok(MembershipTag(tag))
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::crypto::test_utils::{test_cipher_suite_provider, try_test_cipher_suite_provider};
+ use crate::group::{
+ framing::test_utils::get_test_auth_content, test_utils::get_test_group_context,
+ };
+
+ #[cfg(not(mls_build_async))]
+ use crate::crypto::test_utils::TestCryptoProvider;
+
+ #[cfg(target_arch = "wasm32")]
+ use wasm_bindgen_test::wasm_bindgen_test as test;
+
+ #[derive(Debug, serde::Serialize, serde::Deserialize)]
+ struct TestCase {
+ cipher_suite: u16,
+ #[serde(with = "hex::serde")]
+ tag: Vec<u8>,
+ }
+
+ #[cfg(not(mls_build_async))]
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn generate_test_cases() -> Vec<TestCase> {
+ let mut test_cases = Vec::new();
+
+ for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
+ let tag = MembershipTag::create(
+ &get_test_auth_content(),
+ &get_test_group_context(1, cipher_suite),
+ b"membership_key".as_ref(),
+ &test_cipher_suite_provider(cipher_suite),
+ )
+ .unwrap();
+
+ test_cases.push(TestCase {
+ cipher_suite: cipher_suite.into(),
+ tag: tag.to_vec(),
+ });
+ }
+
+ test_cases
+ }
+
+ #[cfg(mls_build_async)]
+ fn generate_test_cases() -> Vec<TestCase> {
+ panic!("Tests cannot be generated in async mode");
+ }
+
+ fn load_test_cases() -> Vec<TestCase> {
+ load_test_case_json!(membership_tag, generate_test_cases())
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_membership_tag() {
+ for case in load_test_cases() {
+ let Some(cs_provider) = try_test_cipher_suite_provider(case.cipher_suite) else {
+ continue;
+ };
+
+ let tag = MembershipTag::create(
+ &get_test_auth_content(),
+ &get_test_group_context(1, cs_provider.cipher_suite()).await,
+ b"membership_key".as_ref(),
+ &test_cipher_suite_provider(cs_provider.cipher_suite()),
+ )
+ .await
+ .unwrap();
+
+ assert_eq!(**tag, case.tag);
+ }
+ }
+}
diff --git a/src/group/message_processor.rs b/src/group/message_processor.rs
new file mode 100644
index 0000000..8084a58
--- /dev/null
+++ b/src/group/message_processor.rs
@@ -0,0 +1,1039 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use super::{
+ commit_sender,
+ confirmation_tag::ConfirmationTag,
+ framing::{
+ ApplicationData, Content, ContentType, MlsMessage, MlsMessagePayload, PublicMessage, Sender,
+ },
+ message_signature::AuthenticatedContent,
+ mls_rules::{CommitDirection, MlsRules},
+ proposal_filter::ProposalBundle,
+ state::GroupState,
+ transcript_hash::InterimTranscriptHash,
+ transcript_hashes, validate_group_info_member, GroupContext, GroupInfo, Welcome,
+};
+use crate::{
+ client::MlsError,
+ key_package::validate_key_package_properties,
+ time::MlsTime,
+ tree_kem::{
+ leaf_node_validator::{LeafNodeValidator, ValidationContext},
+ node::LeafIndex,
+ path_secret::PathSecret,
+ validate_update_path, TreeKemPrivate, TreeKemPublic, ValidatedUpdatePath,
+ },
+ CipherSuiteProvider, KeyPackage,
+};
+#[cfg(mls_build_async)]
+use alloc::boxed::Box;
+use alloc::vec::Vec;
+use core::fmt::{self, Debug};
+use mls_rs_core::{
+ identity::IdentityProvider, protocol_version::ProtocolVersion, psk::PreSharedKeyStorage,
+};
+
+#[cfg(feature = "by_ref_proposal")]
+use super::proposal_ref::ProposalRef;
+
+#[cfg(not(feature = "by_ref_proposal"))]
+use crate::group::proposal_cache::resolve_for_commit;
+
+#[cfg(feature = "by_ref_proposal")]
+use super::proposal::Proposal;
+
+#[cfg(feature = "custom_proposal")]
+use super::proposal_filter::ProposalInfo;
+
+#[cfg(feature = "state_update")]
+use mls_rs_core::{
+ crypto::CipherSuite,
+ group::{MemberUpdate, RosterUpdate},
+};
+
+#[cfg(all(feature = "state_update", feature = "psk"))]
+use mls_rs_core::psk::ExternalPskId;
+
+#[cfg(feature = "state_update")]
+use crate::tree_kem::UpdatePath;
+
+#[cfg(feature = "state_update")]
+use super::{member_from_key_package, member_from_leaf_node};
+
+#[cfg(all(feature = "state_update", feature = "custom_proposal"))]
+use super::proposal::CustomProposal;
+
+#[cfg(feature = "private_message")]
+use crate::group::framing::PrivateMessage;
+
+#[cfg(feature = "by_ref_proposal")]
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+
+#[derive(Debug)]
+pub(crate) struct ProvisionalState {
+ pub(crate) public_tree: TreeKemPublic,
+ pub(crate) applied_proposals: ProposalBundle,
+ pub(crate) group_context: GroupContext,
+ pub(crate) external_init_index: Option<LeafIndex>,
+ pub(crate) indexes_of_added_kpkgs: Vec<LeafIndex>,
+ #[cfg(feature = "by_ref_proposal")]
+ pub(crate) unused_proposals: Vec<crate::mls_rules::ProposalInfo<Proposal>>,
+}
+
+//By default, the path field of a Commit MUST be populated. The path field MAY be omitted if
+//(a) it covers at least one proposal and (b) none of the proposals covered by the Commit are
+//of "path required" types. A proposal type requires a path if it cannot change the group
+//membership in a way that requires the forward secrecy and post-compromise security guarantees
+//that an UpdatePath provides. The only proposal types defined in this document that do not
+//require a path are:
+
+// add
+// psk
+// reinit
+pub(crate) fn path_update_required(proposals: &ProposalBundle) -> bool {
+ let res = proposals.external_init_proposals().first().is_some();
+
+ #[cfg(feature = "by_ref_proposal")]
+ let res = res || !proposals.update_proposals().is_empty();
+
+ res || proposals.length() == 0
+ || proposals.group_context_extensions_proposal().is_some()
+ || !proposals.remove_proposals().is_empty()
+}
+
+/// Representation of changes made by a [commit](crate::Group::commit).
+#[cfg(feature = "state_update")]
+#[derive(Clone, Debug, PartialEq)]
+pub struct StateUpdate {
+ pub(crate) roster_update: RosterUpdate,
+ #[cfg(feature = "psk")]
+ pub(crate) added_psks: Vec<ExternalPskId>,
+ pub(crate) pending_reinit: Option<CipherSuite>,
+ pub(crate) active: bool,
+ pub(crate) epoch: u64,
+ #[cfg(feature = "custom_proposal")]
+ pub(crate) custom_proposals: Vec<ProposalInfo<CustomProposal>>,
+ #[cfg(feature = "by_ref_proposal")]
+ pub(crate) unused_proposals: Vec<crate::mls_rules::ProposalInfo<Proposal>>,
+}
+
+#[cfg(not(feature = "state_update"))]
+#[non_exhaustive]
+#[derive(Clone, Debug, PartialEq)]
+pub struct StateUpdate {}
+
+#[cfg(feature = "state_update")]
+impl StateUpdate {
+ /// Changes to the roster as a result of proposals.
+ pub fn roster_update(&self) -> &RosterUpdate {
+ &self.roster_update
+ }
+
+ #[cfg(feature = "psk")]
+ /// Pre-shared keys that have been added to the group.
+ pub fn added_psks(&self) -> &[ExternalPskId] {
+ &self.added_psks
+ }
+
+ /// Flag to indicate if the group is now pending reinitialization due to
+ /// receiving a [`ReInit`](crate::group::proposal::Proposal::ReInit)
+ /// proposal.
+ pub fn is_pending_reinit(&self) -> bool {
+ self.pending_reinit.is_some()
+ }
+
+ /// Flag to indicate the group is still active. This will be false if the
+ /// member processing the commit has been removed from the group.
+ pub fn is_active(&self) -> bool {
+ self.active
+ }
+
+ /// The new epoch of the group state.
+ pub fn new_epoch(&self) -> u64 {
+ self.epoch
+ }
+
+ /// Custom proposals that were committed to.
+ #[cfg(feature = "custom_proposal")]
+ pub fn custom_proposals(&self) -> &[ProposalInfo<CustomProposal>] {
+ &self.custom_proposals
+ }
+
+ /// Proposals that were received in the prior epoch but not committed to.
+ #[cfg(feature = "by_ref_proposal")]
+ pub fn unused_proposals(&self) -> &[crate::mls_rules::ProposalInfo<Proposal>] {
+ &self.unused_proposals
+ }
+
+ pub fn pending_reinit_ciphersuite(&self) -> Option<CipherSuite> {
+ self.pending_reinit
+ }
+}
+
+#[cfg_attr(
+ all(feature = "ffi", not(test)),
+ safer_ffi_gen::ffi_type(clone, opaque)
+)]
+#[derive(Debug, Clone)]
+#[allow(clippy::large_enum_variant)]
+/// An event generated as a result of processing a message for a group with
+/// [`Group::process_incoming_message`](crate::group::Group::process_incoming_message).
+pub enum ReceivedMessage {
+ /// An application message was decrypted.
+ ApplicationMessage(ApplicationMessageDescription),
+ /// A new commit was processed creating a new group state.
+ Commit(CommitMessageDescription),
+ /// A proposal was received.
+ Proposal(ProposalMessageDescription),
+ /// Validated GroupInfo object
+ GroupInfo(GroupInfo),
+ /// Validated welcome message
+ Welcome,
+ /// Validated key package
+ KeyPackage(KeyPackage),
+}
+
+impl TryFrom<ApplicationMessageDescription> for ReceivedMessage {
+ type Error = MlsError;
+
+ fn try_from(value: ApplicationMessageDescription) -> Result<Self, Self::Error> {
+ Ok(ReceivedMessage::ApplicationMessage(value))
+ }
+}
+
+impl From<CommitMessageDescription> for ReceivedMessage {
+ fn from(value: CommitMessageDescription) -> Self {
+ ReceivedMessage::Commit(value)
+ }
+}
+
+impl From<ProposalMessageDescription> for ReceivedMessage {
+ fn from(value: ProposalMessageDescription) -> Self {
+ ReceivedMessage::Proposal(value)
+ }
+}
+
+impl From<GroupInfo> for ReceivedMessage {
+ fn from(value: GroupInfo) -> Self {
+ ReceivedMessage::GroupInfo(value)
+ }
+}
+
+impl From<Welcome> for ReceivedMessage {
+ fn from(_: Welcome) -> Self {
+ ReceivedMessage::Welcome
+ }
+}
+
+impl From<KeyPackage> for ReceivedMessage {
+ fn from(value: KeyPackage) -> Self {
+ ReceivedMessage::KeyPackage(value)
+ }
+}
+
+#[cfg_attr(
+ all(feature = "ffi", not(test)),
+ safer_ffi_gen::ffi_type(clone, opaque)
+)]
+#[derive(Clone, PartialEq, Eq)]
+/// Description of a MLS application message.
+pub struct ApplicationMessageDescription {
+ /// Index of this user in the group state.
+ pub sender_index: u32,
+ /// Received application data.
+ data: ApplicationData,
+ /// Plaintext authenticated data in the received MLS packet.
+ pub authenticated_data: Vec<u8>,
+}
+
+impl Debug for ApplicationMessageDescription {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("ApplicationMessageDescription")
+ .field("sender_index", &self.sender_index)
+ .field("data", &self.data)
+ .field(
+ "authenticated_data",
+ &mls_rs_core::debug::pretty_bytes(&self.authenticated_data),
+ )
+ .finish()
+ }
+}
+
+#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
+impl ApplicationMessageDescription {
+ pub fn data(&self) -> &[u8] {
+ self.data.as_bytes()
+ }
+}
+
+#[cfg_attr(
+ all(feature = "ffi", not(test)),
+ safer_ffi_gen::ffi_type(clone, opaque)
+)]
+#[derive(Clone, PartialEq)]
+#[non_exhaustive]
+/// Description of a processed MLS commit message.
+pub struct CommitMessageDescription {
+ /// True if this is the result of an external commit.
+ pub is_external: bool,
+ /// The index in the group state of the member who performed this commit.
+ pub committer: u32,
+ /// A full description of group state changes as a result of this commit.
+ pub state_update: StateUpdate,
+ /// Plaintext authenticated data in the received MLS packet.
+ pub authenticated_data: Vec<u8>,
+}
+
+impl Debug for CommitMessageDescription {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("CommitMessageDescription")
+ .field("is_external", &self.is_external)
+ .field("committer", &self.committer)
+ .field("state_update", &self.state_update)
+ .field(
+ "authenticated_data",
+ &mls_rs_core::debug::pretty_bytes(&self.authenticated_data),
+ )
+ .finish()
+ }
+}
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+/// Proposal sender type.
+pub enum ProposalSender {
+ /// A current member of the group by index in the group state.
+ Member(u32),
+ /// An external entity by index within an
+ /// [`ExternalSendersExt`](crate::extension::built_in::ExternalSendersExt).
+ External(u32),
+ /// A new member proposing their addition to the group.
+ NewMember,
+}
+
+impl TryFrom<Sender> for ProposalSender {
+ type Error = MlsError;
+
+ fn try_from(value: Sender) -> Result<Self, Self::Error> {
+ match value {
+ Sender::Member(index) => Ok(Self::Member(index)),
+ #[cfg(feature = "by_ref_proposal")]
+ Sender::External(index) => Ok(Self::External(index)),
+ #[cfg(feature = "by_ref_proposal")]
+ Sender::NewMemberProposal => Ok(Self::NewMember),
+ Sender::NewMemberCommit => Err(MlsError::InvalidSender),
+ }
+ }
+}
+
+#[cfg(feature = "by_ref_proposal")]
+#[cfg_attr(
+ all(feature = "ffi", not(test)),
+ safer_ffi_gen::ffi_type(clone, opaque)
+)]
+#[derive(Clone)]
+#[non_exhaustive]
+/// Description of a processed MLS proposal message.
+pub struct ProposalMessageDescription {
+ /// Sender of the proposal.
+ pub sender: ProposalSender,
+ /// Proposal content.
+ pub proposal: Proposal,
+ /// Plaintext authenticated data in the received MLS packet.
+ pub authenticated_data: Vec<u8>,
+ /// Proposal reference.
+ pub proposal_ref: ProposalRef,
+}
+
+#[cfg(feature = "by_ref_proposal")]
+impl Debug for ProposalMessageDescription {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("ProposalMessageDescription")
+ .field("sender", &self.sender)
+ .field("proposal", &self.proposal)
+ .field(
+ "authenticated_data",
+ &mls_rs_core::debug::pretty_bytes(&self.authenticated_data),
+ )
+ .field("proposal_ref", &self.proposal_ref)
+ .finish()
+ }
+}
+
+#[cfg(feature = "by_ref_proposal")]
+#[derive(MlsSize, MlsEncode, MlsDecode)]
+pub struct CachedProposal {
+ pub(crate) proposal: Proposal,
+ pub(crate) proposal_ref: ProposalRef,
+ pub(crate) sender: Sender,
+}
+
+#[cfg(feature = "by_ref_proposal")]
+impl CachedProposal {
+ /// Deserialize the proposal
+ pub fn from_bytes(bytes: &[u8]) -> Result<Self, MlsError> {
+ Ok(Self::mls_decode(&mut &*bytes)?)
+ }
+
+ /// Serialize the proposal
+ pub fn to_bytes(&self) -> Result<Vec<u8>, MlsError> {
+ Ok(self.mls_encode_to_vec()?)
+ }
+}
+
+#[cfg(feature = "by_ref_proposal")]
+impl ProposalMessageDescription {
+ pub fn cached_proposal(self) -> CachedProposal {
+ let sender = match self.sender {
+ ProposalSender::Member(i) => Sender::Member(i),
+ ProposalSender::External(i) => Sender::External(i),
+ ProposalSender::NewMember => Sender::NewMemberProposal,
+ };
+
+ CachedProposal {
+ proposal: self.proposal,
+ proposal_ref: self.proposal_ref,
+ sender,
+ }
+ }
+
+ pub fn proposal_ref(&self) -> Vec<u8> {
+ self.proposal_ref.to_vec()
+ }
+}
+
+#[cfg(not(feature = "by_ref_proposal"))]
+#[cfg_attr(
+ all(feature = "ffi", not(test)),
+ safer_ffi_gen::ffi_type(clone, opaque)
+)]
+#[derive(Debug, Clone)]
+/// Description of a processed MLS proposal message.
+pub struct ProposalMessageDescription {}
+
+#[allow(clippy::large_enum_variant)]
+pub(crate) enum EventOrContent<E> {
+ #[cfg_attr(
+ not(all(feature = "private_message", feature = "external_client")),
+ allow(dead_code)
+ )]
+ Event(E),
+ Content(AuthenticatedContent),
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(all(target_arch = "wasm32", mls_build_async), maybe_async::must_be_async(?Send))]
+#[cfg_attr(
+ all(not(target_arch = "wasm32"), mls_build_async),
+ maybe_async::must_be_async
+)]
+pub(crate) trait MessageProcessor: Send + Sync {
+ type OutputType: TryFrom<ApplicationMessageDescription, Error = MlsError>
+ + From<CommitMessageDescription>
+ + From<ProposalMessageDescription>
+ + From<GroupInfo>
+ + From<Welcome>
+ + From<KeyPackage>
+ + Send;
+
+ type MlsRules: MlsRules;
+ type IdentityProvider: IdentityProvider;
+ type CipherSuiteProvider: CipherSuiteProvider;
+ type PreSharedKeyStorage: PreSharedKeyStorage;
+
+ async fn process_incoming_message(
+ &mut self,
+ message: MlsMessage,
+ #[cfg(feature = "by_ref_proposal")] cache_proposal: bool,
+ ) -> Result<Self::OutputType, MlsError> {
+ self.process_incoming_message_with_time(
+ message,
+ #[cfg(feature = "by_ref_proposal")]
+ cache_proposal,
+ None,
+ )
+ .await
+ }
+
+ async fn process_incoming_message_with_time(
+ &mut self,
+ message: MlsMessage,
+ #[cfg(feature = "by_ref_proposal")] cache_proposal: bool,
+ time_sent: Option<MlsTime>,
+ ) -> Result<Self::OutputType, MlsError> {
+ let event_or_content = self.get_event_from_incoming_message(message).await?;
+
+ self.process_event_or_content(
+ event_or_content,
+ #[cfg(feature = "by_ref_proposal")]
+ cache_proposal,
+ time_sent,
+ )
+ .await
+ }
+
+ async fn get_event_from_incoming_message(
+ &mut self,
+ message: MlsMessage,
+ ) -> Result<EventOrContent<Self::OutputType>, MlsError> {
+ self.check_metadata(&message)?;
+
+ match message.payload {
+ MlsMessagePayload::Plain(plaintext) => {
+ self.verify_plaintext_authentication(plaintext).await
+ }
+ #[cfg(feature = "private_message")]
+ MlsMessagePayload::Cipher(cipher_text) => self.process_ciphertext(&cipher_text).await,
+ MlsMessagePayload::GroupInfo(group_info) => {
+ validate_group_info_member(
+ self.group_state(),
+ message.version,
+ &group_info,
+ self.cipher_suite_provider(),
+ )
+ .await?;
+
+ Ok(EventOrContent::Event(group_info.into()))
+ }
+ MlsMessagePayload::Welcome(welcome) => {
+ self.validate_welcome(&welcome, message.version)?;
+
+ Ok(EventOrContent::Event(welcome.into()))
+ }
+ MlsMessagePayload::KeyPackage(key_package) => {
+ self.validate_key_package(&key_package, message.version)
+ .await?;
+
+ Ok(EventOrContent::Event(key_package.into()))
+ }
+ }
+ }
+
+ async fn process_event_or_content(
+ &mut self,
+ event_or_content: EventOrContent<Self::OutputType>,
+ #[cfg(feature = "by_ref_proposal")] cache_proposal: bool,
+ time_sent: Option<MlsTime>,
+ ) -> Result<Self::OutputType, MlsError> {
+ let msg = match event_or_content {
+ EventOrContent::Event(event) => event,
+ EventOrContent::Content(content) => {
+ self.process_auth_content(
+ content,
+ #[cfg(feature = "by_ref_proposal")]
+ cache_proposal,
+ time_sent,
+ )
+ .await?
+ }
+ };
+
+ Ok(msg)
+ }
+
+ async fn process_auth_content(
+ &mut self,
+ auth_content: AuthenticatedContent,
+ #[cfg(feature = "by_ref_proposal")] cache_proposal: bool,
+ time_sent: Option<MlsTime>,
+ ) -> Result<Self::OutputType, MlsError> {
+ let event = match auth_content.content.content {
+ #[cfg(feature = "private_message")]
+ Content::Application(data) => {
+ let authenticated_data = auth_content.content.authenticated_data;
+ let sender = auth_content.content.sender;
+
+ self.process_application_message(data, sender, authenticated_data)
+ .and_then(Self::OutputType::try_from)
+ }
+ Content::Commit(_) => self
+ .process_commit(auth_content, time_sent)
+ .await
+ .map(Self::OutputType::from),
+ #[cfg(feature = "by_ref_proposal")]
+ Content::Proposal(ref proposal) => self
+ .process_proposal(&auth_content, proposal, cache_proposal)
+ .await
+ .map(Self::OutputType::from),
+ }?;
+
+ Ok(event)
+ }
+
+ #[cfg(feature = "private_message")]
+ fn process_application_message(
+ &self,
+ data: ApplicationData,
+ sender: Sender,
+ authenticated_data: Vec<u8>,
+ ) -> Result<ApplicationMessageDescription, MlsError> {
+ let Sender::Member(sender_index) = sender else {
+ return Err(MlsError::InvalidSender);
+ };
+
+ Ok(ApplicationMessageDescription {
+ authenticated_data,
+ sender_index,
+ data,
+ })
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn process_proposal(
+ &mut self,
+ auth_content: &AuthenticatedContent,
+ proposal: &Proposal,
+ cache_proposal: bool,
+ ) -> Result<ProposalMessageDescription, MlsError> {
+ let proposal_ref =
+ ProposalRef::from_content(self.cipher_suite_provider(), auth_content).await?;
+
+ let group_state = self.group_state_mut();
+
+ if cache_proposal {
+ let proposal_ref = proposal_ref.clone();
+
+ group_state.proposals.insert(
+ proposal_ref.clone(),
+ proposal.clone(),
+ auth_content.content.sender,
+ );
+ }
+
+ Ok(ProposalMessageDescription {
+ authenticated_data: auth_content.content.authenticated_data.clone(),
+ proposal: proposal.clone(),
+ sender: auth_content.content.sender.try_into()?,
+ proposal_ref,
+ })
+ }
+
+ #[cfg(feature = "state_update")]
+ async fn make_state_update(
+ &self,
+ provisional: &ProvisionalState,
+ path: Option<&UpdatePath>,
+ sender: LeafIndex,
+ ) -> Result<StateUpdate, MlsError> {
+ let added = provisional
+ .applied_proposals
+ .additions
+ .iter()
+ .zip(provisional.indexes_of_added_kpkgs.iter())
+ .map(|(p, index)| member_from_key_package(&p.proposal.key_package, *index))
+ .collect::<Vec<_>>();
+
+ let mut added = added;
+
+ let old_tree = &self.group_state().public_tree;
+
+ let removed = provisional
+ .applied_proposals
+ .removals
+ .iter()
+ .map(|p| {
+ let index = p.proposal.to_remove;
+ let node = old_tree.nodes.borrow_as_leaf(index)?;
+ Ok(member_from_leaf_node(node, index))
+ })
+ .collect::<Result<_, MlsError>>()?;
+
+ #[cfg(feature = "by_ref_proposal")]
+ let mut updated = provisional
+ .applied_proposals
+ .update_senders
+ .iter()
+ .map(|index| {
+ let prior = old_tree
+ .get_leaf_node(*index)
+ .map(|n| member_from_leaf_node(n, *index))?;
+
+ let new = provisional
+ .public_tree
+ .get_leaf_node(*index)
+ .map(|n| member_from_leaf_node(n, *index))?;
+
+ Ok::<_, MlsError>(MemberUpdate::new(prior, new))
+ })
+ .collect::<Result<Vec<_>, _>>()?;
+
+ #[cfg(not(feature = "by_ref_proposal"))]
+ let mut updated = Vec::new();
+
+ if let Some(path) = path {
+ if !provisional
+ .applied_proposals
+ .external_initializations
+ .is_empty()
+ {
+ added.push(member_from_leaf_node(&path.leaf_node, sender))
+ } else {
+ let prior = old_tree
+ .get_leaf_node(sender)
+ .map(|n| member_from_leaf_node(n, sender))?;
+
+ let new = member_from_leaf_node(&path.leaf_node, sender);
+
+ updated.push(MemberUpdate::new(prior, new))
+ }
+ }
+
+ #[cfg(feature = "psk")]
+ let psks = provisional
+ .applied_proposals
+ .psks
+ .iter()
+ .filter_map(|psk| psk.proposal.external_psk_id().cloned())
+ .collect::<Vec<_>>();
+
+ let roster_update = RosterUpdate::new(added, removed, updated);
+
+ let update = StateUpdate {
+ roster_update,
+ #[cfg(feature = "psk")]
+ added_psks: psks,
+ pending_reinit: provisional
+ .applied_proposals
+ .reinitializations
+ .first()
+ .map(|ri| ri.proposal.new_cipher_suite()),
+ active: true,
+ epoch: provisional.group_context.epoch,
+ #[cfg(feature = "custom_proposal")]
+ custom_proposals: provisional.applied_proposals.custom_proposals.clone(),
+ #[cfg(feature = "by_ref_proposal")]
+ unused_proposals: provisional.unused_proposals.clone(),
+ };
+
+ Ok(update)
+ }
+
+ async fn process_commit(
+ &mut self,
+ auth_content: AuthenticatedContent,
+ time_sent: Option<MlsTime>,
+ ) -> Result<CommitMessageDescription, MlsError> {
+ if self.group_state().pending_reinit.is_some() {
+ return Err(MlsError::GroupUsedAfterReInit);
+ }
+
+ // Update the new GroupContext's confirmed and interim transcript hashes using the new Commit.
+ let (interim_transcript_hash, confirmed_transcript_hash) = transcript_hashes(
+ self.cipher_suite_provider(),
+ &self.group_state().interim_transcript_hash,
+ &auth_content,
+ )
+ .await?;
+
+ #[cfg(any(feature = "private_message", feature = "by_ref_proposal"))]
+ let commit = match auth_content.content.content {
+ Content::Commit(commit) => Ok(commit),
+ _ => Err(MlsError::UnexpectedMessageType),
+ }?;
+
+ #[cfg(not(any(feature = "private_message", feature = "by_ref_proposal")))]
+ let Content::Commit(commit) = auth_content.content.content;
+
+ let group_state = self.group_state();
+ let id_provider = self.identity_provider();
+
+ #[cfg(feature = "by_ref_proposal")]
+ let proposals = group_state
+ .proposals
+ .resolve_for_commit(auth_content.content.sender, commit.proposals)?;
+
+ #[cfg(not(feature = "by_ref_proposal"))]
+ let proposals = resolve_for_commit(auth_content.content.sender, commit.proposals)?;
+
+ let mut provisional_state = group_state
+ .apply_resolved(
+ auth_content.content.sender,
+ proposals,
+ commit.path.as_ref().map(|path| &path.leaf_node),
+ &id_provider,
+ self.cipher_suite_provider(),
+ &self.psk_storage(),
+ &self.mls_rules(),
+ time_sent,
+ CommitDirection::Receive,
+ )
+ .await?;
+
+ let sender = commit_sender(&auth_content.content.sender, &provisional_state)?;
+
+ #[cfg(feature = "state_update")]
+ let mut state_update = self
+ .make_state_update(&provisional_state, commit.path.as_ref(), sender)
+ .await?;
+
+ #[cfg(not(feature = "state_update"))]
+ let state_update = StateUpdate {};
+
+ //Verify that the path value is populated if the proposals vector contains any Update
+ // or Remove proposals, or if it's empty. Otherwise, the path value MAY be omitted.
+ if path_update_required(&provisional_state.applied_proposals) && commit.path.is_none() {
+ return Err(MlsError::CommitMissingPath);
+ }
+
+ if !self.can_continue_processing(&provisional_state) {
+ #[cfg(feature = "state_update")]
+ {
+ state_update.active = false;
+ }
+
+ return Ok(CommitMessageDescription {
+ is_external: matches!(auth_content.content.sender, Sender::NewMemberCommit),
+ authenticated_data: auth_content.content.authenticated_data,
+ committer: *sender,
+ state_update,
+ });
+ }
+
+ let update_path = match commit.path {
+ Some(update_path) => Some(
+ validate_update_path(
+ &self.identity_provider(),
+ self.cipher_suite_provider(),
+ update_path,
+ &provisional_state,
+ sender,
+ time_sent,
+ )
+ .await?,
+ ),
+ None => None,
+ };
+
+ let new_secrets = match update_path {
+ Some(update_path) => {
+ self.apply_update_path(sender, &update_path, &mut provisional_state)
+ .await
+ }
+ None => Ok(None),
+ }?;
+
+ // Update the transcript hash to get the new context.
+ provisional_state.group_context.confirmed_transcript_hash = confirmed_transcript_hash;
+
+ // Update the parent hashes in the new context
+ provisional_state
+ .public_tree
+ .update_hashes(&[sender], self.cipher_suite_provider())
+ .await?;
+
+ // Update the tree hash in the new context
+ provisional_state.group_context.tree_hash = provisional_state
+ .public_tree
+ .tree_hash(self.cipher_suite_provider())
+ .await?;
+
+ if let Some(reinit) = provisional_state.applied_proposals.reinitializations.pop() {
+ self.group_state_mut().pending_reinit = Some(reinit.proposal);
+
+ #[cfg(feature = "state_update")]
+ {
+ state_update.active = false;
+ }
+ }
+
+ if let Some(confirmation_tag) = &auth_content.auth.confirmation_tag {
+ // Update the key schedule to calculate new private keys
+ self.update_key_schedule(
+ new_secrets,
+ interim_transcript_hash,
+ confirmation_tag,
+ provisional_state,
+ )
+ .await?;
+
+ Ok(CommitMessageDescription {
+ is_external: matches!(auth_content.content.sender, Sender::NewMemberCommit),
+ authenticated_data: auth_content.content.authenticated_data,
+ committer: *sender,
+ state_update,
+ })
+ } else {
+ Err(MlsError::InvalidConfirmationTag)
+ }
+ }
+
+ fn group_state(&self) -> &GroupState;
+ fn group_state_mut(&mut self) -> &mut GroupState;
+ #[cfg(feature = "private_message")]
+ fn self_index(&self) -> Option<LeafIndex>;
+ fn mls_rules(&self) -> Self::MlsRules;
+ fn identity_provider(&self) -> Self::IdentityProvider;
+ fn cipher_suite_provider(&self) -> &Self::CipherSuiteProvider;
+ fn psk_storage(&self) -> Self::PreSharedKeyStorage;
+ fn can_continue_processing(&self, provisional_state: &ProvisionalState) -> bool;
+
+ #[cfg(feature = "private_message")]
+ fn min_epoch_available(&self) -> Option<u64>;
+
+ fn check_metadata(&self, message: &MlsMessage) -> Result<(), MlsError> {
+ let context = &self.group_state().context;
+
+ if message.version != context.protocol_version {
+ return Err(MlsError::ProtocolVersionMismatch);
+ }
+
+ if let Some((group_id, epoch, content_type)) = match &message.payload {
+ MlsMessagePayload::Plain(plaintext) => Some((
+ &plaintext.content.group_id,
+ plaintext.content.epoch,
+ plaintext.content.content_type(),
+ )),
+ #[cfg(feature = "private_message")]
+ MlsMessagePayload::Cipher(ciphertext) => Some((
+ &ciphertext.group_id,
+ ciphertext.epoch,
+ ciphertext.content_type,
+ )),
+ _ => None,
+ } {
+ if group_id != &context.group_id {
+ return Err(MlsError::GroupIdMismatch);
+ }
+
+ match content_type {
+ ContentType::Commit => {
+ if context.epoch != epoch {
+ Err(MlsError::InvalidEpoch)
+ } else {
+ Ok(())
+ }
+ }
+ #[cfg(feature = "by_ref_proposal")]
+ ContentType::Proposal => {
+ if context.epoch != epoch {
+ Err(MlsError::InvalidEpoch)
+ } else {
+ Ok(())
+ }
+ }
+ #[cfg(feature = "private_message")]
+ ContentType::Application => {
+ if let Some(min) = self.min_epoch_available() {
+ if epoch < min {
+ Err(MlsError::InvalidEpoch)
+ } else {
+ Ok(())
+ }
+ } else {
+ Ok(())
+ }
+ }
+ }?;
+
+ // Proposal and commit messages must be sent in the current epoch
+ let check_epoch = content_type == ContentType::Commit;
+
+ #[cfg(feature = "by_ref_proposal")]
+ let check_epoch = check_epoch || content_type == ContentType::Proposal;
+
+ if check_epoch && epoch != context.epoch {
+ return Err(MlsError::InvalidEpoch);
+ }
+
+ // Unencrypted application messages are not allowed
+ #[cfg(feature = "private_message")]
+ if !matches!(&message.payload, MlsMessagePayload::Cipher(_))
+ && content_type == ContentType::Application
+ {
+ return Err(MlsError::UnencryptedApplicationMessage);
+ }
+ }
+
+ Ok(())
+ }
+
+ fn validate_welcome(
+ &self,
+ welcome: &Welcome,
+ version: ProtocolVersion,
+ ) -> Result<(), MlsError> {
+ let state = self.group_state();
+
+ (welcome.cipher_suite == state.context.cipher_suite
+ && version == state.context.protocol_version)
+ .then_some(())
+ .ok_or(MlsError::InvalidWelcomeMessage)
+ }
+
+ async fn validate_key_package(
+ &self,
+ key_package: &KeyPackage,
+ version: ProtocolVersion,
+ ) -> Result<(), MlsError> {
+ let cs = self.cipher_suite_provider();
+ let id = self.identity_provider();
+
+ validate_key_package(key_package, version, cs, &id).await
+ }
+
+ #[cfg(feature = "private_message")]
+ async fn process_ciphertext(
+ &mut self,
+ cipher_text: &PrivateMessage,
+ ) -> Result<EventOrContent<Self::OutputType>, MlsError>;
+
+ async fn verify_plaintext_authentication(
+ &self,
+ message: PublicMessage,
+ ) -> Result<EventOrContent<Self::OutputType>, MlsError>;
+
+ async fn apply_update_path(
+ &mut self,
+ sender: LeafIndex,
+ update_path: &ValidatedUpdatePath,
+ provisional_state: &mut ProvisionalState,
+ ) -> Result<Option<(TreeKemPrivate, PathSecret)>, MlsError> {
+ provisional_state
+ .public_tree
+ .apply_update_path(
+ sender,
+ update_path,
+ &provisional_state.group_context.extensions,
+ self.identity_provider(),
+ self.cipher_suite_provider(),
+ )
+ .await
+ .map(|_| None)
+ }
+
+ async fn update_key_schedule(
+ &mut self,
+ secrets: Option<(TreeKemPrivate, PathSecret)>,
+ interim_transcript_hash: InterimTranscriptHash,
+ confirmation_tag: &ConfirmationTag,
+ provisional_public_state: ProvisionalState,
+ ) -> Result<(), MlsError>;
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(crate) async fn validate_key_package<C: CipherSuiteProvider, I: IdentityProvider>(
+ key_package: &KeyPackage,
+ version: ProtocolVersion,
+ cs: &C,
+ id: &I,
+) -> Result<(), MlsError> {
+ let validator = LeafNodeValidator::new(cs, id, None);
+
+ #[cfg(feature = "std")]
+ let context = Some(MlsTime::now());
+
+ #[cfg(not(feature = "std"))]
+ let context = None;
+
+ let context = ValidationContext::Add(context);
+
+ validator
+ .check_if_valid(&key_package.leaf_node, context)
+ .await?;
+
+ validate_key_package_properties(key_package, version, cs).await?;
+
+ Ok(())
+}
diff --git a/src/group/message_signature.rs b/src/group/message_signature.rs
new file mode 100644
index 0000000..3c08935
--- /dev/null
+++ b/src/group/message_signature.rs
@@ -0,0 +1,274 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use super::framing::Content;
+use crate::client::MlsError;
+use crate::crypto::SignatureSecretKey;
+use crate::group::framing::{ContentType, FramedContent, PublicMessage, Sender, WireFormat};
+use crate::group::{ConfirmationTag, GroupContext};
+use crate::signer::Signable;
+use crate::CipherSuiteProvider;
+use alloc::vec;
+use alloc::vec::Vec;
+use core::{
+ fmt::{self, Debug},
+ ops::Deref,
+};
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+use mls_rs_core::protocol_version::ProtocolVersion;
+
+#[derive(Clone, Debug, PartialEq)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub struct FramedContentAuthData {
+ pub signature: MessageSignature,
+ pub confirmation_tag: Option<ConfirmationTag>,
+}
+
+impl MlsSize for FramedContentAuthData {
+ fn mls_encoded_len(&self) -> usize {
+ self.signature.mls_encoded_len()
+ + self
+ .confirmation_tag
+ .as_ref()
+ .map_or(0, |tag| tag.mls_encoded_len())
+ }
+}
+
+impl MlsEncode for FramedContentAuthData {
+ fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error> {
+ self.signature.mls_encode(writer)?;
+
+ if let Some(ref tag) = self.confirmation_tag {
+ tag.mls_encode(writer)?;
+ }
+
+ Ok(())
+ }
+}
+
+impl FramedContentAuthData {
+ pub(crate) fn mls_decode(
+ reader: &mut &[u8],
+ content_type: ContentType,
+ ) -> Result<Self, mls_rs_codec::Error> {
+ Ok(FramedContentAuthData {
+ signature: MessageSignature::mls_decode(reader)?,
+ confirmation_tag: match content_type {
+ ContentType::Commit => Some(ConfirmationTag::mls_decode(reader)?),
+ #[cfg(feature = "private_message")]
+ ContentType::Application => None,
+ #[cfg(feature = "by_ref_proposal")]
+ ContentType::Proposal => None,
+ },
+ })
+ }
+}
+
+#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub struct AuthenticatedContent {
+ pub(crate) wire_format: WireFormat,
+ pub(crate) content: FramedContent,
+ pub(crate) auth: FramedContentAuthData,
+}
+
+impl From<PublicMessage> for AuthenticatedContent {
+ fn from(p: PublicMessage) -> Self {
+ Self {
+ wire_format: WireFormat::PublicMessage,
+ content: p.content,
+ auth: p.auth,
+ }
+ }
+}
+
+impl AuthenticatedContent {
+ pub(crate) fn new(
+ context: &GroupContext,
+ sender: Sender,
+ content: Content,
+ authenticated_data: Vec<u8>,
+ wire_format: WireFormat,
+ ) -> AuthenticatedContent {
+ AuthenticatedContent {
+ wire_format,
+ content: FramedContent {
+ group_id: context.group_id.clone(),
+ epoch: context.epoch,
+ sender,
+ authenticated_data,
+ content,
+ },
+ auth: FramedContentAuthData {
+ signature: MessageSignature::empty(),
+ confirmation_tag: None,
+ },
+ }
+ }
+
+ #[inline(never)]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn new_signed<P: CipherSuiteProvider>(
+ signature_provider: &P,
+ context: &GroupContext,
+ sender: Sender,
+ content: Content,
+ signer: &SignatureSecretKey,
+ wire_format: WireFormat,
+ authenticated_data: Vec<u8>,
+ ) -> Result<AuthenticatedContent, MlsError> {
+ // Construct an MlsPlaintext object containing the content
+ let mut plaintext =
+ AuthenticatedContent::new(context, sender, content, authenticated_data, wire_format);
+
+ let signing_context = MessageSigningContext {
+ group_context: Some(context),
+ protocol_version: context.protocol_version,
+ };
+
+ // Sign the MlsPlaintext using the current epoch's GroupContext as context.
+ plaintext
+ .sign(signature_provider, signer, &signing_context)
+ .await?;
+
+ Ok(plaintext)
+ }
+}
+
+impl MlsDecode for AuthenticatedContent {
+ fn mls_decode(reader: &mut &[u8]) -> Result<Self, mls_rs_codec::Error> {
+ let wire_format = WireFormat::mls_decode(reader)?;
+ let content = FramedContent::mls_decode(reader)?;
+ let auth_data = FramedContentAuthData::mls_decode(reader, content.content_type())?;
+
+ Ok(AuthenticatedContent {
+ wire_format,
+ content,
+ auth: auth_data,
+ })
+ }
+}
+
+#[derive(Clone, Debug, PartialEq)]
+pub(crate) struct AuthenticatedContentTBS<'a> {
+ pub(crate) protocol_version: ProtocolVersion,
+ pub(crate) wire_format: WireFormat,
+ pub(crate) content: &'a FramedContent,
+ pub(crate) context: Option<&'a GroupContext>,
+}
+
+impl<'a> MlsSize for AuthenticatedContentTBS<'a> {
+ fn mls_encoded_len(&self) -> usize {
+ self.protocol_version.mls_encoded_len()
+ + self.wire_format.mls_encoded_len()
+ + self.content.mls_encoded_len()
+ + self.context.as_ref().map_or(0, |ctx| ctx.mls_encoded_len())
+ }
+}
+
+impl<'a> MlsEncode for AuthenticatedContentTBS<'a> {
+ fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error> {
+ self.protocol_version.mls_encode(writer)?;
+ self.wire_format.mls_encode(writer)?;
+ self.content.mls_encode(writer)?;
+
+ if let Some(context) = self.context {
+ context.mls_encode(writer)?;
+ }
+
+ Ok(())
+ }
+}
+
+impl<'a> AuthenticatedContentTBS<'a> {
+ /// The group context must not be `None` when the sender is `Member` or `NewMember`.
+ pub(crate) fn from_authenticated_content(
+ auth_content: &'a AuthenticatedContent,
+ group_context: Option<&'a GroupContext>,
+ protocol_version: ProtocolVersion,
+ ) -> Self {
+ AuthenticatedContentTBS {
+ protocol_version,
+ wire_format: auth_content.wire_format,
+ content: &auth_content.content,
+ context: match auth_content.content.sender {
+ Sender::Member(_) | Sender::NewMemberCommit => group_context,
+ #[cfg(feature = "by_ref_proposal")]
+ Sender::External(_) => None,
+ #[cfg(feature = "by_ref_proposal")]
+ Sender::NewMemberProposal => None,
+ },
+ }
+ }
+}
+
+#[derive(Debug)]
+pub(crate) struct MessageSigningContext<'a> {
+ pub group_context: Option<&'a GroupContext>,
+ pub protocol_version: ProtocolVersion,
+}
+
+impl<'a> Signable<'a> for AuthenticatedContent {
+ const SIGN_LABEL: &'static str = "FramedContentTBS";
+
+ type SigningContext = MessageSigningContext<'a>;
+
+ fn signature(&self) -> &[u8] {
+ &self.auth.signature
+ }
+
+ fn signable_content(
+ &self,
+ context: &MessageSigningContext,
+ ) -> Result<Vec<u8>, mls_rs_codec::Error> {
+ AuthenticatedContentTBS::from_authenticated_content(
+ self,
+ context.group_context,
+ context.protocol_version,
+ )
+ .mls_encode_to_vec()
+ }
+
+ fn write_signature(&mut self, signature: Vec<u8>) {
+ self.auth.signature = MessageSignature::from(signature)
+ }
+}
+
+#[derive(Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub struct MessageSignature(
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
+ Vec<u8>,
+);
+
+impl Debug for MessageSignature {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ mls_rs_core::debug::pretty_bytes(&self.0)
+ .named("MessageSignature")
+ .fmt(f)
+ }
+}
+
+impl MessageSignature {
+ pub(crate) fn empty() -> Self {
+ MessageSignature(vec![])
+ }
+}
+
+impl Deref for MessageSignature {
+ type Target = Vec<u8>;
+
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+impl From<Vec<u8>> for MessageSignature {
+ fn from(v: Vec<u8>) -> Self {
+ MessageSignature(v)
+ }
+}
diff --git a/src/group/message_verifier.rs b/src/group/message_verifier.rs
new file mode 100644
index 0000000..7a2bc59
--- /dev/null
+++ b/src/group/message_verifier.rs
@@ -0,0 +1,680 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+#[cfg(feature = "by_ref_proposal")]
+use alloc::{vec, vec::Vec};
+
+use crate::{
+ client::MlsError,
+ crypto::SignaturePublicKey,
+ group::{GroupContext, PublicMessage, Sender},
+ signer::Signable,
+ tree_kem::{node::LeafIndex, TreeKemPublic},
+ CipherSuiteProvider,
+};
+
+#[cfg(feature = "by_ref_proposal")]
+use crate::{extension::ExternalSendersExt, identity::SigningIdentity};
+
+use super::{
+ key_schedule::KeySchedule,
+ message_signature::{AuthenticatedContent, MessageSigningContext},
+ state::GroupState,
+};
+
+#[cfg(feature = "by_ref_proposal")]
+use super::proposal::Proposal;
+
+#[derive(Debug)]
+pub(crate) enum SignaturePublicKeysContainer<'a> {
+ RatchetTree(&'a TreeKemPublic),
+ #[cfg(feature = "private_message")]
+ List(&'a [Option<SignaturePublicKey>]),
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(crate) async fn verify_plaintext_authentication<P: CipherSuiteProvider>(
+ cipher_suite_provider: &P,
+ plaintext: PublicMessage,
+ key_schedule: Option<&KeySchedule>,
+ self_index: Option<LeafIndex>,
+ state: &GroupState,
+) -> Result<AuthenticatedContent, MlsError> {
+ let tag = plaintext.membership_tag.clone();
+ let auth_content = AuthenticatedContent::from(plaintext);
+ let context = &state.context;
+
+ #[cfg(feature = "by_ref_proposal")]
+ let external_signers = external_signers(context);
+
+ let current_tree = &state.public_tree;
+
+ // Verify the membership tag if needed
+ match &auth_content.content.sender {
+ Sender::Member(index) => {
+ if let Some(key_schedule) = key_schedule {
+ let expected_tag = &key_schedule
+ .get_membership_tag(&auth_content, context, cipher_suite_provider)
+ .await?;
+
+ let plaintext_tag = tag.as_ref().ok_or(MlsError::InvalidMembershipTag)?;
+
+ if expected_tag != plaintext_tag {
+ return Err(MlsError::InvalidMembershipTag);
+ }
+ }
+
+ if self_index == Some(LeafIndex(*index)) {
+ return Err(MlsError::CantProcessMessageFromSelf);
+ }
+ }
+ _ => {
+ tag.is_none()
+ .then_some(())
+ .ok_or(MlsError::MembershipTagForNonMember)?;
+ }
+ }
+
+ // Verify that the signature on the MLSAuthenticatedContent verifies using the public key
+ // from the credential stored at the leaf in the tree indicated by the sender field.
+ verify_auth_content_signature(
+ cipher_suite_provider,
+ SignaturePublicKeysContainer::RatchetTree(current_tree),
+ context,
+ &auth_content,
+ #[cfg(feature = "by_ref_proposal")]
+ &external_signers,
+ )
+ .await?;
+
+ Ok(auth_content)
+}
+
+#[cfg(feature = "by_ref_proposal")]
+fn external_signers(context: &GroupContext) -> Vec<SigningIdentity> {
+ context
+ .extensions
+ .get_as::<ExternalSendersExt>()
+ .unwrap_or(None)
+ .map_or(vec![], |extern_senders_ext| {
+ extern_senders_ext.allowed_senders
+ })
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(crate) async fn verify_auth_content_signature<P: CipherSuiteProvider>(
+ cipher_suite_provider: &P,
+ signature_keys_container: SignaturePublicKeysContainer<'_>,
+ context: &GroupContext,
+ auth_content: &AuthenticatedContent,
+ #[cfg(feature = "by_ref_proposal")] external_signers: &[SigningIdentity],
+) -> Result<(), MlsError> {
+ let sender_public_key = signing_identity_for_sender(
+ signature_keys_container,
+ &auth_content.content.sender,
+ &auth_content.content.content,
+ #[cfg(feature = "by_ref_proposal")]
+ external_signers,
+ )?;
+
+ let context = MessageSigningContext {
+ group_context: Some(context),
+ protocol_version: context.protocol_version,
+ };
+
+ auth_content
+ .verify(cipher_suite_provider, &sender_public_key, &context)
+ .await?;
+
+ Ok(())
+}
+
+fn signing_identity_for_sender(
+ signature_keys_container: SignaturePublicKeysContainer,
+ sender: &Sender,
+ content: &super::framing::Content,
+ #[cfg(feature = "by_ref_proposal")] external_signers: &[SigningIdentity],
+) -> Result<SignaturePublicKey, MlsError> {
+ match sender {
+ Sender::Member(leaf_index) => {
+ signing_identity_for_member(signature_keys_container, LeafIndex(*leaf_index))
+ }
+ #[cfg(feature = "by_ref_proposal")]
+ Sender::External(external_key_index) => {
+ signing_identity_for_external(*external_key_index, external_signers)
+ }
+ Sender::NewMemberCommit => signing_identity_for_new_member_commit(content),
+ #[cfg(feature = "by_ref_proposal")]
+ Sender::NewMemberProposal => signing_identity_for_new_member_proposal(content),
+ }
+}
+
+fn signing_identity_for_member(
+ signature_keys_container: SignaturePublicKeysContainer,
+ leaf_index: LeafIndex,
+) -> Result<SignaturePublicKey, MlsError> {
+ match signature_keys_container {
+ SignaturePublicKeysContainer::RatchetTree(tree) => Ok(tree
+ .get_leaf_node(leaf_index)?
+ .signing_identity
+ .signature_key
+ .clone()), // TODO: We can probably get rid of this clone
+ #[cfg(feature = "private_message")]
+ SignaturePublicKeysContainer::List(list) => list
+ .get(leaf_index.0 as usize)
+ .cloned()
+ .flatten()
+ .ok_or(MlsError::LeafNotFound(*leaf_index)),
+ }
+}
+
+#[cfg(feature = "by_ref_proposal")]
+fn signing_identity_for_external(
+ index: u32,
+ external_signers: &[SigningIdentity],
+) -> Result<SignaturePublicKey, MlsError> {
+ external_signers
+ .get(index as usize)
+ .map(|spk| spk.signature_key.clone())
+ .ok_or(MlsError::UnknownSigningIdentityForExternalSender)
+}
+
+fn signing_identity_for_new_member_commit(
+ content: &super::framing::Content,
+) -> Result<SignaturePublicKey, MlsError> {
+ match content {
+ super::framing::Content::Commit(commit) => {
+ if let Some(path) = &commit.path {
+ Ok(path.leaf_node.signing_identity.signature_key.clone())
+ } else {
+ Err(MlsError::CommitMissingPath)
+ }
+ }
+ #[cfg(any(feature = "private_message", feature = "by_ref_proposal"))]
+ _ => Err(MlsError::ExpectedCommitForNewMemberCommit),
+ }
+}
+
+#[cfg(feature = "by_ref_proposal")]
+fn signing_identity_for_new_member_proposal(
+ content: &super::framing::Content,
+) -> Result<SignaturePublicKey, MlsError> {
+ match content {
+ super::framing::Content::Proposal(proposal) => {
+ if let Proposal::Add(p) = proposal.as_ref() {
+ Ok(p.key_package
+ .leaf_node
+ .signing_identity
+ .signature_key
+ .clone())
+ } else {
+ Err(MlsError::ExpectedAddProposalForNewMemberProposal)
+ }
+ }
+ _ => Err(MlsError::ExpectedAddProposalForNewMemberProposal),
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use crate::{
+ client::{
+ test_utils::{test_client_with_key_pkg, TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
+ MlsError,
+ },
+ client_builder::test_utils::TestClientConfig,
+ crypto::test_utils::test_cipher_suite_provider,
+ group::{
+ membership_tag::MembershipTag,
+ message_signature::{AuthenticatedContent, MessageSignature},
+ test_utils::{test_group_custom, TestGroup},
+ Group, PublicMessage,
+ },
+ tree_kem::node::LeafIndex,
+ };
+ use alloc::vec;
+ use assert_matches::assert_matches;
+
+ #[cfg(feature = "by_ref_proposal")]
+ use crate::{extension::ExternalSendersExt, ExtensionList};
+
+ #[cfg(feature = "by_ref_proposal")]
+ use crate::{
+ crypto::SignatureSecretKey,
+ group::{
+ message_signature::MessageSigningContext,
+ proposal::{AddProposal, Proposal, RemoveProposal},
+ Content,
+ },
+ key_package::KeyPackageGeneration,
+ signer::Signable,
+ WireFormat,
+ };
+
+ #[cfg(feature = "by_ref_proposal")]
+ use alloc::boxed::Box;
+
+ use crate::group::{
+ test_utils::{test_group, test_member},
+ Sender,
+ };
+
+ #[cfg(feature = "by_ref_proposal")]
+ use crate::identity::test_utils::get_test_signing_identity;
+
+ use super::{verify_auth_content_signature, verify_plaintext_authentication};
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn make_signed_plaintext(group: &mut Group<TestClientConfig>) -> PublicMessage {
+ group
+ .commit(vec![])
+ .await
+ .unwrap()
+ .commit_message
+ .into_plaintext()
+ .unwrap()
+ }
+
+ struct TestEnv {
+ alice: TestGroup,
+ bob: TestGroup,
+ }
+
+ impl TestEnv {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn new() -> Self {
+ let mut alice = test_group_custom(
+ TEST_PROTOCOL_VERSION,
+ TEST_CIPHER_SUITE,
+ Default::default(),
+ None,
+ None,
+ )
+ .await;
+
+ let (bob_client, bob_key_pkg) =
+ test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "bob").await;
+
+ let commit_output = alice
+ .group
+ .commit_builder()
+ .add_member(bob_key_pkg)
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ alice.group.apply_pending_commit().await.unwrap();
+
+ let (bob, _) = Group::join(
+ &commit_output.welcome_messages[0],
+ None,
+ bob_client.config,
+ bob_client.signer.unwrap(),
+ )
+ .await
+ .unwrap();
+
+ TestEnv {
+ alice,
+ bob: TestGroup { group: bob },
+ }
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn valid_plaintext_is_verified() {
+ let mut env = TestEnv::new().await;
+
+ let message = make_signed_plaintext(&mut env.alice.group).await;
+
+ verify_plaintext_authentication(
+ &env.bob.group.cipher_suite_provider,
+ message,
+ Some(&env.bob.group.key_schedule),
+ None,
+ &env.bob.group.state,
+ )
+ .await
+ .unwrap();
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn valid_auth_content_is_verified() {
+ let mut env = TestEnv::new().await;
+
+ let message = AuthenticatedContent::from(make_signed_plaintext(&mut env.alice.group).await);
+
+ verify_auth_content_signature(
+ &env.bob.group.cipher_suite_provider,
+ super::SignaturePublicKeysContainer::RatchetTree(&env.bob.group.state.public_tree),
+ env.bob.group.context(),
+ &message,
+ #[cfg(feature = "by_ref_proposal")]
+ &[],
+ )
+ .await
+ .unwrap();
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn invalid_plaintext_is_not_verified() {
+ let mut env = TestEnv::new().await;
+ let mut message = make_signed_plaintext(&mut env.alice.group).await;
+ message.auth.signature = MessageSignature::from(b"test".to_vec());
+
+ message.membership_tag = env
+ .alice
+ .group
+ .key_schedule
+ .get_membership_tag(
+ &AuthenticatedContent::from(message.clone()),
+ env.alice.group.context(),
+ &test_cipher_suite_provider(env.alice.group.cipher_suite()),
+ )
+ .await
+ .unwrap()
+ .into();
+
+ let res = verify_plaintext_authentication(
+ &env.bob.group.cipher_suite_provider,
+ message,
+ Some(&env.bob.group.key_schedule),
+ None,
+ &env.bob.group.state,
+ )
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidSignature));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn plaintext_from_member_requires_membership_tag() {
+ let mut env = TestEnv::new().await;
+ let mut message = make_signed_plaintext(&mut env.alice.group).await;
+ message.membership_tag = None;
+
+ let res = verify_plaintext_authentication(
+ &env.bob.group.cipher_suite_provider,
+ message,
+ Some(&env.bob.group.key_schedule),
+ None,
+ &env.bob.group.state,
+ )
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidMembershipTag));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn plaintext_fails_with_invalid_membership_tag() {
+ let mut env = TestEnv::new().await;
+ let mut message = make_signed_plaintext(&mut env.alice.group).await;
+ message.membership_tag = Some(MembershipTag::from(b"test".to_vec()));
+
+ let res = verify_plaintext_authentication(
+ &env.bob.group.cipher_suite_provider,
+ message,
+ Some(&env.bob.group.key_schedule),
+ None,
+ &env.bob.group.state,
+ )
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidMembershipTag));
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn test_new_member_proposal<F>(
+ key_pkg_gen: KeyPackageGeneration,
+ signer: &SignatureSecretKey,
+ test_group: &TestGroup,
+ mut edit: F,
+ ) -> PublicMessage
+ where
+ F: FnMut(&mut AuthenticatedContent),
+ {
+ let mut content = AuthenticatedContent::new_signed(
+ &test_group.group.cipher_suite_provider,
+ test_group.group.context(),
+ Sender::NewMemberProposal,
+ Content::Proposal(Box::new(Proposal::Add(Box::new(AddProposal {
+ key_package: key_pkg_gen.key_package,
+ })))),
+ signer,
+ WireFormat::PublicMessage,
+ vec![],
+ )
+ .await
+ .unwrap();
+
+ edit(&mut content);
+
+ let signing_context = MessageSigningContext {
+ group_context: Some(test_group.group.context()),
+ protocol_version: test_group.group.protocol_version(),
+ };
+
+ content
+ .sign(
+ &test_group.group.cipher_suite_provider,
+ signer,
+ &signing_context,
+ )
+ .await
+ .unwrap();
+
+ PublicMessage {
+ content: content.content,
+ auth: content.auth,
+ membership_tag: None,
+ }
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn valid_proposal_from_new_member_is_verified() {
+ let test_group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let (key_pkg_gen, signer) =
+ test_member(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, b"bob").await;
+ let message = test_new_member_proposal(key_pkg_gen, &signer, &test_group, |_| {}).await;
+
+ verify_plaintext_authentication(
+ &test_group.group.cipher_suite_provider,
+ message,
+ Some(&test_group.group.key_schedule),
+ None,
+ &test_group.group.state,
+ )
+ .await
+ .unwrap();
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn proposal_from_new_member_must_not_have_membership_tag() {
+ let test_group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let (key_pkg_gen, signer) =
+ test_member(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, b"bob").await;
+
+ let mut message = test_new_member_proposal(key_pkg_gen, &signer, &test_group, |_| {}).await;
+ message.membership_tag = Some(MembershipTag::from(vec![]));
+
+ let res = verify_plaintext_authentication(
+ &test_group.group.cipher_suite_provider,
+ message,
+ Some(&test_group.group.key_schedule),
+ None,
+ &test_group.group.state,
+ )
+ .await;
+
+ assert_matches!(res, Err(MlsError::MembershipTagForNonMember));
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn new_member_proposal_sender_must_be_add_proposal() {
+ let test_group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let (key_pkg_gen, signer) =
+ test_member(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, b"bob").await;
+
+ let message = test_new_member_proposal(key_pkg_gen, &signer, &test_group, |msg| {
+ msg.content.content = Content::Proposal(Box::new(Proposal::Remove(RemoveProposal {
+ to_remove: LeafIndex(0),
+ })))
+ })
+ .await;
+
+ let res: Result<AuthenticatedContent, MlsError> = verify_plaintext_authentication(
+ &test_group.group.cipher_suite_provider,
+ message,
+ Some(&test_group.group.key_schedule),
+ None,
+ &test_group.group.state,
+ )
+ .await;
+
+ assert_matches!(res, Err(MlsError::ExpectedAddProposalForNewMemberProposal));
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn new_member_commit_must_be_external_commit() {
+ let test_group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let (key_pkg_gen, signer) =
+ test_member(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, b"bob").await;
+
+ let message = test_new_member_proposal(key_pkg_gen, &signer, &test_group, |msg| {
+ msg.content.sender = Sender::NewMemberCommit;
+ })
+ .await;
+
+ let res = verify_plaintext_authentication(
+ &test_group.group.cipher_suite_provider,
+ message,
+ Some(&test_group.group.key_schedule),
+ None,
+ &test_group.group.state,
+ )
+ .await;
+
+ assert_matches!(res, Err(MlsError::ExpectedCommitForNewMemberCommit));
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn valid_proposal_from_external_is_verified() {
+ let (bob_key_pkg_gen, _) =
+ test_member(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, b"bob").await;
+
+ let (ted_signing, ted_secret) = get_test_signing_identity(TEST_CIPHER_SUITE, b"ted").await;
+
+ let mut test_group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let mut extensions = ExtensionList::default();
+
+ extensions
+ .set_from(ExternalSendersExt {
+ allowed_senders: vec![ted_signing],
+ })
+ .unwrap();
+
+ test_group
+ .group
+ .commit_builder()
+ .set_group_context_ext(extensions)
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ test_group.group.apply_pending_commit().await.unwrap();
+
+ let message = test_new_member_proposal(bob_key_pkg_gen, &ted_secret, &test_group, |msg| {
+ msg.content.sender = Sender::External(0)
+ })
+ .await;
+
+ verify_plaintext_authentication(
+ &test_group.group.cipher_suite_provider,
+ message,
+ Some(&test_group.group.key_schedule),
+ None,
+ &test_group.group.state,
+ )
+ .await
+ .unwrap();
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn external_proposal_must_be_from_valid_sender() {
+ let (bob_key_pkg_gen, _) =
+ test_member(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, b"bob").await;
+ let (_, ted_secret) = get_test_signing_identity(TEST_CIPHER_SUITE, b"ted").await;
+ let test_group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+
+ let message = test_new_member_proposal(bob_key_pkg_gen, &ted_secret, &test_group, |msg| {
+ msg.content.sender = Sender::External(0)
+ })
+ .await;
+
+ let res = verify_plaintext_authentication(
+ &test_group.group.cipher_suite_provider,
+ message,
+ Some(&test_group.group.key_schedule),
+ None,
+ &test_group.group.state,
+ )
+ .await;
+
+ assert_matches!(res, Err(MlsError::UnknownSigningIdentityForExternalSender));
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn proposal_from_external_sender_must_not_have_membership_tag() {
+ let (bob_key_pkg_gen, _) =
+ test_member(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, b"bob").await;
+
+ let (_, ted_secret) = get_test_signing_identity(TEST_CIPHER_SUITE, b"ted").await;
+
+ let test_group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+
+ let mut message =
+ test_new_member_proposal(bob_key_pkg_gen, &ted_secret, &test_group, |_| {}).await;
+
+ message.membership_tag = Some(MembershipTag::from(vec![]));
+
+ let res = verify_plaintext_authentication(
+ &test_group.group.cipher_suite_provider,
+ message,
+ Some(&test_group.group.key_schedule),
+ None,
+ &test_group.group.state,
+ )
+ .await;
+
+ assert_matches!(res, Err(MlsError::MembershipTagForNonMember));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn plaintext_from_self_fails_verification() {
+ let mut env = TestEnv::new().await;
+
+ let message = make_signed_plaintext(&mut env.alice.group).await;
+
+ let res = verify_plaintext_authentication(
+ &env.alice.group.cipher_suite_provider,
+ message,
+ Some(&env.alice.group.key_schedule),
+ Some(LeafIndex::new(env.alice.group.current_member_index())),
+ &env.alice.group.state,
+ )
+ .await;
+
+ assert_matches!(res, Err(MlsError::CantProcessMessageFromSelf))
+ }
+}
diff --git a/src/group/mls_rules.rs b/src/group/mls_rules.rs
new file mode 100644
index 0000000..98b1dac
--- /dev/null
+++ b/src/group/mls_rules.rs
@@ -0,0 +1,283 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use crate::group::{proposal_filter::ProposalBundle, Roster};
+
+#[cfg(feature = "private_message")]
+use crate::{
+ group::{padding::PaddingMode, Sender},
+ WireFormat,
+};
+
+use alloc::boxed::Box;
+use core::convert::Infallible;
+use mls_rs_core::{
+ error::IntoAnyError, extension::ExtensionList, group::Member, identity::SigningIdentity,
+};
+
+#[derive(Copy, Clone, Debug, PartialEq, Eq)]
+pub enum CommitDirection {
+ Send,
+ Receive,
+}
+
+/// The source of the commit: either a current member or a new member joining
+/// via external commit.
+#[derive(Clone, Debug, PartialEq, Eq)]
+pub enum CommitSource {
+ ExistingMember(Member),
+ NewMember(SigningIdentity),
+}
+
+/// Options controlling commit generation
+#[derive(Clone, Copy, Debug, PartialEq, Eq)]
+#[non_exhaustive]
+pub struct CommitOptions {
+ pub path_required: bool,
+ pub ratchet_tree_extension: bool,
+ pub single_welcome_message: bool,
+ pub allow_external_commit: bool,
+}
+
+impl Default for CommitOptions {
+ fn default() -> Self {
+ CommitOptions {
+ path_required: false,
+ ratchet_tree_extension: true,
+ single_welcome_message: true,
+ allow_external_commit: false,
+ }
+ }
+}
+
+impl CommitOptions {
+ pub fn new() -> Self {
+ Self::default()
+ }
+
+ pub fn with_path_required(self, path_required: bool) -> Self {
+ Self {
+ path_required,
+ ..self
+ }
+ }
+
+ pub fn with_ratchet_tree_extension(self, ratchet_tree_extension: bool) -> Self {
+ Self {
+ ratchet_tree_extension,
+ ..self
+ }
+ }
+
+ pub fn with_single_welcome_message(self, single_welcome_message: bool) -> Self {
+ Self {
+ single_welcome_message,
+ ..self
+ }
+ }
+
+ pub fn with_allow_external_commit(self, allow_external_commit: bool) -> Self {
+ Self {
+ allow_external_commit,
+ ..self
+ }
+ }
+}
+
+/// Options controlling encryption of control and application messages
+#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
+#[non_exhaustive]
+pub struct EncryptionOptions {
+ #[cfg(feature = "private_message")]
+ pub encrypt_control_messages: bool,
+ #[cfg(feature = "private_message")]
+ pub padding_mode: PaddingMode,
+}
+
+#[cfg(feature = "private_message")]
+impl EncryptionOptions {
+ pub fn new(encrypt_control_messages: bool, padding_mode: PaddingMode) -> Self {
+ Self {
+ encrypt_control_messages,
+ padding_mode,
+ }
+ }
+
+ pub(crate) fn control_wire_format(&self, sender: Sender) -> WireFormat {
+ match sender {
+ Sender::Member(_) if self.encrypt_control_messages => WireFormat::PrivateMessage,
+ _ => WireFormat::PublicMessage,
+ }
+ }
+}
+
+/// A set of user controlled rules that customize the behavior of MLS.
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(mls_build_async, maybe_async::must_be_async)]
+pub trait MlsRules: Send + Sync {
+ type Error: IntoAnyError;
+
+ /// This is called when preparing or receiving a commit to pre-process the set of committed
+ /// proposals.
+ ///
+ /// Both proposals received during the current epoch and at the time of commit
+ /// will be presented for validation and filtering. Filter and validate will
+ /// present a raw list of proposals. Standard MLS rules are applied internally
+ /// on the result of these rules.
+ ///
+ /// Each member of a group MUST apply the same proposal rules in order to
+ /// maintain a working group.
+ ///
+ /// Typically, any invalid proposal should result in an error. The exception are invalid
+ /// by-reference proposals processed when _preparing_ a commit, which should be filtered
+ /// out instead. This is to avoid the deadlock situation when no commit can be generated
+ /// after receiving an invalid set of proposal messages.
+ ///
+ /// `ProposalBundle` can be arbitrarily modified. For example, a Remove proposal that
+ /// removes a moderator can result in adding a GroupContextExtensions proposal that updates
+ /// the moderator list in the group context. The resulting `ProposalBundle` is validated
+ /// by the library.
+ async fn filter_proposals(
+ &self,
+ direction: CommitDirection,
+ source: CommitSource,
+ current_roster: &Roster,
+ extension_list: &ExtensionList,
+ proposals: ProposalBundle,
+ ) -> Result<ProposalBundle, Self::Error>;
+
+ /// This is called when preparing a commit to determine various options: whether to enforce an update
+ /// path in case it is not mandated by MLS, whether to include the ratchet tree in the welcome
+ /// message (if the commit adds members) and whether to generate a single welcome message, or one
+ /// welcome message for each added member.
+ ///
+ /// The `new_roster` and `new_extension_list` describe the group state after the commit.
+ fn commit_options(
+ &self,
+ new_roster: &Roster,
+ new_extension_list: &ExtensionList,
+ proposals: &ProposalBundle,
+ ) -> Result<CommitOptions, Self::Error>;
+
+ /// This is called when sending any packet. For proposals and commits, this determines whether to
+ /// encrypt them. For any encrypted packet, this determines the padding mode used.
+ ///
+ /// Note that for commits, the `current_roster` and `current_extension_list` describe the group state
+ /// before the commit, unlike in [commit_options](MlsRules::commit_options).
+ fn encryption_options(
+ &self,
+ current_roster: &Roster,
+ current_extension_list: &ExtensionList,
+ ) -> Result<EncryptionOptions, Self::Error>;
+}
+
+macro_rules! delegate_mls_rules {
+ ($implementer:ty) => {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ #[cfg_attr(mls_build_async, maybe_async::must_be_async)]
+ impl<T: MlsRules + ?Sized> MlsRules for $implementer {
+ type Error = T::Error;
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn filter_proposals(
+ &self,
+ direction: CommitDirection,
+ source: CommitSource,
+ current_roster: &Roster,
+ extension_list: &ExtensionList,
+ proposals: ProposalBundle,
+ ) -> Result<ProposalBundle, Self::Error> {
+ (**self)
+ .filter_proposals(direction, source, current_roster, extension_list, proposals)
+ .await
+ }
+
+ fn commit_options(
+ &self,
+ roster: &Roster,
+ extension_list: &ExtensionList,
+ proposals: &ProposalBundle,
+ ) -> Result<CommitOptions, Self::Error> {
+ (**self).commit_options(roster, extension_list, proposals)
+ }
+
+ fn encryption_options(
+ &self,
+ roster: &Roster,
+ extension_list: &ExtensionList,
+ ) -> Result<EncryptionOptions, Self::Error> {
+ (**self).encryption_options(roster, extension_list)
+ }
+ }
+ };
+}
+
+delegate_mls_rules!(Box<T>);
+delegate_mls_rules!(&T);
+
+#[derive(Clone, Debug, Default)]
+#[non_exhaustive]
+/// Default MLS rules with pass-through proposal filter and customizable options.
+pub struct DefaultMlsRules {
+ pub commit_options: CommitOptions,
+ pub encryption_options: EncryptionOptions,
+}
+
+impl DefaultMlsRules {
+ /// Create new MLS rules with default settings: do not enforce path and do
+ /// put the ratchet tree in the extension.
+ pub fn new() -> Self {
+ Default::default()
+ }
+
+ /// Set commit options.
+ pub fn with_commit_options(self, commit_options: CommitOptions) -> Self {
+ Self {
+ commit_options,
+ encryption_options: self.encryption_options,
+ }
+ }
+
+ /// Set encryption options.
+ pub fn with_encryption_options(self, encryption_options: EncryptionOptions) -> Self {
+ Self {
+ commit_options: self.commit_options,
+ encryption_options,
+ }
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(mls_build_async, maybe_async::must_be_async)]
+impl MlsRules for DefaultMlsRules {
+ type Error = Infallible;
+
+ async fn filter_proposals(
+ &self,
+ _direction: CommitDirection,
+ _source: CommitSource,
+ _current_roster: &Roster,
+ _extension_list: &ExtensionList,
+ proposals: ProposalBundle,
+ ) -> Result<ProposalBundle, Self::Error> {
+ Ok(proposals)
+ }
+
+ fn commit_options(
+ &self,
+ _: &Roster,
+ _: &ExtensionList,
+ _: &ProposalBundle,
+ ) -> Result<CommitOptions, Self::Error> {
+ Ok(self.commit_options)
+ }
+
+ fn encryption_options(
+ &self,
+ _: &Roster,
+ _: &ExtensionList,
+ ) -> Result<EncryptionOptions, Self::Error> {
+ Ok(self.encryption_options)
+ }
+}
diff --git a/src/group/mod.rs b/src/group/mod.rs
new file mode 100644
index 0000000..0d84a84
--- /dev/null
+++ b/src/group/mod.rs
@@ -0,0 +1,4236 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::vec;
+use alloc::vec::Vec;
+use core::fmt::{self, Debug};
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+use mls_rs_core::error::IntoAnyError;
+use mls_rs_core::secret::Secret;
+use mls_rs_core::time::MlsTime;
+
+use crate::cipher_suite::CipherSuite;
+use crate::client::MlsError;
+use crate::client_config::ClientConfig;
+use crate::crypto::{HpkeCiphertext, SignatureSecretKey};
+use crate::extension::RatchetTreeExt;
+use crate::identity::SigningIdentity;
+use crate::key_package::{KeyPackage, KeyPackageRef};
+use crate::protocol_version::ProtocolVersion;
+use crate::psk::secret::PskSecret;
+use crate::psk::PreSharedKeyID;
+use crate::signer::Signable;
+use crate::tree_kem::hpke_encryption::HpkeEncryptable;
+use crate::tree_kem::kem::TreeKem;
+use crate::tree_kem::node::LeafIndex;
+use crate::tree_kem::path_secret::PathSecret;
+pub use crate::tree_kem::Capabilities;
+use crate::tree_kem::{
+ leaf_node::LeafNode,
+ leaf_node_validator::{LeafNodeValidator, ValidationContext},
+};
+use crate::tree_kem::{math as tree_math, ValidatedUpdatePath};
+use crate::tree_kem::{TreeKemPrivate, TreeKemPublic};
+use crate::{CipherSuiteProvider, CryptoProvider};
+
+#[cfg(feature = "by_ref_proposal")]
+use crate::crypto::{HpkePublicKey, HpkeSecretKey};
+
+use crate::extension::ExternalPubExt;
+
+#[cfg(feature = "private_message")]
+use self::mls_rules::{EncryptionOptions, MlsRules};
+
+#[cfg(feature = "psk")]
+pub use self::resumption::ReinitClient;
+
+#[cfg(feature = "psk")]
+use crate::psk::{
+ resolver::PskResolver, secret::PskSecretInput, ExternalPskId, JustPreSharedKeyID, PskGroupId,
+ ResumptionPSKUsage, ResumptionPsk,
+};
+
+#[cfg(all(feature = "std", feature = "by_ref_proposal"))]
+use std::collections::HashMap;
+
+#[cfg(feature = "private_message")]
+use ciphertext_processor::*;
+
+use confirmation_tag::*;
+use framing::*;
+use key_schedule::*;
+use membership_tag::*;
+use message_signature::*;
+use message_verifier::*;
+use proposal::*;
+#[cfg(feature = "by_ref_proposal")]
+use proposal_cache::*;
+use state::*;
+use transcript_hash::*;
+
+#[cfg(test)]
+pub(crate) use self::commit::test_utils::CommitModifiers;
+
+#[cfg(all(test, feature = "private_message"))]
+pub use self::framing::PrivateMessage;
+
+#[cfg(feature = "psk")]
+use self::proposal_filter::ProposalInfo;
+
+#[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+use secret_tree::*;
+
+#[cfg(feature = "prior_epoch")]
+use self::epoch::PriorEpoch;
+
+use self::epoch::EpochSecrets;
+pub use self::message_processor::{
+ ApplicationMessageDescription, CommitMessageDescription, ProposalMessageDescription,
+ ProposalSender, ReceivedMessage, StateUpdate,
+};
+use self::message_processor::{EventOrContent, MessageProcessor, ProvisionalState};
+#[cfg(feature = "by_ref_proposal")]
+use self::proposal_ref::ProposalRef;
+use self::state_repo::GroupStateRepository;
+pub use group_info::GroupInfo;
+
+pub use self::framing::{ContentType, Sender};
+pub use commit::*;
+pub use context::GroupContext;
+pub use roster::*;
+
+pub(crate) use transcript_hash::ConfirmedTranscriptHash;
+pub(crate) use util::*;
+
+#[cfg(all(feature = "by_ref_proposal", feature = "external_client"))]
+pub use self::message_processor::CachedProposal;
+
+#[cfg(feature = "private_message")]
+mod ciphertext_processor;
+
+mod commit;
+pub(crate) mod confirmation_tag;
+mod context;
+pub(crate) mod epoch;
+pub(crate) mod framing;
+mod group_info;
+pub(crate) mod key_schedule;
+mod membership_tag;
+pub(crate) mod message_processor;
+pub(crate) mod message_signature;
+pub(crate) mod message_verifier;
+pub mod mls_rules;
+#[cfg(feature = "private_message")]
+pub(crate) mod padding;
+/// Proposals to evolve a MLS [`Group`]
+pub mod proposal;
+mod proposal_cache;
+pub(crate) mod proposal_filter;
+#[cfg(feature = "by_ref_proposal")]
+pub(crate) mod proposal_ref;
+#[cfg(feature = "psk")]
+mod resumption;
+mod roster;
+pub(crate) mod snapshot;
+pub(crate) mod state;
+
+#[cfg(feature = "prior_epoch")]
+pub(crate) mod state_repo;
+#[cfg(not(feature = "prior_epoch"))]
+pub(crate) mod state_repo_light;
+#[cfg(not(feature = "prior_epoch"))]
+pub(crate) use state_repo_light as state_repo;
+
+pub(crate) mod transcript_hash;
+mod util;
+
+/// External commit building.
+pub mod external_commit;
+
+#[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+pub(crate) mod secret_tree;
+
+#[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+pub use secret_tree::MessageKeyData as MessageKey;
+
+#[cfg(all(test, feature = "rfc_compliant"))]
+mod interop_test_vectors;
+
+mod exported_tree;
+
+pub use exported_tree::ExportedTree;
+
+#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+struct GroupSecrets {
+ joiner_secret: JoinerSecret,
+ path_secret: Option<PathSecret>,
+ psks: Vec<PreSharedKeyID>,
+}
+
+impl HpkeEncryptable for GroupSecrets {
+ const ENCRYPT_LABEL: &'static str = "Welcome";
+
+ fn from_bytes(bytes: Vec<u8>) -> Result<Self, MlsError> {
+ Self::mls_decode(&mut bytes.as_slice()).map_err(Into::into)
+ }
+
+ fn get_bytes(&self) -> Result<Vec<u8>, MlsError> {
+ self.mls_encode_to_vec().map_err(Into::into)
+ }
+}
+
+#[derive(Clone, Debug, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+pub(crate) struct EncryptedGroupSecrets {
+ pub new_member: KeyPackageRef,
+ pub encrypted_group_secrets: HpkeCiphertext,
+}
+
+#[derive(Clone, Eq, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+pub(crate) struct Welcome {
+ pub cipher_suite: CipherSuite,
+ pub secrets: Vec<EncryptedGroupSecrets>,
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ pub encrypted_group_info: Vec<u8>,
+}
+
+impl Debug for Welcome {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("Welcome")
+ .field("cipher_suite", &self.cipher_suite)
+ .field("secrets", &self.secrets)
+ .field(
+ "encrypted_group_info",
+ &mls_rs_core::debug::pretty_bytes(&self.encrypted_group_info),
+ )
+ .finish()
+ }
+}
+
+#[derive(Clone, Debug)]
+#[cfg_attr(
+ all(feature = "ffi", not(test)),
+ safer_ffi_gen::ffi_type(clone, opaque)
+)]
+#[non_exhaustive]
+/// Information provided to new members upon joining a group.
+pub struct NewMemberInfo {
+ /// Group info extensions found within the Welcome message used to join
+ /// the group.
+ pub group_info_extensions: ExtensionList,
+}
+
+#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
+impl NewMemberInfo {
+ pub(crate) fn new(group_info_extensions: ExtensionList) -> Self {
+ let mut new_member_info = Self {
+ group_info_extensions,
+ };
+
+ new_member_info.ungrease();
+
+ new_member_info
+ }
+
+ /// Group info extensions found within the Welcome message used to join
+ /// the group.
+ #[cfg(feature = "ffi")]
+ pub fn group_info_extensions(&self) -> &ExtensionList {
+ &self.group_info_extensions
+ }
+}
+
+/// An MLS end-to-end encrypted group.
+///
+/// # Group Evolution
+///
+/// MLS Groups are evolved via a propose-then-commit system. Each group state
+/// produced by a commit is called an epoch and can produce and consume
+/// application, proposal, and commit messages. A [commit](Group::commit) is used
+/// to advance to the next epoch by applying existing proposals sent in
+/// the current epoch by-reference along with an optional set of proposals
+/// that are included by-value using a [`CommitBuilder`].
+#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::ffi_type(opaque))]
+#[derive(Clone)]
+pub struct Group<C>
+where
+ C: ClientConfig,
+{
+ config: C,
+ cipher_suite_provider: <C::CryptoProvider as CryptoProvider>::CipherSuiteProvider,
+ state_repo: GroupStateRepository<C::GroupStateStorage, C::KeyPackageRepository>,
+ pub(crate) state: GroupState,
+ epoch_secrets: EpochSecrets,
+ private_tree: TreeKemPrivate,
+ key_schedule: KeySchedule,
+ #[cfg(all(feature = "std", feature = "by_ref_proposal"))]
+ pending_updates: HashMap<HpkePublicKey, (HpkeSecretKey, Option<SignatureSecretKey>)>, // Hash of leaf node hpke public key to secret key
+ #[cfg(all(not(feature = "std"), feature = "by_ref_proposal"))]
+ pending_updates: Vec<(HpkePublicKey, (HpkeSecretKey, Option<SignatureSecretKey>))>,
+ pending_commit: Option<CommitGeneration>,
+ #[cfg(feature = "psk")]
+ previous_psk: Option<PskSecretInput>,
+ #[cfg(test)]
+ pub(crate) commit_modifiers: CommitModifiers,
+ pub(crate) signer: SignatureSecretKey,
+}
+
+#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
+impl<C> Group<C>
+where
+ C: ClientConfig + Clone,
+{
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn new(
+ config: C,
+ group_id: Option<Vec<u8>>,
+ cipher_suite: CipherSuite,
+ protocol_version: ProtocolVersion,
+ signing_identity: SigningIdentity,
+ group_context_extensions: ExtensionList,
+ signer: SignatureSecretKey,
+ ) -> Result<Self, MlsError> {
+ let cipher_suite_provider = cipher_suite_provider(config.crypto_provider(), cipher_suite)?;
+
+ let (leaf_node, leaf_node_secret) = LeafNode::generate(
+ &cipher_suite_provider,
+ config.leaf_properties(),
+ signing_identity,
+ &signer,
+ config.lifetime(),
+ )
+ .await?;
+
+ let identity_provider = config.identity_provider();
+
+ let leaf_node_validator = LeafNodeValidator::new(
+ &cipher_suite_provider,
+ &identity_provider,
+ Some(&group_context_extensions),
+ );
+
+ leaf_node_validator
+ .check_if_valid(&leaf_node, ValidationContext::Add(None))
+ .await?;
+
+ let (mut public_tree, private_tree) = TreeKemPublic::derive(
+ leaf_node,
+ leaf_node_secret,
+ &config.identity_provider(),
+ &group_context_extensions,
+ )
+ .await?;
+
+ let tree_hash = public_tree.tree_hash(&cipher_suite_provider).await?;
+
+ let group_id = group_id.map(Ok).unwrap_or_else(|| {
+ cipher_suite_provider
+ .random_bytes_vec(cipher_suite_provider.kdf_extract_size())
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
+ })?;
+
+ let context = GroupContext::new_group(
+ protocol_version,
+ cipher_suite,
+ group_id,
+ tree_hash,
+ group_context_extensions,
+ );
+
+ let state_repo = GroupStateRepository::new(
+ #[cfg(feature = "prior_epoch")]
+ context.group_id.clone(),
+ config.group_state_storage(),
+ config.key_package_repo(),
+ None,
+ )?;
+
+ let key_schedule_result = KeySchedule::from_random_epoch_secret(
+ &cipher_suite_provider,
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ public_tree.total_leaf_count(),
+ )
+ .await?;
+
+ let confirmation_tag = ConfirmationTag::create(
+ &key_schedule_result.confirmation_key,
+ &vec![].into(),
+ &cipher_suite_provider,
+ )
+ .await?;
+
+ let interim_hash = InterimTranscriptHash::create(
+ &cipher_suite_provider,
+ &vec![].into(),
+ &confirmation_tag,
+ )
+ .await?;
+
+ Ok(Self {
+ config,
+ state: GroupState::new(context, public_tree, interim_hash, confirmation_tag),
+ private_tree,
+ key_schedule: key_schedule_result.key_schedule,
+ #[cfg(feature = "by_ref_proposal")]
+ pending_updates: Default::default(),
+ pending_commit: None,
+ #[cfg(test)]
+ commit_modifiers: Default::default(),
+ epoch_secrets: key_schedule_result.epoch_secrets,
+ state_repo,
+ cipher_suite_provider,
+ #[cfg(feature = "psk")]
+ previous_psk: None,
+ signer,
+ })
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn join(
+ welcome: &MlsMessage,
+ tree_data: Option<ExportedTree<'_>>,
+ config: C,
+ signer: SignatureSecretKey,
+ ) -> Result<(Self, NewMemberInfo), MlsError> {
+ Self::from_welcome_message(
+ welcome,
+ tree_data,
+ config,
+ signer,
+ #[cfg(feature = "psk")]
+ None,
+ )
+ .await
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn from_welcome_message(
+ welcome: &MlsMessage,
+ tree_data: Option<ExportedTree<'_>>,
+ config: C,
+ signer: SignatureSecretKey,
+ #[cfg(feature = "psk")] additional_psk: Option<PskSecretInput>,
+ ) -> Result<(Self, NewMemberInfo), MlsError> {
+ let protocol_version = welcome.version;
+
+ if !config.version_supported(protocol_version) {
+ return Err(MlsError::UnsupportedProtocolVersion(protocol_version));
+ }
+
+ let MlsMessagePayload::Welcome(welcome) = &welcome.payload else {
+ return Err(MlsError::UnexpectedMessageType);
+ };
+
+ let cipher_suite_provider =
+ cipher_suite_provider(config.crypto_provider(), welcome.cipher_suite)?;
+
+ let (encrypted_group_secrets, key_package_generation) =
+ find_key_package_generation(&config.key_package_repo(), &welcome.secrets).await?;
+
+ let key_package_version = key_package_generation.key_package.version;
+
+ if key_package_version != protocol_version {
+ return Err(MlsError::ProtocolVersionMismatch);
+ }
+
+ // Decrypt the encrypted_group_secrets using HPKE with the algorithms indicated by the
+ // cipher suite and the HPKE private key corresponding to the GroupSecrets. If a
+ // PreSharedKeyID is part of the GroupSecrets and the client is not in possession of
+ // the corresponding PSK, return an error
+ let group_secrets = GroupSecrets::decrypt(
+ &cipher_suite_provider,
+ &key_package_generation.init_secret_key,
+ &key_package_generation.key_package.hpke_init_key,
+ &welcome.encrypted_group_info,
+ &encrypted_group_secrets.encrypted_group_secrets,
+ )
+ .await?;
+
+ #[cfg(feature = "psk")]
+ let psk_secret = if let Some(psk) = additional_psk {
+ let psk_id = group_secrets
+ .psks
+ .first()
+ .ok_or(MlsError::UnexpectedPskId)?;
+
+ match &psk_id.key_id {
+ JustPreSharedKeyID::Resumption(r) if r.usage != ResumptionPSKUsage::Application => {
+ Ok(())
+ }
+ _ => Err(MlsError::UnexpectedPskId),
+ }?;
+
+ let mut psk = psk;
+ psk.id.psk_nonce = psk_id.psk_nonce.clone();
+ PskSecret::calculate(&[psk], &cipher_suite_provider).await?
+ } else {
+ PskResolver::<
+ <C as ClientConfig>::GroupStateStorage,
+ <C as ClientConfig>::KeyPackageRepository,
+ <C as ClientConfig>::PskStore,
+ > {
+ group_context: None,
+ current_epoch: None,
+ prior_epochs: None,
+ psk_store: &config.secret_store(),
+ }
+ .resolve_to_secret(&group_secrets.psks, &cipher_suite_provider)
+ .await?
+ };
+
+ #[cfg(not(feature = "psk"))]
+ let psk_secret = PskSecret::new(&cipher_suite_provider);
+
+ // From the joiner_secret in the decrypted GroupSecrets object and the PSKs specified in
+ // the GroupSecrets, derive the welcome_secret and using that the welcome_key and
+ // welcome_nonce.
+ let welcome_secret = WelcomeSecret::from_joiner_secret(
+ &cipher_suite_provider,
+ &group_secrets.joiner_secret,
+ &psk_secret,
+ )
+ .await?;
+
+ // Use the key and nonce to decrypt the encrypted_group_info field.
+ let decrypted_group_info = welcome_secret
+ .decrypt(&welcome.encrypted_group_info)
+ .await?;
+
+ let group_info = GroupInfo::mls_decode(&mut &**decrypted_group_info)?;
+
+ let public_tree = validate_group_info_joiner(
+ protocol_version,
+ &group_info,
+ tree_data,
+ &config.identity_provider(),
+ &cipher_suite_provider,
+ )
+ .await?;
+
+ // Identify a leaf in the tree array (any even-numbered node) whose leaf_node is identical
+ // to the leaf_node field of the KeyPackage. If no such field exists, return an error. Let
+ // index represent the index of this node among the leaves in the tree, namely the index of
+ // the node in the tree array divided by two.
+ let self_index = public_tree
+ .find_leaf_node(&key_package_generation.key_package.leaf_node)
+ .ok_or(MlsError::WelcomeKeyPackageNotFound)?;
+
+ let used_key_package_ref = key_package_generation.reference;
+
+ let mut private_tree =
+ TreeKemPrivate::new_self_leaf(self_index, key_package_generation.leaf_node_secret_key);
+
+ // If the path_secret value is set in the GroupSecrets object
+ if let Some(path_secret) = group_secrets.path_secret {
+ private_tree
+ .update_secrets(
+ &cipher_suite_provider,
+ group_info.signer,
+ path_secret,
+ &public_tree,
+ )
+ .await?;
+ }
+
+ // Use the joiner_secret from the GroupSecrets object to generate the epoch secret and
+ // other derived secrets for the current epoch.
+ let key_schedule_result = KeySchedule::from_joiner(
+ &cipher_suite_provider,
+ &group_secrets.joiner_secret,
+ &group_info.group_context,
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ public_tree.total_leaf_count(),
+ &psk_secret,
+ )
+ .await?;
+
+ // Verify the confirmation tag in the GroupInfo using the derived confirmation key and the
+ // confirmed_transcript_hash from the GroupInfo.
+ if !group_info
+ .confirmation_tag
+ .matches(
+ &key_schedule_result.confirmation_key,
+ &group_info.group_context.confirmed_transcript_hash,
+ &cipher_suite_provider,
+ )
+ .await?
+ {
+ return Err(MlsError::InvalidConfirmationTag);
+ }
+
+ Self::join_with(
+ config,
+ group_info,
+ public_tree,
+ key_schedule_result.key_schedule,
+ key_schedule_result.epoch_secrets,
+ private_tree,
+ Some(used_key_package_ref),
+ signer,
+ )
+ .await
+ }
+
+ #[allow(clippy::too_many_arguments)]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn join_with(
+ config: C,
+ group_info: GroupInfo,
+ public_tree: TreeKemPublic,
+ key_schedule: KeySchedule,
+ epoch_secrets: EpochSecrets,
+ private_tree: TreeKemPrivate,
+ used_key_package_ref: Option<KeyPackageRef>,
+ signer: SignatureSecretKey,
+ ) -> Result<(Self, NewMemberInfo), MlsError> {
+ let cs = group_info.group_context.cipher_suite;
+
+ let cs = config
+ .crypto_provider()
+ .cipher_suite_provider(cs)
+ .ok_or(MlsError::UnsupportedCipherSuite(cs))?;
+
+ // Use the confirmed transcript hash and confirmation tag to compute the interim transcript
+ // hash in the new state.
+ let interim_transcript_hash = InterimTranscriptHash::create(
+ &cs,
+ &group_info.group_context.confirmed_transcript_hash,
+ &group_info.confirmation_tag,
+ )
+ .await?;
+
+ let state_repo = GroupStateRepository::new(
+ #[cfg(feature = "prior_epoch")]
+ group_info.group_context.group_id.clone(),
+ config.group_state_storage(),
+ config.key_package_repo(),
+ used_key_package_ref,
+ )?;
+
+ let group = Group {
+ config,
+ state: GroupState::new(
+ group_info.group_context,
+ public_tree,
+ interim_transcript_hash,
+ group_info.confirmation_tag,
+ ),
+ private_tree,
+ key_schedule,
+ #[cfg(feature = "by_ref_proposal")]
+ pending_updates: Default::default(),
+ pending_commit: None,
+ #[cfg(test)]
+ commit_modifiers: Default::default(),
+ epoch_secrets,
+ state_repo,
+ cipher_suite_provider: cs,
+ #[cfg(feature = "psk")]
+ previous_psk: None,
+ signer,
+ };
+
+ Ok((group, NewMemberInfo::new(group_info.extensions)))
+ }
+
+ #[inline(always)]
+ pub(crate) fn current_epoch_tree(&self) -> &TreeKemPublic {
+ &self.state.public_tree
+ }
+
+ /// The current epoch of the group. This value is incremented each
+ /// time a [`Group::commit`] message is processed.
+ #[inline(always)]
+ pub fn current_epoch(&self) -> u64 {
+ self.context().epoch
+ }
+
+ /// Index within the group's state for the local group instance.
+ ///
+ /// This index corresponds to indexes in content descriptions within
+ /// [`ReceivedMessage`].
+ #[inline(always)]
+ pub fn current_member_index(&self) -> u32 {
+ self.private_tree.self_index.0
+ }
+
+ fn current_user_leaf_node(&self) -> Result<&LeafNode, MlsError> {
+ self.current_epoch_tree()
+ .get_leaf_node(self.private_tree.self_index)
+ }
+
+ /// Signing identity currently in use by the local group instance.
+ #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen_ignore)]
+ pub fn current_member_signing_identity(&self) -> Result<&SigningIdentity, MlsError> {
+ self.current_user_leaf_node().map(|ln| &ln.signing_identity)
+ }
+
+ /// Member at a specific index in the group state.
+ ///
+ /// These indexes correspond to indexes in content descriptions within
+ /// [`ReceivedMessage`].
+ pub fn member_at_index(&self, index: u32) -> Option<Member> {
+ let leaf_index = LeafIndex(index);
+
+ self.current_epoch_tree()
+ .get_leaf_node(leaf_index)
+ .ok()
+ .map(|ln| member_from_leaf_node(ln, leaf_index))
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn proposal_message(
+ &mut self,
+ proposal: Proposal,
+ authenticated_data: Vec<u8>,
+ ) -> Result<MlsMessage, MlsError> {
+ let sender = Sender::Member(*self.private_tree.self_index);
+
+ let auth_content = AuthenticatedContent::new_signed(
+ &self.cipher_suite_provider,
+ self.context(),
+ sender,
+ Content::Proposal(alloc::boxed::Box::new(proposal.clone())),
+ &self.signer,
+ #[cfg(feature = "private_message")]
+ self.encryption_options()?.control_wire_format(sender),
+ #[cfg(not(feature = "private_message"))]
+ WireFormat::PublicMessage,
+ authenticated_data,
+ )
+ .await?;
+
+ let proposal_ref =
+ ProposalRef::from_content(&self.cipher_suite_provider, &auth_content).await?;
+
+ self.state
+ .proposals
+ .insert(proposal_ref, proposal, auth_content.content.sender);
+
+ self.format_for_wire(auth_content).await
+ }
+
+ /// Unique identifier for this group.
+ pub fn group_id(&self) -> &[u8] {
+ &self.context().group_id
+ }
+
+ fn provisional_private_tree(
+ &self,
+ provisional_state: &ProvisionalState,
+ ) -> Result<(TreeKemPrivate, Option<SignatureSecretKey>), MlsError> {
+ let mut provisional_private_tree = self.private_tree.clone();
+ let self_index = provisional_private_tree.self_index;
+
+ // Remove secret keys for blanked nodes
+ let path = provisional_state
+ .public_tree
+ .nodes
+ .direct_copath(self_index);
+
+ provisional_private_tree
+ .secret_keys
+ .resize(path.len() + 1, None);
+
+ for (i, n) in path.iter().enumerate() {
+ if provisional_state.public_tree.nodes.is_blank(n.path)? {
+ provisional_private_tree.secret_keys[i + 1] = None;
+ }
+ }
+
+ // Apply own update
+ let new_signer = None;
+
+ #[cfg(feature = "by_ref_proposal")]
+ let mut new_signer = new_signer;
+
+ #[cfg(feature = "by_ref_proposal")]
+ for p in &provisional_state.applied_proposals.updates {
+ if p.sender == Sender::Member(*self_index) {
+ let leaf_pk = &p.proposal.leaf_node.public_key;
+
+ // Update the leaf in the private tree if this is our update
+ #[cfg(feature = "std")]
+ let new_leaf_sk_and_signer = self.pending_updates.get(leaf_pk);
+
+ #[cfg(not(feature = "std"))]
+ let new_leaf_sk_and_signer = self
+ .pending_updates
+ .iter()
+ .find_map(|(pk, sk)| (pk == leaf_pk).then_some(sk));
+
+ let new_leaf_sk = new_leaf_sk_and_signer.map(|(sk, _)| sk.clone());
+ new_signer = new_leaf_sk_and_signer.and_then(|(_, sk)| sk.clone());
+
+ provisional_private_tree
+ .update_leaf(new_leaf_sk.ok_or(MlsError::UpdateErrorNoSecretKey)?);
+
+ break;
+ }
+ }
+
+ Ok((provisional_private_tree, new_signer))
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn encrypt_group_secrets(
+ &self,
+ key_package: &KeyPackage,
+ leaf_index: LeafIndex,
+ joiner_secret: &JoinerSecret,
+ path_secrets: Option<&Vec<Option<PathSecret>>>,
+ #[cfg(feature = "psk")] psks: Vec<PreSharedKeyID>,
+ encrypted_group_info: &[u8],
+ ) -> Result<EncryptedGroupSecrets, MlsError> {
+ let path_secret = path_secrets
+ .map(|secrets| {
+ secrets
+ .get(
+ tree_math::leaf_lca_level(*self.private_tree.self_index, *leaf_index)
+ as usize
+ - 1,
+ )
+ .cloned()
+ .flatten()
+ .ok_or(MlsError::InvalidTreeKemPrivateKey)
+ })
+ .transpose()?;
+
+ #[cfg(not(feature = "psk"))]
+ let psks = Vec::new();
+
+ let group_secrets = GroupSecrets {
+ joiner_secret: joiner_secret.clone(),
+ path_secret,
+ psks,
+ };
+
+ let encrypted_group_secrets = group_secrets
+ .encrypt(
+ &self.cipher_suite_provider,
+ &key_package.hpke_init_key,
+ encrypted_group_info,
+ )
+ .await?;
+
+ Ok(EncryptedGroupSecrets {
+ new_member: key_package
+ .to_reference(&self.cipher_suite_provider)
+ .await?,
+ encrypted_group_secrets,
+ })
+ }
+
+ /// Create a proposal message that adds a new member to the group.
+ ///
+ /// `authenticated_data` will be sent unencrypted along with the contents
+ /// of the proposal message.
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn propose_add(
+ &mut self,
+ key_package: MlsMessage,
+ authenticated_data: Vec<u8>,
+ ) -> Result<MlsMessage, MlsError> {
+ let proposal = self.add_proposal(key_package)?;
+ self.proposal_message(proposal, authenticated_data).await
+ }
+
+ fn add_proposal(&self, key_package: MlsMessage) -> Result<Proposal, MlsError> {
+ Ok(Proposal::Add(alloc::boxed::Box::new(AddProposal {
+ key_package: key_package
+ .into_key_package()
+ .ok_or(MlsError::UnexpectedMessageType)?,
+ })))
+ }
+
+ /// Create a proposal message that updates your own public keys.
+ ///
+ /// This proposal is useful for contributing additional forward secrecy
+ /// and post-compromise security to the group without having to perform
+ /// the necessary computation of a [`Group::commit`].
+ ///
+ /// `authenticated_data` will be sent unencrypted along with the contents
+ /// of the proposal message.
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn propose_update(
+ &mut self,
+ authenticated_data: Vec<u8>,
+ ) -> Result<MlsMessage, MlsError> {
+ let proposal = self.update_proposal(None, None).await?;
+ self.proposal_message(proposal, authenticated_data).await
+ }
+
+ /// Create a proposal message that updates your own public keys
+ /// as well as your credential.
+ ///
+ /// This proposal is useful for contributing additional forward secrecy
+ /// and post-compromise security to the group without having to perform
+ /// the necessary computation of a [`Group::commit`].
+ ///
+ /// Identity updates are allowed by the group by default assuming that the
+ /// new identity provided is considered
+ /// [valid](crate::IdentityProvider::validate_member)
+ /// by and matches the output of the
+ /// [identity](crate::IdentityProvider)
+ /// function of the current
+ /// [`IdentityProvider`](crate::IdentityProvider).
+ ///
+ /// `authenticated_data` will be sent unencrypted along with the contents
+ /// of the proposal message.
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn propose_update_with_identity(
+ &mut self,
+ signer: SignatureSecretKey,
+ signing_identity: SigningIdentity,
+ authenticated_data: Vec<u8>,
+ ) -> Result<MlsMessage, MlsError> {
+ let proposal = self
+ .update_proposal(Some(signer), Some(signing_identity))
+ .await?;
+
+ self.proposal_message(proposal, authenticated_data).await
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn update_proposal(
+ &mut self,
+ signer: Option<SignatureSecretKey>,
+ signing_identity: Option<SigningIdentity>,
+ ) -> Result<Proposal, MlsError> {
+ // Grab a copy of the current node and update it to have new key material
+ let mut new_leaf_node = self.current_user_leaf_node()?.clone();
+
+ let secret_key = new_leaf_node
+ .update(
+ &self.cipher_suite_provider,
+ self.group_id(),
+ self.current_member_index(),
+ self.config.leaf_properties(),
+ signing_identity,
+ signer.as_ref().unwrap_or(&self.signer),
+ )
+ .await?;
+
+ // Store the secret key in the pending updates storage for later
+ #[cfg(feature = "std")]
+ self.pending_updates
+ .insert(new_leaf_node.public_key.clone(), (secret_key, signer));
+
+ #[cfg(not(feature = "std"))]
+ self.pending_updates
+ .push((new_leaf_node.public_key.clone(), (secret_key, signer)));
+
+ Ok(Proposal::Update(UpdateProposal {
+ leaf_node: new_leaf_node,
+ }))
+ }
+
+ /// Create a proposal message that removes an existing member from the
+ /// group.
+ ///
+ /// `authenticated_data` will be sent unencrypted along with the contents
+ /// of the proposal message.
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn propose_remove(
+ &mut self,
+ index: u32,
+ authenticated_data: Vec<u8>,
+ ) -> Result<MlsMessage, MlsError> {
+ let proposal = self.remove_proposal(index)?;
+ self.proposal_message(proposal, authenticated_data).await
+ }
+
+ fn remove_proposal(&self, index: u32) -> Result<Proposal, MlsError> {
+ let leaf_index = LeafIndex(index);
+
+ // Verify that this leaf is actually in the tree
+ self.current_epoch_tree().get_leaf_node(leaf_index)?;
+
+ Ok(Proposal::Remove(RemoveProposal {
+ to_remove: leaf_index,
+ }))
+ }
+
+ /// Create a proposal message that adds an external pre shared key to the group.
+ ///
+ /// Each group member will need to have the PSK associated with
+ /// [`ExternalPskId`](mls_rs_core::psk::ExternalPskId) installed within
+ /// the [`PreSharedKeyStorage`](mls_rs_core::psk::PreSharedKeyStorage)
+ /// in use by this group upon processing a [commit](Group::commit) that
+ /// contains this proposal.
+ ///
+ /// `authenticated_data` will be sent unencrypted along with the contents
+ /// of the proposal message.
+ #[cfg(all(feature = "by_ref_proposal", feature = "psk"))]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn propose_external_psk(
+ &mut self,
+ psk: ExternalPskId,
+ authenticated_data: Vec<u8>,
+ ) -> Result<MlsMessage, MlsError> {
+ let proposal = self.psk_proposal(JustPreSharedKeyID::External(psk))?;
+ self.proposal_message(proposal, authenticated_data).await
+ }
+
+ #[cfg(feature = "psk")]
+ fn psk_proposal(&self, key_id: JustPreSharedKeyID) -> Result<Proposal, MlsError> {
+ Ok(Proposal::Psk(PreSharedKeyProposal {
+ psk: PreSharedKeyID::new(key_id, &self.cipher_suite_provider)?,
+ }))
+ }
+
+ /// Create a proposal message that adds a pre shared key from a previous
+ /// epoch to the current group state.
+ ///
+ /// Each group member will need to have the secret state from `psk_epoch`.
+ /// In particular, the members who joined between `psk_epoch` and the
+ /// current epoch cannot process a commit containing this proposal.
+ ///
+ /// `authenticated_data` will be sent unencrypted along with the contents
+ /// of the proposal message.
+ #[cfg(all(feature = "by_ref_proposal", feature = "psk"))]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn propose_resumption_psk(
+ &mut self,
+ psk_epoch: u64,
+ authenticated_data: Vec<u8>,
+ ) -> Result<MlsMessage, MlsError> {
+ let key_id = ResumptionPsk {
+ psk_epoch,
+ usage: ResumptionPSKUsage::Application,
+ psk_group_id: PskGroupId(self.group_id().to_vec()),
+ };
+
+ let proposal = self.psk_proposal(JustPreSharedKeyID::Resumption(key_id))?;
+ self.proposal_message(proposal, authenticated_data).await
+ }
+
+ /// Create a proposal message that requests for this group to be
+ /// reinitialized.
+ ///
+ /// Once a [`ReInitProposal`](proposal::ReInitProposal)
+ /// has been sent, another group member can complete reinitialization of
+ /// the group by calling [`Group::get_reinit_client`].
+ ///
+ /// `authenticated_data` will be sent unencrypted along with the contents
+ /// of the proposal message.
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn propose_reinit(
+ &mut self,
+ group_id: Option<Vec<u8>>,
+ version: ProtocolVersion,
+ cipher_suite: CipherSuite,
+ extensions: ExtensionList,
+ authenticated_data: Vec<u8>,
+ ) -> Result<MlsMessage, MlsError> {
+ let proposal = self.reinit_proposal(group_id, version, cipher_suite, extensions)?;
+ self.proposal_message(proposal, authenticated_data).await
+ }
+
+ fn reinit_proposal(
+ &self,
+ group_id: Option<Vec<u8>>,
+ version: ProtocolVersion,
+ cipher_suite: CipherSuite,
+ extensions: ExtensionList,
+ ) -> Result<Proposal, MlsError> {
+ let group_id = group_id.map(Ok).unwrap_or_else(|| {
+ self.cipher_suite_provider
+ .random_bytes_vec(self.cipher_suite_provider.kdf_extract_size())
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
+ })?;
+
+ Ok(Proposal::ReInit(ReInitProposal {
+ group_id,
+ version,
+ cipher_suite,
+ extensions,
+ }))
+ }
+
+ /// Create a proposal message that sets extensions stored in the group
+ /// state.
+ ///
+ /// # Warning
+ ///
+ /// This function does not create a diff that will be applied to the
+ /// current set of extension that are in use. In order for an existing
+ /// extension to not be overwritten by this proposal, it must be included
+ /// in the new set of extensions being proposed.
+ ///
+ ///
+ /// `authenticated_data` will be sent unencrypted along with the contents
+ /// of the proposal message.
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn propose_group_context_extensions(
+ &mut self,
+ extensions: ExtensionList,
+ authenticated_data: Vec<u8>,
+ ) -> Result<MlsMessage, MlsError> {
+ let proposal = self.group_context_extensions_proposal(extensions);
+ self.proposal_message(proposal, authenticated_data).await
+ }
+
+ fn group_context_extensions_proposal(&self, extensions: ExtensionList) -> Proposal {
+ Proposal::GroupContextExtensions(extensions)
+ }
+
+ /// Create a custom proposal message.
+ ///
+ /// `authenticated_data` will be sent unencrypted along with the contents
+ /// of the proposal message.
+ #[cfg(all(feature = "custom_proposal", feature = "by_ref_proposal"))]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn propose_custom(
+ &mut self,
+ proposal: CustomProposal,
+ authenticated_data: Vec<u8>,
+ ) -> Result<MlsMessage, MlsError> {
+ self.proposal_message(Proposal::Custom(proposal), authenticated_data)
+ .await
+ }
+
+ /// Delete all sent and received proposals cached for commit.
+ #[cfg(feature = "by_ref_proposal")]
+ pub fn clear_proposal_cache(&mut self) {
+ self.state.proposals.clear()
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn format_for_wire(
+ &mut self,
+ content: AuthenticatedContent,
+ ) -> Result<MlsMessage, MlsError> {
+ #[cfg(feature = "private_message")]
+ let payload = if content.wire_format == WireFormat::PrivateMessage {
+ MlsMessagePayload::Cipher(self.create_ciphertext(content).await?)
+ } else {
+ MlsMessagePayload::Plain(self.create_plaintext(content).await?)
+ };
+ #[cfg(not(feature = "private_message"))]
+ let payload = MlsMessagePayload::Plain(self.create_plaintext(content).await?);
+
+ Ok(MlsMessage::new(self.protocol_version(), payload))
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn create_plaintext(
+ &self,
+ auth_content: AuthenticatedContent,
+ ) -> Result<PublicMessage, MlsError> {
+ let membership_tag = if matches!(auth_content.content.sender, Sender::Member(_)) {
+ let tag = self
+ .key_schedule
+ .get_membership_tag(&auth_content, self.context(), &self.cipher_suite_provider)
+ .await?;
+
+ Some(tag)
+ } else {
+ None
+ };
+
+ Ok(PublicMessage {
+ content: auth_content.content,
+ auth: auth_content.auth,
+ membership_tag,
+ })
+ }
+
+ #[cfg(feature = "private_message")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn create_ciphertext(
+ &mut self,
+ auth_content: AuthenticatedContent,
+ ) -> Result<PrivateMessage, MlsError> {
+ let padding_mode = self.encryption_options()?.padding_mode;
+
+ let mut encryptor = CiphertextProcessor::new(self, self.cipher_suite_provider.clone());
+
+ encryptor.seal(auth_content, padding_mode).await
+ }
+
+ /// Encrypt an application message using the current group state.
+ ///
+ /// `authenticated_data` will be sent unencrypted along with the contents
+ /// of the proposal message.
+ #[cfg(feature = "private_message")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn encrypt_application_message(
+ &mut self,
+ message: &[u8],
+ authenticated_data: Vec<u8>,
+ ) -> Result<MlsMessage, MlsError> {
+ // A group member that has observed one or more proposals within an epoch MUST send a Commit message
+ // before sending application data
+ #[cfg(feature = "by_ref_proposal")]
+ if !self.state.proposals.is_empty() {
+ return Err(MlsError::CommitRequired);
+ }
+
+ let auth_content = AuthenticatedContent::new_signed(
+ &self.cipher_suite_provider,
+ self.context(),
+ Sender::Member(*self.private_tree.self_index),
+ Content::Application(message.to_vec().into()),
+ &self.signer,
+ WireFormat::PrivateMessage,
+ authenticated_data,
+ )
+ .await?;
+
+ self.format_for_wire(auth_content).await
+ }
+
+ #[cfg(feature = "private_message")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn decrypt_incoming_ciphertext(
+ &mut self,
+ message: &PrivateMessage,
+ ) -> Result<AuthenticatedContent, MlsError> {
+ let epoch_id = message.epoch;
+
+ let auth_content = if epoch_id == self.context().epoch {
+ let content = CiphertextProcessor::new(self, self.cipher_suite_provider.clone())
+ .open(message)
+ .await?;
+
+ verify_auth_content_signature(
+ &self.cipher_suite_provider,
+ SignaturePublicKeysContainer::RatchetTree(&self.state.public_tree),
+ self.context(),
+ &content,
+ #[cfg(feature = "by_ref_proposal")]
+ &[],
+ )
+ .await?;
+
+ Ok::<_, MlsError>(content)
+ } else {
+ #[cfg(feature = "prior_epoch")]
+ {
+ let epoch = self
+ .state_repo
+ .get_epoch_mut(epoch_id)
+ .await?
+ .ok_or(MlsError::EpochNotFound)?;
+
+ let content = CiphertextProcessor::new(epoch, self.cipher_suite_provider.clone())
+ .open(message)
+ .await?;
+
+ verify_auth_content_signature(
+ &self.cipher_suite_provider,
+ SignaturePublicKeysContainer::List(&epoch.signature_public_keys),
+ &epoch.context,
+ &content,
+ #[cfg(feature = "by_ref_proposal")]
+ &[],
+ )
+ .await?;
+
+ Ok(content)
+ }
+
+ #[cfg(not(feature = "prior_epoch"))]
+ Err(MlsError::EpochNotFound)
+ }?;
+
+ Ok(auth_content)
+ }
+
+ /// Apply a pending commit that was created by [`Group::commit`] or
+ /// [`CommitBuilder::build`].
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn apply_pending_commit(&mut self) -> Result<CommitMessageDescription, MlsError> {
+ let pending_commit = self
+ .pending_commit
+ .clone()
+ .ok_or(MlsError::PendingCommitNotFound)?;
+
+ self.process_commit(pending_commit.content, None).await
+ }
+
+ /// Returns true if a commit has been created but not yet applied
+ /// with [`Group::apply_pending_commit`] or cleared with [`Group::clear_pending_commit`]
+ pub fn has_pending_commit(&self) -> bool {
+ self.pending_commit.is_some()
+ }
+
+ /// Clear the currently pending commit.
+ ///
+ /// This function will automatically be called in the event that a
+ /// commit message is processed using [`Group::process_incoming_message`]
+ /// before [`Group::apply_pending_commit`] is called.
+ pub fn clear_pending_commit(&mut self) {
+ self.pending_commit = None
+ }
+
+ /// Process an inbound message for this group.
+ ///
+ /// # Warning
+ ///
+ /// Changes to the group's state as a result of processing `message` will
+ /// not be persisted by the
+ /// [`GroupStateStorage`](crate::GroupStateStorage)
+ /// in use by this group until [`Group::write_to_storage`] is called.
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ #[inline(never)]
+ pub async fn process_incoming_message(
+ &mut self,
+ message: MlsMessage,
+ ) -> Result<ReceivedMessage, MlsError> {
+ if let Some(pending) = &self.pending_commit {
+ let message_hash = CommitHash::compute(&self.cipher_suite_provider, &message).await?;
+
+ if message_hash == pending.commit_message_hash {
+ let message_description = self.apply_pending_commit().await?;
+
+ return Ok(ReceivedMessage::Commit(message_description));
+ }
+ }
+
+ MessageProcessor::process_incoming_message(
+ self,
+ message,
+ #[cfg(feature = "by_ref_proposal")]
+ true,
+ )
+ .await
+ }
+
+ /// Process an inbound message for this group, providing additional context
+ /// with a message timestamp.
+ ///
+ /// Providing a timestamp is useful when the
+ /// [`IdentityProvider`](crate::IdentityProvider)
+ /// in use by the group can determine validity based on a timestamp.
+ /// For example, this allows for checking X.509 certificate expiration
+ /// at the time when `message` was received by a server rather than when
+ /// a specific client asynchronously received `message`
+ ///
+ /// # Warning
+ ///
+ /// Changes to the group's state as a result of processing `message` will
+ /// not be persisted by the
+ /// [`GroupStateStorage`](crate::GroupStateStorage)
+ /// in use by this group until [`Group::write_to_storage`] is called.
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn process_incoming_message_with_time(
+ &mut self,
+ message: MlsMessage,
+ time: MlsTime,
+ ) -> Result<ReceivedMessage, MlsError> {
+ MessageProcessor::process_incoming_message_with_time(
+ self,
+ message,
+ #[cfg(feature = "by_ref_proposal")]
+ true,
+ Some(time),
+ )
+ .await
+ }
+
+ /// Find a group member by
+ /// [identity](crate::IdentityProvider::identity)
+ ///
+ /// This function determines identity by calling the
+ /// [`IdentityProvider`](crate::IdentityProvider)
+ /// currently in use by the group.
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn member_with_identity(&self, identity: &[u8]) -> Result<Member, MlsError> {
+ let tree = &self.state.public_tree;
+
+ #[cfg(feature = "tree_index")]
+ let index = tree.get_leaf_node_with_identity(identity);
+
+ #[cfg(not(feature = "tree_index"))]
+ let index = tree
+ .get_leaf_node_with_identity(
+ identity,
+ &self.identity_provider(),
+ &self.state.context.extensions,
+ )
+ .await?;
+
+ let index = index.ok_or(MlsError::MemberNotFound)?;
+ let node = self.state.public_tree.get_leaf_node(index)?;
+
+ Ok(member_from_leaf_node(node, index))
+ }
+
+ /// Create a group info message that can be used for external proposals and commits.
+ ///
+ /// The returned `GroupInfo` is suitable for one external commit for the current epoch.
+ /// If `with_tree_in_extension` is set to true, the returned `GroupInfo` contains the
+ /// ratchet tree and therefore contains all information needed to join the group. Otherwise,
+ /// the ratchet tree must be obtained separately, e.g. via
+ /// (ExternalClient::export_tree)[crate::external_client::ExternalGroup::export_tree].
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn group_info_message_allowing_ext_commit(
+ &self,
+ with_tree_in_extension: bool,
+ ) -> Result<MlsMessage, MlsError> {
+ let mut extensions = ExtensionList::new();
+
+ extensions.set_from({
+ self.key_schedule
+ .get_external_key_pair_ext(&self.cipher_suite_provider)
+ .await?
+ })?;
+
+ self.group_info_message_internal(extensions, with_tree_in_extension)
+ .await
+ }
+
+ /// Create a group info message that can be used for external proposals.
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn group_info_message(
+ &self,
+ with_tree_in_extension: bool,
+ ) -> Result<MlsMessage, MlsError> {
+ self.group_info_message_internal(ExtensionList::new(), with_tree_in_extension)
+ .await
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn group_info_message_internal(
+ &self,
+ mut initial_extensions: ExtensionList,
+ with_tree_in_extension: bool,
+ ) -> Result<MlsMessage, MlsError> {
+ if with_tree_in_extension {
+ initial_extensions.set_from(RatchetTreeExt {
+ tree_data: ExportedTree::new(self.state.public_tree.nodes.clone()),
+ })?;
+ }
+
+ let mut info = GroupInfo {
+ group_context: self.context().clone(),
+ extensions: initial_extensions,
+ confirmation_tag: self.state.confirmation_tag.clone(),
+ signer: self.private_tree.self_index,
+ signature: Vec::new(),
+ };
+
+ info.grease(self.cipher_suite_provider())?;
+
+ info.sign(&self.cipher_suite_provider, &self.signer, &())
+ .await?;
+
+ Ok(MlsMessage::new(
+ self.protocol_version(),
+ MlsMessagePayload::GroupInfo(info),
+ ))
+ }
+
+ /// Get the current group context summarizing various information about the group.
+ #[inline(always)]
+ pub fn context(&self) -> &GroupContext {
+ &self.group_state().context
+ }
+
+ /// Get the
+ /// [epoch_authenticator](https://messaginglayersecurity.rocks/mls-protocol/draft-ietf-mls-protocol.html#name-key-schedule)
+ /// of the current epoch.
+ pub fn epoch_authenticator(&self) -> Result<Secret, MlsError> {
+ Ok(self.key_schedule.authentication_secret.clone().into())
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn export_secret(
+ &self,
+ label: &[u8],
+ context: &[u8],
+ len: usize,
+ ) -> Result<Secret, MlsError> {
+ self.key_schedule
+ .export_secret(label, context, len, &self.cipher_suite_provider)
+ .await
+ .map(Into::into)
+ }
+
+ /// Export the current epoch's ratchet tree in serialized format.
+ ///
+ /// This function is used to provide the current group tree to new members
+ /// when the `ratchet_tree_extension` is not used according to [`MlsRules::commit_options`].
+ pub fn export_tree(&self) -> ExportedTree<'_> {
+ ExportedTree::new_borrowed(&self.current_epoch_tree().nodes)
+ }
+
+ /// Current version of the MLS protocol in use by this group.
+ pub fn protocol_version(&self) -> ProtocolVersion {
+ self.context().protocol_version
+ }
+
+ /// Current cipher suite in use by this group.
+ pub fn cipher_suite(&self) -> CipherSuite {
+ self.context().cipher_suite
+ }
+
+ /// Current roster
+ pub fn roster(&self) -> Roster<'_> {
+ self.group_state().public_tree.roster()
+ }
+
+ /// Determines equality of two different groups internal states.
+ /// Useful for testing.
+ ///
+ pub fn equal_group_state(a: &Group<C>, b: &Group<C>) -> bool {
+ a.state == b.state && a.key_schedule == b.key_schedule && a.epoch_secrets == b.epoch_secrets
+ }
+
+ #[cfg(feature = "psk")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn get_psk(
+ &self,
+ psks: &[ProposalInfo<PreSharedKeyProposal>],
+ ) -> Result<(PskSecret, Vec<PreSharedKeyID>), MlsError> {
+ if let Some(psk) = self.previous_psk.clone() {
+ // TODO consider throwing error if psks not empty
+ let psk_id = vec![psk.id.clone()];
+ let psk = PskSecret::calculate(&[psk], self.cipher_suite_provider()).await?;
+
+ Ok((psk, psk_id))
+ } else {
+ let psks = psks
+ .iter()
+ .map(|psk| psk.proposal.psk.clone())
+ .collect::<Vec<_>>();
+
+ let psk = PskResolver {
+ group_context: Some(self.context()),
+ current_epoch: Some(&self.epoch_secrets),
+ prior_epochs: Some(&self.state_repo),
+ psk_store: &self.config.secret_store(),
+ }
+ .resolve_to_secret(&psks, self.cipher_suite_provider())
+ .await?;
+
+ Ok((psk, psks))
+ }
+ }
+
+ #[cfg(feature = "private_message")]
+ pub(crate) fn encryption_options(&self) -> Result<EncryptionOptions, MlsError> {
+ self.config
+ .mls_rules()
+ .encryption_options(&self.roster(), self.group_context().extensions())
+ .map_err(|e| MlsError::MlsRulesError(e.into_any_error()))
+ }
+
+ #[cfg(not(feature = "psk"))]
+ fn get_psk(&self) -> PskSecret {
+ PskSecret::new(self.cipher_suite_provider())
+ }
+
+ #[cfg(feature = "secret_tree_access")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ #[inline(never)]
+ pub async fn next_encryption_key(&mut self) -> Result<MessageKey, MlsError> {
+ self.epoch_secrets
+ .secret_tree
+ .next_message_key(
+ &self.cipher_suite_provider,
+ crate::tree_kem::node::NodeIndex::from(self.private_tree.self_index),
+ KeyType::Application,
+ )
+ .await
+ }
+
+ #[cfg(feature = "secret_tree_access")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn derive_decryption_key(
+ &mut self,
+ sender: u32,
+ generation: u32,
+ ) -> Result<MessageKey, MlsError> {
+ self.epoch_secrets
+ .secret_tree
+ .message_key_generation(
+ &self.cipher_suite_provider,
+ crate::tree_kem::node::NodeIndex::from(sender),
+ KeyType::Application,
+ generation,
+ )
+ .await
+ }
+}
+
+#[cfg(feature = "private_message")]
+impl<C> GroupStateProvider for Group<C>
+where
+ C: ClientConfig + Clone,
+{
+ fn group_context(&self) -> &GroupContext {
+ self.context()
+ }
+
+ fn self_index(&self) -> LeafIndex {
+ self.private_tree.self_index
+ }
+
+ fn epoch_secrets_mut(&mut self) -> &mut EpochSecrets {
+ &mut self.epoch_secrets
+ }
+
+ fn epoch_secrets(&self) -> &EpochSecrets {
+ &self.epoch_secrets
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(all(target_arch = "wasm32", mls_build_async), maybe_async::must_be_async(?Send))]
+#[cfg_attr(
+ all(not(target_arch = "wasm32"), mls_build_async),
+ maybe_async::must_be_async
+)]
+impl<C> MessageProcessor for Group<C>
+where
+ C: ClientConfig + Clone,
+{
+ type MlsRules = C::MlsRules;
+ type IdentityProvider = C::IdentityProvider;
+ type PreSharedKeyStorage = C::PskStore;
+ type OutputType = ReceivedMessage;
+ type CipherSuiteProvider = <C::CryptoProvider as CryptoProvider>::CipherSuiteProvider;
+
+ #[cfg(feature = "private_message")]
+ fn self_index(&self) -> Option<LeafIndex> {
+ Some(self.private_tree.self_index)
+ }
+
+ #[cfg(feature = "private_message")]
+ async fn process_ciphertext(
+ &mut self,
+ cipher_text: &PrivateMessage,
+ ) -> Result<EventOrContent<Self::OutputType>, MlsError> {
+ self.decrypt_incoming_ciphertext(cipher_text)
+ .await
+ .map(EventOrContent::Content)
+ }
+
+ async fn verify_plaintext_authentication(
+ &self,
+ message: PublicMessage,
+ ) -> Result<EventOrContent<Self::OutputType>, MlsError> {
+ let auth_content = verify_plaintext_authentication(
+ &self.cipher_suite_provider,
+ message,
+ Some(&self.key_schedule),
+ Some(self.private_tree.self_index),
+ &self.state,
+ )
+ .await?;
+
+ Ok(EventOrContent::Content(auth_content))
+ }
+
+ async fn apply_update_path(
+ &mut self,
+ sender: LeafIndex,
+ update_path: &ValidatedUpdatePath,
+ provisional_state: &mut ProvisionalState,
+ ) -> Result<Option<(TreeKemPrivate, PathSecret)>, MlsError> {
+ // Update the private tree to create a provisional private tree
+ let (mut provisional_private_tree, new_signer) =
+ self.provisional_private_tree(provisional_state)?;
+
+ if let Some(signer) = new_signer {
+ self.signer = signer;
+ }
+
+ provisional_state
+ .public_tree
+ .apply_update_path(
+ sender,
+ update_path,
+ &provisional_state.group_context.extensions,
+ self.identity_provider(),
+ self.cipher_suite_provider(),
+ )
+ .await?;
+
+ if let Some(pending) = &self.pending_commit {
+ Ok(Some((
+ pending.pending_private_tree.clone(),
+ pending.pending_commit_secret.clone(),
+ )))
+ } else {
+ // Update the tree hash to get context for decryption
+ provisional_state.group_context.tree_hash = provisional_state
+ .public_tree
+ .tree_hash(&self.cipher_suite_provider)
+ .await?;
+
+ let context_bytes = provisional_state.group_context.mls_encode_to_vec()?;
+
+ TreeKem::new(
+ &mut provisional_state.public_tree,
+ &mut provisional_private_tree,
+ )
+ .decap(
+ sender,
+ update_path,
+ &provisional_state.indexes_of_added_kpkgs,
+ &context_bytes,
+ &self.cipher_suite_provider,
+ )
+ .await
+ .map(|root_secret| Some((provisional_private_tree, root_secret)))
+ }
+ }
+
+ async fn update_key_schedule(
+ &mut self,
+ secrets: Option<(TreeKemPrivate, PathSecret)>,
+ interim_transcript_hash: InterimTranscriptHash,
+ confirmation_tag: &ConfirmationTag,
+ provisional_state: ProvisionalState,
+ ) -> Result<(), MlsError> {
+ let commit_secret = if let Some(secrets) = secrets {
+ self.private_tree = secrets.0;
+ secrets.1
+ } else {
+ PathSecret::empty(&self.cipher_suite_provider)
+ };
+
+ // Use the commit_secret, the psk_secret, the provisional GroupContext, and the init secret
+ // from the previous epoch (or from the external init) to compute the epoch secret and
+ // derived secrets for the new epoch
+
+ let key_schedule = match provisional_state
+ .applied_proposals
+ .external_initializations
+ .first()
+ .cloned()
+ {
+ Some(ext_init) if self.pending_commit.is_none() => {
+ self.key_schedule
+ .derive_for_external(&ext_init.proposal.kem_output, &self.cipher_suite_provider)
+ .await?
+ }
+ _ => self.key_schedule.clone(),
+ };
+
+ #[cfg(feature = "psk")]
+ let (psk, _) = self
+ .get_psk(&provisional_state.applied_proposals.psks)
+ .await?;
+
+ #[cfg(not(feature = "psk"))]
+ let psk = self.get_psk();
+
+ let key_schedule_result = KeySchedule::from_key_schedule(
+ &key_schedule,
+ &commit_secret,
+ &provisional_state.group_context,
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ provisional_state.public_tree.total_leaf_count(),
+ &psk,
+ &self.cipher_suite_provider,
+ )
+ .await?;
+
+ // Use the confirmation_key for the new epoch to compute the confirmation tag for
+ // this message, as described below, and verify that it is the same as the
+ // confirmation_tag field in the MlsPlaintext object.
+ let new_confirmation_tag = ConfirmationTag::create(
+ &key_schedule_result.confirmation_key,
+ &provisional_state.group_context.confirmed_transcript_hash,
+ &self.cipher_suite_provider,
+ )
+ .await?;
+
+ if &new_confirmation_tag != confirmation_tag {
+ return Err(MlsError::InvalidConfirmationTag);
+ }
+
+ #[cfg(feature = "prior_epoch")]
+ let signature_public_keys = self
+ .state
+ .public_tree
+ .leaves()
+ .map(|l| l.map(|n| n.signing_identity.signature_key.clone()))
+ .collect();
+
+ #[cfg(feature = "prior_epoch")]
+ let past_epoch = PriorEpoch {
+ context: self.context().clone(),
+ self_index: self.private_tree.self_index,
+ secrets: self.epoch_secrets.clone(),
+ signature_public_keys,
+ };
+
+ #[cfg(feature = "prior_epoch")]
+ self.state_repo.insert(past_epoch).await?;
+
+ self.epoch_secrets = key_schedule_result.epoch_secrets;
+ self.state.context = provisional_state.group_context;
+ self.state.interim_transcript_hash = interim_transcript_hash;
+ self.key_schedule = key_schedule_result.key_schedule;
+ self.state.public_tree = provisional_state.public_tree;
+ self.state.confirmation_tag = new_confirmation_tag;
+
+ // Clear the proposals list
+ #[cfg(feature = "by_ref_proposal")]
+ self.state.proposals.clear();
+
+ // Clear the pending updates list
+ #[cfg(feature = "by_ref_proposal")]
+ {
+ self.pending_updates = Default::default();
+ }
+
+ self.pending_commit = None;
+
+ Ok(())
+ }
+
+ fn mls_rules(&self) -> Self::MlsRules {
+ self.config.mls_rules()
+ }
+
+ fn identity_provider(&self) -> Self::IdentityProvider {
+ self.config.identity_provider()
+ }
+
+ fn psk_storage(&self) -> Self::PreSharedKeyStorage {
+ self.config.secret_store()
+ }
+
+ fn group_state(&self) -> &GroupState {
+ &self.state
+ }
+
+ fn group_state_mut(&mut self) -> &mut GroupState {
+ &mut self.state
+ }
+
+ fn can_continue_processing(&self, provisional_state: &ProvisionalState) -> bool {
+ !(provisional_state
+ .applied_proposals
+ .removals
+ .iter()
+ .any(|p| p.proposal.to_remove == self.private_tree.self_index)
+ && self.pending_commit.is_none())
+ }
+
+ #[cfg(feature = "private_message")]
+ fn min_epoch_available(&self) -> Option<u64> {
+ None
+ }
+
+ fn cipher_suite_provider(&self) -> &Self::CipherSuiteProvider {
+ &self.cipher_suite_provider
+ }
+}
+
+#[cfg(test)]
+pub(crate) mod test_utils;
+
+#[cfg(test)]
+mod tests {
+ use crate::{
+ client::test_utils::{
+ test_client_with_key_pkg, TestClientBuilder, TEST_CIPHER_SUITE,
+ TEST_CUSTOM_PROPOSAL_TYPE, TEST_PROTOCOL_VERSION,
+ },
+ client_builder::{test_utils::TestClientConfig, ClientBuilder, MlsConfig},
+ crypto::test_utils::TestCryptoProvider,
+ group::{
+ mls_rules::{CommitDirection, CommitSource},
+ proposal_filter::ProposalBundle,
+ },
+ identity::{
+ basic::BasicIdentityProvider,
+ test_utils::{get_test_signing_identity, BasicWithCustomProvider},
+ },
+ key_package::test_utils::test_key_package_message,
+ mls_rules::CommitOptions,
+ tree_kem::{
+ leaf_node::{test_utils::get_test_capabilities, LeafNodeSource},
+ UpdatePathNode,
+ },
+ };
+
+ #[cfg(any(feature = "private_message", feature = "custom_proposal"))]
+ use crate::group::mls_rules::DefaultMlsRules;
+
+ #[cfg(feature = "prior_epoch")]
+ use crate::group::padding::PaddingMode;
+
+ use crate::{extension::RequiredCapabilitiesExt, key_package::test_utils::test_key_package};
+
+ #[cfg(all(feature = "by_ref_proposal", feature = "custom_proposal"))]
+ use super::test_utils::test_group_custom_config;
+
+ #[cfg(feature = "psk")]
+ use crate::{client::Client, psk::PreSharedKey};
+
+ #[cfg(any(feature = "by_ref_proposal", feature = "private_message"))]
+ use crate::group::test_utils::random_bytes;
+
+ #[cfg(feature = "by_ref_proposal")]
+ use crate::{
+ extension::test_utils::TestExtension, identity::test_utils::get_test_basic_credential,
+ time::MlsTime,
+ };
+
+ use super::{
+ test_utils::{
+ get_test_25519_key, get_test_groups_with_features, group_extensions, process_commit,
+ test_group, test_group_custom, test_n_member_group, TestGroup, TEST_GROUP,
+ },
+ *,
+ };
+
+ use assert_matches::assert_matches;
+
+ use mls_rs_core::extension::{Extension, ExtensionType};
+ use mls_rs_core::identity::{Credential, CredentialType, CustomCredential};
+
+ #[cfg(feature = "by_ref_proposal")]
+ use mls_rs_core::identity::CertificateChain;
+
+ #[cfg(feature = "state_update")]
+ use itertools::Itertools;
+
+ #[cfg(feature = "state_update")]
+ use alloc::format;
+
+ #[cfg(feature = "by_ref_proposal")]
+ use crate::{crypto::test_utils::test_cipher_suite_provider, extension::ExternalSendersExt};
+
+ #[cfg(any(feature = "private_message", feature = "state_update"))]
+ use super::test_utils::test_member;
+
+ use mls_rs_core::extension::MlsExtension;
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_create_group() {
+ for (protocol_version, cipher_suite) in ProtocolVersion::all().flat_map(|p| {
+ TestCryptoProvider::all_supported_cipher_suites()
+ .into_iter()
+ .map(move |cs| (p, cs))
+ }) {
+ let test_group = test_group(protocol_version, cipher_suite).await;
+ let group = test_group.group;
+
+ assert_eq!(group.cipher_suite(), cipher_suite);
+ assert_eq!(group.state.context.epoch, 0);
+ assert_eq!(group.state.context.group_id, TEST_GROUP.to_vec());
+ assert_eq!(group.state.context.extensions, group_extensions());
+
+ assert_eq!(
+ group.state.context.confirmed_transcript_hash,
+ ConfirmedTranscriptHash::from(vec![])
+ );
+
+ #[cfg(feature = "private_message")]
+ assert!(group.state.proposals.is_empty());
+
+ #[cfg(feature = "by_ref_proposal")]
+ assert!(group.pending_updates.is_empty());
+
+ assert!(!group.has_pending_commit());
+
+ assert_eq!(
+ group.private_tree.self_index.0,
+ group.current_member_index()
+ );
+ }
+ }
+
+ #[cfg(feature = "private_message")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_pending_proposals_application_data() {
+ let mut test_group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+
+ // Create a proposal
+ let (bob_key_package, _) =
+ test_member(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, b"bob").await;
+
+ let proposal = test_group
+ .group
+ .add_proposal(bob_key_package.key_package_message())
+ .unwrap();
+
+ test_group
+ .group
+ .proposal_message(proposal, vec![])
+ .await
+ .unwrap();
+
+ // We should not be able to send application messages until a commit happens
+ let res = test_group
+ .group
+ .encrypt_application_message(b"test", vec![])
+ .await;
+
+ assert_matches!(res, Err(MlsError::CommitRequired));
+
+ // We should be able to send application messages after a commit
+ test_group.group.commit(vec![]).await.unwrap();
+
+ assert!(test_group.group.has_pending_commit());
+
+ test_group.group.apply_pending_commit().await.unwrap();
+
+ let res = test_group
+ .group
+ .encrypt_application_message(b"test", vec![])
+ .await;
+
+ assert!(res.is_ok());
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_update_proposals() {
+ let new_extension = TestExtension { foo: 10 };
+ let mut extension_list = ExtensionList::default();
+ extension_list.set_from(new_extension).unwrap();
+
+ let mut test_group = test_group_custom(
+ TEST_PROTOCOL_VERSION,
+ TEST_CIPHER_SUITE,
+ vec![42.into()],
+ Some(extension_list.clone()),
+ None,
+ )
+ .await;
+
+ let existing_leaf = test_group.group.current_user_leaf_node().unwrap().clone();
+
+ // Create an update proposal
+ let proposal = test_group.update_proposal().await;
+
+ let update = match proposal {
+ Proposal::Update(update) => update,
+ _ => panic!("non update proposal found"),
+ };
+
+ assert_ne!(update.leaf_node.public_key, existing_leaf.public_key);
+
+ assert_eq!(
+ update.leaf_node.signing_identity,
+ existing_leaf.signing_identity
+ );
+
+ assert_eq!(update.leaf_node.ungreased_extensions(), extension_list);
+ assert_eq!(
+ update.leaf_node.ungreased_capabilities().sorted(),
+ Capabilities {
+ extensions: vec![42.into()],
+ ..get_test_capabilities()
+ }
+ .sorted()
+ );
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_invalid_commit_self_update() {
+ let mut test_group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+
+ // Create an update proposal
+ let proposal_msg = test_group.group.propose_update(vec![]).await.unwrap();
+
+ let proposal = match proposal_msg.into_plaintext().unwrap().content.content {
+ Content::Proposal(p) => p,
+ _ => panic!("found non-proposal message"),
+ };
+
+ let update_leaf = match *proposal {
+ Proposal::Update(u) => u.leaf_node,
+ _ => panic!("found proposal message that isn't an update"),
+ };
+
+ test_group.group.commit(vec![]).await.unwrap();
+ test_group.group.apply_pending_commit().await.unwrap();
+
+ // The leaf node should not be the one from the update, because the committer rejects it
+ assert_ne!(
+ &update_leaf,
+ test_group.group.current_user_leaf_node().unwrap()
+ );
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn update_proposal_with_bad_key_package_is_ignored_when_committing() {
+ let (mut alice_group, mut bob_group) =
+ test_two_member_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, true).await;
+
+ let mut proposal = alice_group.update_proposal().await;
+
+ if let Proposal::Update(ref mut update) = proposal {
+ update.leaf_node.signature = random_bytes(32);
+ } else {
+ panic!("Invalid update proposal")
+ }
+
+ let proposal_message = alice_group
+ .group
+ .proposal_message(proposal.clone(), vec![])
+ .await
+ .unwrap();
+
+ let proposal_plaintext = match proposal_message.payload {
+ MlsMessagePayload::Plain(p) => p,
+ _ => panic!("Unexpected non-plaintext message"),
+ };
+
+ let proposal_ref = ProposalRef::from_content(
+ &bob_group.group.cipher_suite_provider,
+ &proposal_plaintext.clone().into(),
+ )
+ .await
+ .unwrap();
+
+ // Hack bob's receipt of the proposal
+ bob_group.group.state.proposals.insert(
+ proposal_ref,
+ proposal,
+ proposal_plaintext.content.sender,
+ );
+
+ let commit_output = bob_group.group.commit(vec![]).await.unwrap();
+
+ assert_matches!(
+ commit_output.commit_message,
+ MlsMessage {
+ payload: MlsMessagePayload::Plain(
+ PublicMessage {
+ content: FramedContent {
+ content: Content::Commit(c),
+ ..
+ },
+ ..
+ }),
+ ..
+ } if c.proposals.is_empty()
+ );
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn test_two_member_group(
+ protocol_version: ProtocolVersion,
+ cipher_suite: CipherSuite,
+ tree_ext: bool,
+ ) -> (TestGroup, TestGroup) {
+ let mut test_group = test_group_custom(
+ protocol_version,
+ cipher_suite,
+ Default::default(),
+ None,
+ Some(CommitOptions::new().with_ratchet_tree_extension(tree_ext)),
+ )
+ .await;
+
+ let (bob_test_group, _) = test_group.join("bob").await;
+
+ assert!(Group::equal_group_state(
+ &test_group.group,
+ &bob_test_group.group
+ ));
+
+ (test_group, bob_test_group)
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_welcome_processing_exported_tree() {
+ test_two_member_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, false).await;
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_welcome_processing_tree_extension() {
+ test_two_member_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, true).await;
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_welcome_processing_missing_tree() {
+ let mut test_group = test_group_custom(
+ TEST_PROTOCOL_VERSION,
+ TEST_CIPHER_SUITE,
+ Default::default(),
+ None,
+ Some(CommitOptions::new().with_ratchet_tree_extension(false)),
+ )
+ .await;
+
+ let (bob_client, bob_key_package) =
+ test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "bob").await;
+
+ // Add bob to the group
+ let commit_output = test_group
+ .group
+ .commit_builder()
+ .add_member(bob_key_package)
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ // Group from Bob's perspective
+ let bob_group = Group::join(
+ &commit_output.welcome_messages[0],
+ None,
+ bob_client.config,
+ bob_client.signer.unwrap(),
+ )
+ .await
+ .map(|_| ());
+
+ assert_matches!(bob_group, Err(MlsError::RatchetTreeNotFound));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_group_context_ext_proposal_create() {
+ let test_group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+
+ let mut extension_list = ExtensionList::new();
+ extension_list
+ .set_from(RequiredCapabilitiesExt {
+ extensions: vec![42.into()],
+ proposals: vec![],
+ credentials: vec![],
+ })
+ .unwrap();
+
+ let proposal = test_group
+ .group
+ .group_context_extensions_proposal(extension_list.clone());
+
+ assert_matches!(proposal, Proposal::GroupContextExtensions(ext) if ext == extension_list);
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn group_context_extension_proposal_test(
+ ext_list: ExtensionList,
+ ) -> (TestGroup, Result<MlsMessage, MlsError>) {
+ let protocol_version = TEST_PROTOCOL_VERSION;
+ let cipher_suite = TEST_CIPHER_SUITE;
+
+ let mut test_group =
+ test_group_custom(protocol_version, cipher_suite, vec![42.into()], None, None).await;
+
+ let commit = test_group
+ .group
+ .commit_builder()
+ .set_group_context_ext(ext_list)
+ .unwrap()
+ .build()
+ .await
+ .map(|commit_output| commit_output.commit_message);
+
+ (test_group, commit)
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_group_context_ext_proposal_commit() {
+ let mut extension_list = ExtensionList::new();
+
+ extension_list
+ .set_from(RequiredCapabilitiesExt {
+ extensions: vec![42.into()],
+ proposals: vec![],
+ credentials: vec![],
+ })
+ .unwrap();
+
+ let (mut test_group, _) =
+ group_context_extension_proposal_test(extension_list.clone()).await;
+
+ #[cfg(feature = "state_update")]
+ {
+ let update = test_group.group.apply_pending_commit().await.unwrap();
+ assert!(update.state_update.active);
+ }
+
+ #[cfg(not(feature = "state_update"))]
+ test_group.group.apply_pending_commit().await.unwrap();
+
+ assert_eq!(test_group.group.state.context.extensions, extension_list)
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_group_context_ext_proposal_invalid() {
+ let mut extension_list = ExtensionList::new();
+ extension_list
+ .set_from(RequiredCapabilitiesExt {
+ extensions: vec![999.into()],
+ proposals: vec![],
+ credentials: vec![],
+ })
+ .unwrap();
+
+ let (_, commit) = group_context_extension_proposal_test(extension_list.clone()).await;
+
+ assert_matches!(
+ commit,
+ Err(MlsError::RequiredExtensionNotFound(a)) if a == 999.into()
+ );
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn make_group_with_required_capabilities(
+ required_caps: RequiredCapabilitiesExt,
+ ) -> Result<Group<TestClientConfig>, MlsError> {
+ test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "alice")
+ .await
+ .0
+ .create_group(core::iter::once(required_caps.into_extension().unwrap()).collect())
+ .await
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn creating_group_with_member_not_supporting_required_credential_type_fails() {
+ let group_creation = make_group_with_required_capabilities(RequiredCapabilitiesExt {
+ credentials: vec![CredentialType::BASIC, CredentialType::X509],
+ ..Default::default()
+ })
+ .await
+ .map(|_| ());
+
+ assert_matches!(
+ group_creation,
+ Err(MlsError::RequiredCredentialNotFound(CredentialType::X509))
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn creating_group_with_member_not_supporting_required_extension_type_fails() {
+ const EXTENSION_TYPE: ExtensionType = ExtensionType::new(33);
+
+ let group_creation = make_group_with_required_capabilities(RequiredCapabilitiesExt {
+ extensions: vec![EXTENSION_TYPE],
+ ..Default::default()
+ })
+ .await
+ .map(|_| ());
+
+ assert_matches!(
+ group_creation,
+ Err(MlsError::RequiredExtensionNotFound(EXTENSION_TYPE))
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn creating_group_with_member_not_supporting_required_proposal_type_fails() {
+ const PROPOSAL_TYPE: ProposalType = ProposalType::new(33);
+
+ let group_creation = make_group_with_required_capabilities(RequiredCapabilitiesExt {
+ proposals: vec![PROPOSAL_TYPE],
+ ..Default::default()
+ })
+ .await
+ .map(|_| ());
+
+ assert_matches!(
+ group_creation,
+ Err(MlsError::RequiredProposalNotFound(PROPOSAL_TYPE))
+ );
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg(not(target_arch = "wasm32"))]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn creating_group_with_member_not_supporting_external_sender_credential_fails() {
+ let ext_senders = make_x509_external_senders_ext()
+ .await
+ .into_extension()
+ .unwrap();
+
+ let group_creation =
+ test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "alice")
+ .await
+ .0
+ .create_group(core::iter::once(ext_senders).collect())
+ .await
+ .map(|_| ());
+
+ assert_matches!(
+ group_creation,
+ Err(MlsError::RequiredCredentialNotFound(CredentialType::X509))
+ );
+ }
+
+ #[cfg(all(not(target_arch = "wasm32"), feature = "private_message"))]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_group_encrypt_plaintext_padding() {
+ let protocol_version = TEST_PROTOCOL_VERSION;
+ // This test requires a cipher suite whose signatures are not variable in length.
+ let cipher_suite = CipherSuite::CURVE25519_AES128;
+
+ let mut test_group = test_group_custom_config(protocol_version, cipher_suite, |b| {
+ b.mls_rules(
+ DefaultMlsRules::default()
+ .with_encryption_options(EncryptionOptions::new(true, PaddingMode::None)),
+ )
+ })
+ .await;
+
+ let without_padding = test_group
+ .group
+ .encrypt_application_message(&random_bytes(150), vec![])
+ .await
+ .unwrap();
+
+ let mut test_group =
+ test_group_custom_config(protocol_version, cipher_suite, |b| {
+ b.mls_rules(DefaultMlsRules::default().with_encryption_options(
+ EncryptionOptions::new(true, PaddingMode::StepFunction),
+ ))
+ })
+ .await;
+
+ let with_padding = test_group
+ .group
+ .encrypt_application_message(&random_bytes(150), vec![])
+ .await
+ .unwrap();
+
+ assert!(with_padding.mls_encoded_len() > without_padding.mls_encoded_len());
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn external_commit_requires_external_pub_extension() {
+ let protocol_version = TEST_PROTOCOL_VERSION;
+ let cipher_suite = TEST_CIPHER_SUITE;
+ let group = test_group(protocol_version, cipher_suite).await;
+
+ let info = group
+ .group
+ .group_info_message(false)
+ .await
+ .unwrap()
+ .into_group_info()
+ .unwrap();
+
+ let info_msg = MlsMessage::new(protocol_version, MlsMessagePayload::GroupInfo(info));
+
+ let signing_identity = group
+ .group
+ .current_member_signing_identity()
+ .unwrap()
+ .clone();
+
+ let res = external_commit::ExternalCommitBuilder::new(
+ group.group.signer,
+ signing_identity,
+ group.group.config,
+ )
+ .build(info_msg)
+ .await
+ .map(|_| {});
+
+ assert_matches!(res, Err(MlsError::MissingExternalPubExtension));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn external_commit_via_commit_options_round_trip() {
+ let mut group = test_group_custom(
+ TEST_PROTOCOL_VERSION,
+ TEST_CIPHER_SUITE,
+ vec![],
+ None,
+ CommitOptions::default()
+ .with_allow_external_commit(true)
+ .into(),
+ )
+ .await;
+
+ let commit_output = group.group.commit(vec![]).await.unwrap();
+
+ let (test_client, _) =
+ test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "bob").await;
+
+ test_client
+ .external_commit_builder()
+ .unwrap()
+ .build(commit_output.external_commit_group_info.unwrap())
+ .await
+ .unwrap();
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_path_update_preference() {
+ let protocol_version = TEST_PROTOCOL_VERSION;
+ let cipher_suite = TEST_CIPHER_SUITE;
+
+ let mut test_group = test_group_custom(
+ protocol_version,
+ cipher_suite,
+ Default::default(),
+ None,
+ Some(CommitOptions::new()),
+ )
+ .await;
+
+ let test_key_package =
+ test_key_package_message(protocol_version, cipher_suite, "alice").await;
+
+ test_group
+ .group
+ .commit_builder()
+ .add_member(test_key_package.clone())
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ assert!(test_group
+ .group
+ .pending_commit
+ .unwrap()
+ .pending_commit_secret
+ .iter()
+ .all(|x| x == &0));
+
+ let mut test_group = test_group_custom(
+ protocol_version,
+ cipher_suite,
+ Default::default(),
+ None,
+ Some(CommitOptions::new().with_path_required(true)),
+ )
+ .await;
+
+ test_group
+ .group
+ .commit_builder()
+ .add_member(test_key_package)
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ assert!(!test_group
+ .group
+ .pending_commit
+ .unwrap()
+ .pending_commit_secret
+ .iter()
+ .all(|x| x == &0));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_path_update_preference_override() {
+ let protocol_version = TEST_PROTOCOL_VERSION;
+ let cipher_suite = TEST_CIPHER_SUITE;
+
+ let mut test_group = test_group_custom(
+ protocol_version,
+ cipher_suite,
+ Default::default(),
+ None,
+ Some(CommitOptions::new()),
+ )
+ .await;
+
+ test_group.group.commit(vec![]).await.unwrap();
+
+ assert!(!test_group
+ .group
+ .pending_commit
+ .unwrap()
+ .pending_commit_secret
+ .iter()
+ .all(|x| x == &0));
+ }
+
+ #[cfg(feature = "private_message")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn group_rejects_unencrypted_application_message() {
+ let protocol_version = TEST_PROTOCOL_VERSION;
+ let cipher_suite = TEST_CIPHER_SUITE;
+
+ let mut alice = test_group(protocol_version, cipher_suite).await;
+ let (mut bob, _) = alice.join("bob").await;
+
+ let message = alice
+ .make_plaintext(Content::Application(b"hello".to_vec().into()))
+ .await;
+
+ let res = bob.group.process_incoming_message(message).await;
+
+ assert_matches!(res, Err(MlsError::UnencryptedApplicationMessage));
+ }
+
+ #[cfg(feature = "state_update")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_state_update() {
+ let protocol_version = TEST_PROTOCOL_VERSION;
+ let cipher_suite = TEST_CIPHER_SUITE;
+
+ // Create a group with 10 members
+ let mut alice = test_group(protocol_version, cipher_suite).await;
+ let (mut bob, _) = alice.join("bob").await;
+ let mut leaves = vec![];
+
+ for i in 0..8 {
+ let (group, commit) = alice.join(&format!("charlie{i}")).await;
+ leaves.push(group.group.current_user_leaf_node().unwrap().clone());
+ bob.process_message(commit).await.unwrap();
+ }
+
+ // Create many proposals, make Alice commit them
+
+ let update_message = bob.group.propose_update(vec![]).await.unwrap();
+
+ alice.process_message(update_message).await.unwrap();
+
+ let external_psk_ids: Vec<ExternalPskId> = (0..5)
+ .map(|i| {
+ let external_id = ExternalPskId::new(vec![i]);
+
+ alice
+ .group
+ .config
+ .secret_store()
+ .insert(ExternalPskId::new(vec![i]), PreSharedKey::from(vec![i]));
+
+ bob.group
+ .config
+ .secret_store()
+ .insert(ExternalPskId::new(vec![i]), PreSharedKey::from(vec![i]));
+
+ external_id
+ })
+ .collect();
+
+ let mut commit_builder = alice.group.commit_builder();
+
+ for external_psk in external_psk_ids {
+ commit_builder = commit_builder.add_external_psk(external_psk).unwrap();
+ }
+
+ for index in [2, 5, 6] {
+ commit_builder = commit_builder.remove_member(index).unwrap();
+ }
+
+ for i in 0..5 {
+ let (key_package, _) = test_member(
+ protocol_version,
+ cipher_suite,
+ format!("dave{i}").as_bytes(),
+ )
+ .await;
+
+ commit_builder = commit_builder
+ .add_member(key_package.key_package_message())
+ .unwrap()
+ }
+
+ let commit_output = commit_builder.build().await.unwrap();
+
+ let commit_description = alice.process_pending_commit().await.unwrap();
+
+ assert!(!commit_description.is_external);
+
+ assert_eq!(
+ commit_description.committer,
+ alice.group.current_member_index()
+ );
+
+ // Check that applying pending commit and processing commit yields correct update.
+ let state_update_alice = commit_description.state_update.clone();
+
+ assert_eq!(
+ state_update_alice
+ .roster_update
+ .added()
+ .iter()
+ .map(|m| m.index)
+ .collect::<Vec<_>>(),
+ vec![2, 5, 6, 10, 11]
+ );
+
+ assert_eq!(
+ state_update_alice.roster_update.removed(),
+ vec![2, 5, 6]
+ .into_iter()
+ .map(|i| member_from_leaf_node(&leaves[i as usize - 2], LeafIndex(i)))
+ .collect::<Vec<_>>()
+ );
+
+ assert_eq!(
+ state_update_alice
+ .roster_update
+ .updated()
+ .iter()
+ .map(|update| update.new.clone())
+ .collect_vec()
+ .as_slice(),
+ &alice.group.roster().members()[0..2]
+ );
+
+ assert_eq!(
+ state_update_alice.added_psks,
+ (0..5)
+ .map(|i| ExternalPskId::new(vec![i]))
+ .collect::<Vec<_>>()
+ );
+
+ let payload = bob
+ .process_message(commit_output.commit_message)
+ .await
+ .unwrap();
+
+ let ReceivedMessage::Commit(bob_commit_description) = payload else {
+ panic!("expected commit");
+ };
+
+ assert_eq!(commit_description, bob_commit_description);
+ }
+
+ #[cfg(feature = "state_update")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn commit_description_external_commit() {
+ use crate::client::test_utils::TestClientBuilder;
+
+ let mut alice_group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+
+ let (bob_identity, secret_key) = get_test_signing_identity(TEST_CIPHER_SUITE, b"bob").await;
+
+ let bob = TestClientBuilder::new_for_test()
+ .signing_identity(bob_identity, secret_key, TEST_CIPHER_SUITE)
+ .build();
+
+ let (bob_group, commit) = bob
+ .external_commit_builder()
+ .unwrap()
+ .build(
+ alice_group
+ .group
+ .group_info_message_allowing_ext_commit(true)
+ .await
+ .unwrap(),
+ )
+ .await
+ .unwrap();
+
+ let event = alice_group.process_message(commit).await.unwrap();
+
+ let ReceivedMessage::Commit(commit_description) = event else {
+ panic!("expected commit");
+ };
+
+ assert!(commit_description.is_external);
+ assert_eq!(commit_description.committer, 1);
+
+ assert_eq!(
+ commit_description.state_update.roster_update.added(),
+ &bob_group.roster().members()[1..2]
+ );
+
+ itertools::assert_equal(
+ bob_group.roster().members_iter(),
+ alice_group.group.roster().members_iter(),
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn can_join_new_group_externally() {
+ use crate::client::test_utils::TestClientBuilder;
+
+ let mut alice_group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+
+ let (bob_identity, secret_key) = get_test_signing_identity(TEST_CIPHER_SUITE, b"bob").await;
+
+ let bob = TestClientBuilder::new_for_test()
+ .signing_identity(bob_identity, secret_key, TEST_CIPHER_SUITE)
+ .build();
+
+ let (_, commit) = bob
+ .external_commit_builder()
+ .unwrap()
+ .with_tree_data(alice_group.group.export_tree().into_owned())
+ .build(
+ alice_group
+ .group
+ .group_info_message_allowing_ext_commit(false)
+ .await
+ .unwrap(),
+ )
+ .await
+ .unwrap();
+
+ alice_group.process_message(commit).await.unwrap();
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_membership_tag_from_non_member() {
+ let (mut alice_group, mut bob_group) =
+ test_two_member_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, true).await;
+
+ let mut commit_output = alice_group.group.commit(vec![]).await.unwrap();
+
+ let plaintext = match commit_output.commit_message.payload {
+ MlsMessagePayload::Plain(ref mut plain) => plain,
+ _ => panic!("Non plaintext message"),
+ };
+
+ plaintext.content.sender = Sender::NewMemberCommit;
+
+ let res = bob_group
+ .process_message(commit_output.commit_message)
+ .await;
+
+ assert_matches!(res, Err(MlsError::MembershipTagForNonMember));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_partial_commits() {
+ let protocol_version = TEST_PROTOCOL_VERSION;
+ let cipher_suite = TEST_CIPHER_SUITE;
+
+ let mut alice = test_group(protocol_version, cipher_suite).await;
+ let (mut bob, _) = alice.join("bob").await;
+ let (mut charlie, commit) = alice.join("charlie").await;
+ bob.process_message(commit).await.unwrap();
+
+ let (_, commit) = charlie.join("dave").await;
+
+ alice.process_message(commit.clone()).await.unwrap();
+ bob.process_message(commit.clone()).await.unwrap();
+
+ let Content::Commit(commit) = commit.into_plaintext().unwrap().content.content else {
+ panic!("Expected commit")
+ };
+
+ assert!(commit.path.is_none());
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn group_with_path_required() -> TestGroup {
+ let mut alice = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+
+ alice.group.config.0.mls_rules.commit_options.path_required = true;
+
+ alice
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn old_hpke_secrets_are_removed() {
+ let mut alice = group_with_path_required().await;
+ alice.join("bob").await;
+ alice.join("charlie").await;
+
+ alice
+ .group
+ .commit_builder()
+ .remove_member(1)
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ assert!(alice.group.private_tree.secret_keys[1].is_some());
+ alice.process_pending_commit().await.unwrap();
+ assert!(alice.group.private_tree.secret_keys[1].is_none());
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn old_hpke_secrets_of_removed_are_removed() {
+ let mut alice = group_with_path_required().await;
+ alice.join("bob").await;
+ let (mut charlie, _) = alice.join("charlie").await;
+
+ let commit = charlie
+ .group
+ .commit_builder()
+ .remove_member(1)
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ assert!(alice.group.private_tree.secret_keys[1].is_some());
+ alice.process_message(commit.commit_message).await.unwrap();
+ assert!(alice.group.private_tree.secret_keys[1].is_none());
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn old_hpke_secrets_of_updated_are_removed() {
+ let mut alice = group_with_path_required().await;
+ let (mut bob, _) = alice.join("bob").await;
+ let (mut charlie, commit) = alice.join("charlie").await;
+ bob.process_message(commit).await.unwrap();
+
+ let update = bob.group.propose_update(vec![]).await.unwrap();
+ charlie.process_message(update.clone()).await.unwrap();
+ alice.process_message(update).await.unwrap();
+
+ let commit = charlie.group.commit(vec![]).await.unwrap();
+
+ assert!(alice.group.private_tree.secret_keys[1].is_some());
+ alice.process_message(commit.commit_message).await.unwrap();
+ assert!(alice.group.private_tree.secret_keys[1].is_none());
+ }
+
+ #[cfg(feature = "psk")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn only_selected_members_of_the_original_group_can_join_subgroup() {
+ let mut alice = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let (mut bob, _) = alice.join("bob").await;
+ let (carol, commit) = alice.join("carol").await;
+
+ // Apply the commit that adds carol
+ bob.group.process_incoming_message(commit).await.unwrap();
+
+ let bob_identity = bob.group.current_member_signing_identity().unwrap().clone();
+ let signer = bob.group.signer.clone();
+
+ let new_key_pkg = Client::new(
+ bob.group.config.clone(),
+ Some(signer),
+ Some((bob_identity, TEST_CIPHER_SUITE)),
+ TEST_PROTOCOL_VERSION,
+ )
+ .generate_key_package_message()
+ .await
+ .unwrap();
+
+ let (mut alice_sub_group, welcome) = alice
+ .group
+ .branch(b"subgroup".to_vec(), vec![new_key_pkg])
+ .await
+ .unwrap();
+
+ let welcome = &welcome[0];
+
+ let (mut bob_sub_group, _) = bob.group.join_subgroup(welcome, None).await.unwrap();
+
+ // Carol can't join
+ let res = carol.group.join_subgroup(welcome, None).await.map(|_| ());
+ assert_matches!(res, Err(_));
+
+ // Alice and Bob can still talk
+ let commit_output = alice_sub_group.commit(vec![]).await.unwrap();
+
+ bob_sub_group
+ .process_incoming_message(commit_output.commit_message)
+ .await
+ .unwrap();
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn joining_group_fails_if_unsupported<F>(
+ f: F,
+ ) -> Result<(TestGroup, MlsMessage), MlsError>
+ where
+ F: FnMut(&mut TestClientConfig),
+ {
+ let mut alice_group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ alice_group.join_with_custom_config("alice", false, f).await
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn joining_group_fails_if_protocol_version_is_not_supported() {
+ let res = joining_group_fails_if_unsupported(|config| {
+ config.0.settings.protocol_versions.clear();
+ })
+ .await
+ .map(|_| ());
+
+ assert_matches!(
+ res,
+ Err(MlsError::UnsupportedProtocolVersion(v)) if v ==
+ TEST_PROTOCOL_VERSION
+ );
+ }
+
+ // WebCrypto does not support disabling ciphersuites
+ #[cfg(not(target_arch = "wasm32"))]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn joining_group_fails_if_cipher_suite_is_not_supported() {
+ let res = joining_group_fails_if_unsupported(|config| {
+ config
+ .0
+ .crypto_provider
+ .enabled_cipher_suites
+ .retain(|&x| x != TEST_CIPHER_SUITE);
+ })
+ .await
+ .map(|_| ());
+
+ assert_matches!(
+ res,
+ Err(MlsError::UnsupportedCipherSuite(TEST_CIPHER_SUITE))
+ );
+ }
+
+ #[cfg(feature = "private_message")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn member_can_see_sender_creds() {
+ let mut alice_group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let (mut bob_group, _) = alice_group.join("bob").await;
+
+ let bob_msg = b"I'm Bob";
+
+ let msg = bob_group
+ .group
+ .encrypt_application_message(bob_msg, vec![])
+ .await
+ .unwrap();
+
+ let received_by_alice = alice_group
+ .group
+ .process_incoming_message(msg)
+ .await
+ .unwrap();
+
+ assert_matches!(
+ received_by_alice,
+ ReceivedMessage::ApplicationMessage(ApplicationMessageDescription { sender_index, .. })
+ if sender_index == bob_group.group.current_member_index()
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn members_of_a_group_have_identical_authentication_secrets() {
+ let mut alice_group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let (bob_group, _) = alice_group.join("bob").await;
+
+ assert_eq!(
+ alice_group.group.epoch_authenticator().unwrap(),
+ bob_group.group.epoch_authenticator().unwrap()
+ );
+ }
+
+ #[cfg(feature = "private_message")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn member_cannot_decrypt_same_message_twice() {
+ let mut alice_group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let (mut bob_group, _) = alice_group.join("bob").await;
+
+ let message = alice_group
+ .group
+ .encrypt_application_message(b"foobar", Vec::new())
+ .await
+ .unwrap();
+
+ let received_message = bob_group
+ .group
+ .process_incoming_message(message.clone())
+ .await
+ .unwrap();
+
+ assert_matches!(
+ received_message,
+ ReceivedMessage::ApplicationMessage(m) if m.data() == b"foobar"
+ );
+
+ let res = bob_group.group.process_incoming_message(message).await;
+
+ assert_matches!(res, Err(MlsError::KeyMissing(0)));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn removing_requirements_allows_to_add() {
+ let mut alice_group = test_group_custom(
+ TEST_PROTOCOL_VERSION,
+ TEST_CIPHER_SUITE,
+ vec![17.into()],
+ None,
+ None,
+ )
+ .await;
+
+ alice_group
+ .group
+ .commit_builder()
+ .set_group_context_ext(
+ vec![RequiredCapabilitiesExt {
+ extensions: vec![17.into()],
+ ..Default::default()
+ }
+ .into_extension()
+ .unwrap()]
+ .try_into()
+ .unwrap(),
+ )
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ alice_group.process_pending_commit().await.unwrap();
+
+ let test_key_package =
+ test_key_package(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "bob").await;
+
+ let test_key_package = MlsMessage::new(
+ TEST_PROTOCOL_VERSION,
+ MlsMessagePayload::KeyPackage(test_key_package),
+ );
+
+ alice_group
+ .group
+ .commit_builder()
+ .add_member(test_key_package)
+ .unwrap()
+ .set_group_context_ext(Default::default())
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ let state_update = alice_group
+ .process_pending_commit()
+ .await
+ .unwrap()
+ .state_update;
+
+ #[cfg(feature = "state_update")]
+ assert_eq!(
+ state_update
+ .roster_update
+ .added()
+ .iter()
+ .map(|m| m.index)
+ .collect::<Vec<_>>(),
+ vec![1]
+ );
+
+ #[cfg(not(feature = "state_update"))]
+ assert!(state_update == StateUpdate {});
+
+ assert_eq!(alice_group.group.roster().members_iter().count(), 2);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn commit_leaf_wrong_source() {
+ // RFC, 13.4.2. "The leaf_node_source field MUST be set to commit."
+ let mut groups = test_n_member_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, 3).await;
+
+ groups[0].group.commit_modifiers.modify_leaf = |leaf, sk| {
+ leaf.leaf_node_source = LeafNodeSource::Update;
+ Some(sk.clone())
+ };
+
+ let commit_output = groups[0].group.commit(vec![]).await.unwrap();
+
+ let res = groups[2]
+ .process_message(commit_output.commit_message)
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidLeafNodeSource));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn commit_leaf_same_hpke_key() {
+ // RFC 13.4.2. "Verify that the encryption_key value in the LeafNode is different from the committer's current leaf node"
+
+ let mut groups = test_n_member_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, 3).await;
+
+ // Group 0 starts using fixed key
+ groups[0].group.commit_modifiers.modify_leaf = |leaf, sk| {
+ leaf.public_key = get_test_25519_key(1u8);
+ Some(sk.clone())
+ };
+
+ let commit_output = groups[0].group.commit(vec![]).await.unwrap();
+ groups[0].process_pending_commit().await.unwrap();
+ groups[2]
+ .process_message(commit_output.commit_message)
+ .await
+ .unwrap();
+
+ // Group 0 tries to use the fixed key againd
+ let commit_output = groups[0].group.commit(vec![]).await.unwrap();
+
+ let res = groups[2]
+ .process_message(commit_output.commit_message)
+ .await;
+
+ assert_matches!(res, Err(MlsError::SameHpkeKey(0)));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn commit_leaf_duplicate_hpke_key() {
+ // RFC 8.3 "Verify that the following fields are unique among the members of the group: `encryption_key`"
+
+ if TEST_CIPHER_SUITE != CipherSuite::CURVE25519_AES128
+ && TEST_CIPHER_SUITE != CipherSuite::CURVE25519_CHACHA
+ {
+ return;
+ }
+
+ let mut groups = test_n_member_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, 10).await;
+
+ // Group 1 uses the fixed key
+ groups[1].group.commit_modifiers.modify_leaf = |leaf, sk| {
+ leaf.public_key = get_test_25519_key(1u8);
+ Some(sk.clone())
+ };
+
+ let commit_output = groups
+ .get_mut(1)
+ .unwrap()
+ .group
+ .commit(vec![])
+ .await
+ .unwrap();
+
+ process_commit(&mut groups, commit_output.commit_message, 1).await;
+
+ // Group 0 tries to use the fixed key too
+ groups[0].group.commit_modifiers.modify_leaf = |leaf, sk| {
+ leaf.public_key = get_test_25519_key(1u8);
+ Some(sk.clone())
+ };
+
+ let commit_output = groups[0].group.commit(vec![]).await.unwrap();
+
+ let res = groups[7]
+ .process_message(commit_output.commit_message)
+ .await;
+
+ assert_matches!(res, Err(MlsError::DuplicateLeafData(_)));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn commit_leaf_duplicate_signature_key() {
+ // RFC 8.3 "Verify that the following fields are unique among the members of the group: `signature_key`"
+
+ if TEST_CIPHER_SUITE != CipherSuite::CURVE25519_AES128
+ && TEST_CIPHER_SUITE != CipherSuite::CURVE25519_CHACHA
+ {
+ return;
+ }
+
+ let mut groups = test_n_member_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, 10).await;
+
+ // Group 1 uses the fixed key
+ groups[1].group.commit_modifiers.modify_leaf = |leaf, _| {
+ let sk = hex!(
+ "3468b4c890255c983e3d5cbf5cb64c1ef7f6433a518f2f3151d6672f839a06ebcad4fc381fe61822af45135c82921a348e6f46643d66ddefc70483565433714b"
+ )
+ .into();
+
+ leaf.signing_identity.signature_key =
+ hex!("cad4fc381fe61822af45135c82921a348e6f46643d66ddefc70483565433714b").into();
+
+ Some(sk)
+ };
+
+ let commit_output = groups
+ .get_mut(1)
+ .unwrap()
+ .group
+ .commit(vec![])
+ .await
+ .unwrap();
+
+ process_commit(&mut groups, commit_output.commit_message, 1).await;
+
+ // Group 0 tries to use the fixed key too
+ groups[0].group.commit_modifiers.modify_leaf = |leaf, _| {
+ let sk = hex!(
+ "3468b4c890255c983e3d5cbf5cb64c1ef7f6433a518f2f3151d6672f839a06ebcad4fc381fe61822af45135c82921a348e6f46643d66ddefc70483565433714b"
+ )
+ .into();
+
+ leaf.signing_identity.signature_key =
+ hex!("cad4fc381fe61822af45135c82921a348e6f46643d66ddefc70483565433714b").into();
+
+ Some(sk)
+ };
+
+ let commit_output = groups[0].group.commit(vec![]).await.unwrap();
+
+ let res = groups[7]
+ .process_message(commit_output.commit_message)
+ .await;
+
+ assert_matches!(res, Err(MlsError::DuplicateLeafData(_)));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn commit_leaf_incorrect_signature() {
+ let mut groups = test_n_member_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, 3).await;
+
+ groups[0].group.commit_modifiers.modify_leaf = |leaf, _| {
+ leaf.signature[0] ^= 1;
+ None
+ };
+
+ let commit_output = groups[0].group.commit(vec![]).await.unwrap();
+
+ let res = groups[2]
+ .process_message(commit_output.commit_message)
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidSignature));
+ }
+
+ #[cfg(not(target_arch = "wasm32"))]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn commit_leaf_not_supporting_used_context_extension() {
+ const EXT_TYPE: ExtensionType = ExtensionType::new(999);
+
+ // The new leaf of the committer doesn't support an extension set in group context
+ let extension = Extension::new(EXT_TYPE, vec![]);
+
+ let mut groups =
+ get_test_groups_with_features(3, vec![extension].into(), Default::default()).await;
+
+ groups[0].commit_modifiers.modify_leaf = |leaf, sk| {
+ leaf.capabilities = get_test_capabilities();
+ Some(sk.clone())
+ };
+
+ let commit_output = groups[0].commit(vec![]).await.unwrap();
+
+ let res = groups[1]
+ .process_incoming_message(commit_output.commit_message)
+ .await;
+
+ assert_matches!(res, Err(MlsError::UnsupportedGroupExtension(EXT_TYPE)));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn commit_leaf_not_supporting_required_extension() {
+ // The new leaf of the committer doesn't support an extension required by group context
+
+ let extension = RequiredCapabilitiesExt {
+ extensions: vec![999.into()],
+ proposals: vec![],
+ credentials: vec![],
+ };
+
+ let extensions = vec![extension.into_extension().unwrap()];
+ let mut groups =
+ get_test_groups_with_features(3, extensions.into(), Default::default()).await;
+
+ groups[0].commit_modifiers.modify_leaf = |leaf, sk| {
+ leaf.capabilities = Capabilities::default();
+ Some(sk.clone())
+ };
+
+ let commit_output = groups[0].commit(vec![]).await.unwrap();
+
+ let res = groups[2]
+ .process_incoming_message(commit_output.commit_message)
+ .await;
+
+ assert!(res.is_err());
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn commit_leaf_has_unsupported_credential() {
+ // The new leaf of the committer has a credential unsupported by another leaf
+ let mut groups =
+ get_test_groups_with_features(3, Default::default(), Default::default()).await;
+
+ for group in groups.iter_mut() {
+ group.config.0.identity_provider.allow_any_custom = true;
+ }
+
+ groups[0].commit_modifiers.modify_leaf = |leaf, sk| {
+ leaf.signing_identity.credential = Credential::Custom(CustomCredential::new(
+ CredentialType::new(43),
+ leaf.signing_identity
+ .credential
+ .as_basic()
+ .unwrap()
+ .identifier
+ .to_vec(),
+ ));
+
+ Some(sk.clone())
+ };
+
+ let commit_output = groups[0].commit(vec![]).await.unwrap();
+
+ let res = groups[2]
+ .process_incoming_message(commit_output.commit_message)
+ .await;
+
+ assert_matches!(res, Err(MlsError::CredentialTypeOfNewLeafIsUnsupported));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn commit_leaf_not_supporting_credential_used_in_another_leaf() {
+ // The new leaf of the committer doesn't support another leaf's credential
+
+ let mut groups =
+ get_test_groups_with_features(3, Default::default(), Default::default()).await;
+
+ groups[0].commit_modifiers.modify_leaf = |leaf, sk| {
+ leaf.capabilities.credentials = vec![2.into()];
+ Some(sk.clone())
+ };
+
+ let commit_output = groups[0].commit(vec![]).await.unwrap();
+
+ let res = groups[2]
+ .process_incoming_message(commit_output.commit_message)
+ .await;
+
+ assert_matches!(res, Err(MlsError::InUseCredentialTypeUnsupportedByNewLeaf));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn commit_leaf_not_supporting_required_credential() {
+ // The new leaf of the committer doesn't support a credential required by group context
+
+ let extension = RequiredCapabilitiesExt {
+ extensions: vec![],
+ proposals: vec![],
+ credentials: vec![1.into()],
+ };
+
+ let extensions = vec![extension.into_extension().unwrap()];
+ let mut groups =
+ get_test_groups_with_features(3, extensions.into(), Default::default()).await;
+
+ groups[0].commit_modifiers.modify_leaf = |leaf, sk| {
+ leaf.capabilities.credentials = vec![2.into()];
+ Some(sk.clone())
+ };
+
+ let commit_output = groups[0].commit(vec![]).await.unwrap();
+
+ let res = groups[2]
+ .process_incoming_message(commit_output.commit_message)
+ .await;
+
+ assert_matches!(res, Err(MlsError::RequiredCredentialNotFound(_)));
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg(not(target_arch = "wasm32"))]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn make_x509_external_senders_ext() -> ExternalSendersExt {
+ let (_, ext_sender_pk) = test_cipher_suite_provider(TEST_CIPHER_SUITE)
+ .signature_key_generate()
+ .await
+ .unwrap();
+
+ let ext_sender_id = SigningIdentity {
+ signature_key: ext_sender_pk,
+ credential: Credential::X509(CertificateChain::from(vec![random_bytes(32)])),
+ };
+
+ ExternalSendersExt::new(vec![ext_sender_id])
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg(not(target_arch = "wasm32"))]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn commit_leaf_not_supporting_external_sender_credential_leads_to_rejected_commit() {
+ let ext_senders = make_x509_external_senders_ext()
+ .await
+ .into_extension()
+ .unwrap();
+
+ let mut alice = ClientBuilder::new()
+ .crypto_provider(TestCryptoProvider::new())
+ .identity_provider(
+ BasicWithCustomProvider::default().with_credential_type(CredentialType::X509),
+ )
+ .with_random_signing_identity("alice", TEST_CIPHER_SUITE)
+ .await
+ .build()
+ .create_group(core::iter::once(ext_senders).collect())
+ .await
+ .unwrap();
+
+ // New leaf supports only basic credentials (used by the group) but not X509 used by external sender
+ alice.commit_modifiers.modify_leaf = |leaf, sk| {
+ leaf.capabilities.credentials = vec![CredentialType::BASIC];
+ Some(sk.clone())
+ };
+
+ alice.commit(vec![]).await.unwrap();
+ let res = alice.apply_pending_commit().await;
+
+ assert_matches!(
+ res,
+ Err(MlsError::RequiredCredentialNotFound(CredentialType::X509))
+ );
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg(not(target_arch = "wasm32"))]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn node_not_supporting_external_sender_credential_cannot_join_group() {
+ let ext_senders = make_x509_external_senders_ext()
+ .await
+ .into_extension()
+ .unwrap();
+
+ let mut alice = ClientBuilder::new()
+ .crypto_provider(TestCryptoProvider::new())
+ .identity_provider(
+ BasicWithCustomProvider::default().with_credential_type(CredentialType::X509),
+ )
+ .with_random_signing_identity("alice", TEST_CIPHER_SUITE)
+ .await
+ .build()
+ .create_group(core::iter::once(ext_senders).collect())
+ .await
+ .unwrap();
+
+ let (_, bob_key_pkg) =
+ test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "bob").await;
+
+ let commit = alice
+ .commit_builder()
+ .add_member(bob_key_pkg)
+ .unwrap()
+ .build()
+ .await;
+
+ assert_matches!(
+ commit,
+ Err(MlsError::RequiredCredentialNotFound(CredentialType::X509))
+ );
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg(not(target_arch = "wasm32"))]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn external_senders_extension_is_rejected_if_member_does_not_support_credential_type() {
+ let mut alice = ClientBuilder::new()
+ .crypto_provider(TestCryptoProvider::new())
+ .identity_provider(
+ BasicWithCustomProvider::default().with_credential_type(CredentialType::X509),
+ )
+ .with_random_signing_identity("alice", TEST_CIPHER_SUITE)
+ .await
+ .build()
+ .create_group(Default::default())
+ .await
+ .unwrap();
+
+ let (_, bob_key_pkg) =
+ test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "bob").await;
+
+ alice
+ .commit_builder()
+ .add_member(bob_key_pkg)
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ alice.apply_pending_commit().await.unwrap();
+ assert_eq!(alice.roster().members_iter().count(), 2);
+
+ let ext_senders = make_x509_external_senders_ext()
+ .await
+ .into_extension()
+ .unwrap();
+
+ let res = alice
+ .commit_builder()
+ .set_group_context_ext(core::iter::once(ext_senders).collect())
+ .unwrap()
+ .build()
+ .await;
+
+ assert_matches!(
+ res,
+ Err(MlsError::RequiredCredentialNotFound(CredentialType::X509))
+ );
+ }
+
+ /*
+ * Edge case paths
+ */
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn committing_degenerate_path_succeeds() {
+ let mut groups = test_n_member_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, 10).await;
+
+ groups[0].group.commit_modifiers.modify_tree = |tree: &mut TreeKemPublic| {
+ tree.update_node(get_test_25519_key(1u8), 1).unwrap();
+ tree.update_node(get_test_25519_key(1u8), 3).unwrap();
+ };
+
+ groups[0].group.commit_modifiers.modify_leaf = |leaf, sk| {
+ leaf.public_key = get_test_25519_key(1u8);
+ Some(sk.clone())
+ };
+
+ let commit_output = groups[0].group.commit(vec![]).await.unwrap();
+
+ let res = groups[7]
+ .process_message(commit_output.commit_message)
+ .await;
+
+ assert!(res.is_ok());
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn inserting_key_in_filtered_node_fails() {
+ let mut groups = test_n_member_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, 10).await;
+
+ let commit_output = groups[0]
+ .group
+ .commit_builder()
+ .remove_member(1)
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ groups[0].process_pending_commit().await.unwrap();
+
+ for group in groups.iter_mut().skip(2) {
+ group
+ .process_message(commit_output.commit_message.clone())
+ .await
+ .unwrap();
+ }
+
+ groups[0].group.commit_modifiers.modify_tree = |tree: &mut TreeKemPublic| {
+ tree.update_node(get_test_25519_key(1u8), 1).unwrap();
+ };
+
+ groups[0].group.commit_modifiers.modify_path = |path: Vec<UpdatePathNode>| {
+ let mut path = path;
+ let mut node = path[0].clone();
+ node.public_key = get_test_25519_key(1u8);
+ path.insert(0, node);
+ path
+ };
+
+ let commit_output = groups[0].group.commit(vec![]).await.unwrap();
+
+ let res = groups[7]
+ .process_message(commit_output.commit_message)
+ .await;
+
+ // We should get a path validation error, since the path is too long
+ assert_matches!(res, Err(MlsError::WrongPathLen));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn commit_with_too_short_path_fails() {
+ let mut groups = test_n_member_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, 10).await;
+
+ let commit_output = groups[0]
+ .group
+ .commit_builder()
+ .remove_member(1)
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ groups[0].process_pending_commit().await.unwrap();
+
+ for group in groups.iter_mut().skip(2) {
+ group
+ .process_message(commit_output.commit_message.clone())
+ .await
+ .unwrap();
+ }
+
+ groups[0].group.commit_modifiers.modify_path = |path: Vec<UpdatePathNode>| {
+ let mut path = path;
+ path.pop();
+ path
+ };
+
+ let commit_output = groups[0].group.commit(vec![]).await.unwrap();
+
+ let res = groups[7]
+ .process_message(commit_output.commit_message)
+ .await;
+
+ assert!(res.is_err());
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn update_proposal_can_change_credential() {
+ let mut groups = test_n_member_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, 3).await;
+ let (identity, secret_key) = get_test_signing_identity(TEST_CIPHER_SUITE, b"member").await;
+
+ let update = groups[0]
+ .group
+ .propose_update_with_identity(secret_key, identity.clone(), vec![])
+ .await
+ .unwrap();
+
+ groups[1].process_message(update).await.unwrap();
+ let commit_output = groups[1].group.commit(vec![]).await.unwrap();
+
+ // Check that the credential was updated by in the committer's state.
+ groups[1].process_pending_commit().await.unwrap();
+ let new_member = groups[1].group.roster().member_with_index(0).unwrap();
+
+ assert_eq!(
+ new_member.signing_identity.credential,
+ get_test_basic_credential(b"member".to_vec())
+ );
+
+ assert_eq!(
+ new_member.signing_identity.signature_key,
+ identity.signature_key
+ );
+
+ // Check that the credential was updated in the updater's state.
+ groups[0]
+ .process_message(commit_output.commit_message)
+ .await
+ .unwrap();
+ let new_member = groups[0].group.roster().member_with_index(0).unwrap();
+
+ assert_eq!(
+ new_member.signing_identity.credential,
+ get_test_basic_credential(b"member".to_vec())
+ );
+
+ assert_eq!(
+ new_member.signing_identity.signature_key,
+ identity.signature_key
+ );
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_commit_with_old_adds_fails() {
+ let mut groups = test_n_member_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, 2).await;
+
+ let key_package =
+ test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "foobar").await;
+
+ let proposal = groups[0]
+ .group
+ .propose_add(key_package, vec![])
+ .await
+ .unwrap();
+
+ let commit = groups[0].group.commit(vec![]).await.unwrap().commit_message;
+
+ // 10 years from now
+ let future_time = MlsTime::now().seconds_since_epoch() + 10 * 365 * 24 * 3600;
+
+ let future_time =
+ MlsTime::from_duration_since_epoch(core::time::Duration::from_secs(future_time));
+
+ groups[1]
+ .group
+ .process_incoming_message(proposal)
+ .await
+ .unwrap();
+ let res = groups[1]
+ .group
+ .process_incoming_message_with_time(commit, future_time)
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidLifetime));
+ }
+
+ #[cfg(feature = "custom_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn custom_proposal_setup() -> (TestGroup, TestGroup) {
+ let mut alice = test_group_custom_config(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, |b| {
+ b.custom_proposal_type(TEST_CUSTOM_PROPOSAL_TYPE)
+ })
+ .await;
+
+ let (bob, _) = alice
+ .join_with_custom_config("bob", true, |c| {
+ c.0.settings
+ .custom_proposal_types
+ .push(TEST_CUSTOM_PROPOSAL_TYPE)
+ })
+ .await
+ .unwrap();
+
+ (alice, bob)
+ }
+
+ #[cfg(feature = "custom_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn custom_proposal_by_value() {
+ let (mut alice, mut bob) = custom_proposal_setup().await;
+
+ let custom_proposal = CustomProposal::new(TEST_CUSTOM_PROPOSAL_TYPE, vec![0, 1, 2]);
+
+ let commit = alice
+ .group
+ .commit_builder()
+ .custom_proposal(custom_proposal.clone())
+ .build()
+ .await
+ .unwrap()
+ .commit_message;
+
+ let res = bob.group.process_incoming_message(commit).await.unwrap();
+
+ #[cfg(feature = "state_update")]
+ assert_matches!(res, ReceivedMessage::Commit(CommitMessageDescription { state_update: StateUpdate { custom_proposals, .. }, .. })
+ if custom_proposals.len() == 1 && custom_proposals[0].proposal == custom_proposal);
+
+ #[cfg(not(feature = "state_update"))]
+ assert_matches!(res, ReceivedMessage::Commit(_));
+ }
+
+ #[cfg(feature = "custom_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn custom_proposal_by_reference() {
+ let (mut alice, mut bob) = custom_proposal_setup().await;
+
+ let custom_proposal = CustomProposal::new(TEST_CUSTOM_PROPOSAL_TYPE, vec![0, 1, 2]);
+
+ let proposal = alice
+ .group
+ .propose_custom(custom_proposal.clone(), vec![])
+ .await
+ .unwrap();
+
+ let recv_prop = bob.group.process_incoming_message(proposal).await.unwrap();
+
+ assert_matches!(recv_prop, ReceivedMessage::Proposal(ProposalMessageDescription { proposal: Proposal::Custom(c), ..})
+ if c == custom_proposal);
+
+ let commit = bob.group.commit(vec![]).await.unwrap().commit_message;
+ let res = alice.group.process_incoming_message(commit).await.unwrap();
+
+ #[cfg(feature = "state_update")]
+ assert_matches!(res, ReceivedMessage::Commit(CommitMessageDescription { state_update: StateUpdate { custom_proposals, .. }, .. })
+ if custom_proposals.len() == 1 && custom_proposals[0].proposal == custom_proposal);
+
+ #[cfg(not(feature = "state_update"))]
+ assert_matches!(res, ReceivedMessage::Commit(_));
+ }
+
+ #[cfg(feature = "psk")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn can_join_with_psk() {
+ let mut alice = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE)
+ .await
+ .group;
+
+ let (bob, key_pkg) =
+ test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "bob").await;
+
+ let psk_id = ExternalPskId::new(vec![0]);
+ let psk = PreSharedKey::from(vec![0]);
+
+ alice
+ .config
+ .secret_store()
+ .insert(psk_id.clone(), psk.clone());
+
+ bob.config.secret_store().insert(psk_id.clone(), psk);
+
+ let commit = alice
+ .commit_builder()
+ .add_member(key_pkg)
+ .unwrap()
+ .add_external_psk(psk_id)
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ bob.join_group(None, &commit.welcome_messages[0])
+ .await
+ .unwrap();
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn invalid_update_does_not_prevent_other_updates() {
+ const EXTENSION_TYPE: ExtensionType = ExtensionType::new(33);
+
+ let group_extensions = ExtensionList::from(vec![RequiredCapabilitiesExt {
+ extensions: vec![EXTENSION_TYPE],
+ ..Default::default()
+ }
+ .into_extension()
+ .unwrap()]);
+
+ // Alice creates a group requiring support for an extension
+ let mut alice = TestClientBuilder::new_for_test()
+ .with_random_signing_identity("alice", TEST_CIPHER_SUITE)
+ .await
+ .extension_type(EXTENSION_TYPE)
+ .build()
+ .create_group(group_extensions.clone())
+ .await
+ .unwrap();
+
+ let (bob_signing_identity, bob_secret_key) =
+ get_test_signing_identity(TEST_CIPHER_SUITE, b"bob").await;
+
+ let bob_client = TestClientBuilder::new_for_test()
+ .signing_identity(
+ bob_signing_identity.clone(),
+ bob_secret_key.clone(),
+ TEST_CIPHER_SUITE,
+ )
+ .extension_type(EXTENSION_TYPE)
+ .build();
+
+ let carol_client = TestClientBuilder::new_for_test()
+ .with_random_signing_identity("carol", TEST_CIPHER_SUITE)
+ .await
+ .extension_type(EXTENSION_TYPE)
+ .build();
+
+ let dave_client = TestClientBuilder::new_for_test()
+ .with_random_signing_identity("dave", TEST_CIPHER_SUITE)
+ .await
+ .extension_type(EXTENSION_TYPE)
+ .build();
+
+ // Alice adds Bob, Carol and Dave to the group. They all support the mandatory extension.
+ let commit = alice
+ .commit_builder()
+ .add_member(bob_client.generate_key_package_message().await.unwrap())
+ .unwrap()
+ .add_member(carol_client.generate_key_package_message().await.unwrap())
+ .unwrap()
+ .add_member(dave_client.generate_key_package_message().await.unwrap())
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ alice.apply_pending_commit().await.unwrap();
+
+ let mut bob = bob_client
+ .join_group(None, &commit.welcome_messages[0])
+ .await
+ .unwrap()
+ .0;
+
+ bob.write_to_storage().await.unwrap();
+
+ // Bob reloads his group data, but with parameters that will cause his generated leaves to
+ // not support the mandatory extension.
+ let mut bob = TestClientBuilder::new_for_test()
+ .signing_identity(bob_signing_identity, bob_secret_key, TEST_CIPHER_SUITE)
+ .key_package_repo(bob.config.key_package_repo())
+ .group_state_storage(bob.config.group_state_storage())
+ .build()
+ .load_group(alice.group_id())
+ .await
+ .unwrap();
+
+ let mut carol = carol_client
+ .join_group(None, &commit.welcome_messages[0])
+ .await
+ .unwrap()
+ .0;
+
+ let mut dave = dave_client
+ .join_group(None, &commit.welcome_messages[0])
+ .await
+ .unwrap()
+ .0;
+
+ // Bob's updated leaf does not support the mandatory extension.
+ let bob_update = bob.propose_update(Vec::new()).await.unwrap();
+ let carol_update = carol.propose_update(Vec::new()).await.unwrap();
+ let dave_update = dave.propose_update(Vec::new()).await.unwrap();
+
+ // Alice receives the update proposals to be committed.
+ alice.process_incoming_message(bob_update).await.unwrap();
+ alice.process_incoming_message(carol_update).await.unwrap();
+ alice.process_incoming_message(dave_update).await.unwrap();
+
+ // Alice commits the update proposals.
+ alice.commit(Vec::new()).await.unwrap();
+ let commit_desc = alice.apply_pending_commit().await.unwrap();
+
+ let find_update_for = |id: &str| {
+ commit_desc
+ .state_update
+ .roster_update
+ .updated()
+ .iter()
+ .filter_map(|u| u.prior.signing_identity.credential.as_basic())
+ .any(|c| c.identifier == id.as_bytes())
+ };
+
+ // Check that all updates preserve identities.
+ let identities_are_preserved = commit_desc
+ .state_update
+ .roster_update
+ .updated()
+ .iter()
+ .filter_map(|u| {
+ let before = &u.prior.signing_identity.credential.as_basic()?.identifier;
+ let after = &u.new.signing_identity.credential.as_basic()?.identifier;
+ Some((before, after))
+ })
+ .all(|(before, after)| before == after);
+
+ assert!(identities_are_preserved);
+
+ // Carol's and Dave's updates should be part of the commit.
+ assert!(find_update_for("carol"));
+ assert!(find_update_for("dave"));
+
+ // Bob's update should be rejected.
+ assert!(!find_update_for("bob"));
+
+ // Check that all members are still in the group.
+ let all_members_are_in = alice
+ .roster()
+ .members_iter()
+ .zip(["alice", "bob", "carol", "dave"])
+ .all(|(member, id)| {
+ member
+ .signing_identity
+ .credential
+ .as_basic()
+ .unwrap()
+ .identifier
+ == id.as_bytes()
+ });
+
+ assert!(all_members_are_in);
+ }
+
+ #[cfg(feature = "custom_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn custom_proposal_may_enforce_path() {
+ test_custom_proposal_mls_rules(true).await;
+ }
+
+ #[cfg(feature = "custom_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn custom_proposal_need_not_enforce_path() {
+ test_custom_proposal_mls_rules(false).await;
+ }
+
+ #[cfg(feature = "custom_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn test_custom_proposal_mls_rules(path_required_for_custom: bool) {
+ let mls_rules = CustomMlsRules {
+ path_required_for_custom,
+ external_joiner_can_send_custom: true,
+ };
+
+ let mut alice = client_with_custom_rules(b"alice", mls_rules.clone())
+ .await
+ .create_group(Default::default())
+ .await
+ .unwrap();
+
+ let alice_pub_before = alice.current_user_leaf_node().unwrap().public_key.clone();
+
+ let kp = client_with_custom_rules(b"bob", mls_rules)
+ .await
+ .generate_key_package_message()
+ .await
+ .unwrap();
+
+ alice
+ .commit_builder()
+ .custom_proposal(CustomProposal::new(TEST_CUSTOM_PROPOSAL_TYPE, vec![]))
+ .add_member(kp)
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ alice.apply_pending_commit().await.unwrap();
+
+ let alice_pub_after = &alice.current_user_leaf_node().unwrap().public_key;
+
+ if path_required_for_custom {
+ assert_ne!(alice_pub_after, &alice_pub_before);
+ } else {
+ assert_eq!(alice_pub_after, &alice_pub_before);
+ }
+ }
+
+ #[cfg(feature = "custom_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn custom_proposal_by_value_in_external_join_may_be_allowed() {
+ test_custom_proposal_by_value_in_external_join(true).await
+ }
+
+ #[cfg(feature = "custom_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn custom_proposal_by_value_in_external_join_may_not_be_allowed() {
+ test_custom_proposal_by_value_in_external_join(false).await
+ }
+
+ #[cfg(feature = "custom_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn test_custom_proposal_by_value_in_external_join(external_joiner_can_send_custom: bool) {
+ let mls_rules = CustomMlsRules {
+ path_required_for_custom: true,
+ external_joiner_can_send_custom,
+ };
+
+ let mut alice = client_with_custom_rules(b"alice", mls_rules.clone())
+ .await
+ .create_group(Default::default())
+ .await
+ .unwrap();
+
+ let group_info = alice
+ .group_info_message_allowing_ext_commit(true)
+ .await
+ .unwrap();
+
+ let commit = client_with_custom_rules(b"bob", mls_rules)
+ .await
+ .external_commit_builder()
+ .unwrap()
+ .with_custom_proposal(CustomProposal::new(TEST_CUSTOM_PROPOSAL_TYPE, vec![]))
+ .build(group_info)
+ .await;
+
+ if external_joiner_can_send_custom {
+ let commit = commit.unwrap().1;
+ alice.process_incoming_message(commit).await.unwrap();
+ } else {
+ assert_matches!(commit.map(|_| ()), Err(MlsError::MlsRulesError(_)));
+ }
+ }
+
+ #[cfg(feature = "custom_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn custom_proposal_by_ref_in_external_join() {
+ let mls_rules = CustomMlsRules {
+ path_required_for_custom: true,
+ external_joiner_can_send_custom: true,
+ };
+
+ let mut alice = client_with_custom_rules(b"alice", mls_rules.clone())
+ .await
+ .create_group(Default::default())
+ .await
+ .unwrap();
+
+ let by_ref = CustomProposal::new(TEST_CUSTOM_PROPOSAL_TYPE, vec![]);
+ let by_ref = alice.propose_custom(by_ref, vec![]).await.unwrap();
+
+ let group_info = alice
+ .group_info_message_allowing_ext_commit(true)
+ .await
+ .unwrap();
+
+ let (_, commit) = client_with_custom_rules(b"bob", mls_rules)
+ .await
+ .external_commit_builder()
+ .unwrap()
+ .with_received_custom_proposal(by_ref)
+ .build(group_info)
+ .await
+ .unwrap();
+
+ alice.process_incoming_message(commit).await.unwrap();
+ }
+
+ #[cfg(feature = "custom_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn client_with_custom_rules(
+ name: &[u8],
+ mls_rules: CustomMlsRules,
+ ) -> Client<impl MlsConfig> {
+ let (signing_identity, signer) = get_test_signing_identity(TEST_CIPHER_SUITE, name).await;
+
+ ClientBuilder::new()
+ .crypto_provider(TestCryptoProvider::new())
+ .identity_provider(BasicWithCustomProvider::new(BasicIdentityProvider::new()))
+ .signing_identity(signing_identity, signer, TEST_CIPHER_SUITE)
+ .custom_proposal_type(TEST_CUSTOM_PROPOSAL_TYPE)
+ .mls_rules(mls_rules)
+ .build()
+ }
+
+ #[derive(Debug, Clone)]
+ struct CustomMlsRules {
+ path_required_for_custom: bool,
+ external_joiner_can_send_custom: bool,
+ }
+
+ #[cfg(feature = "custom_proposal")]
+ impl ProposalBundle {
+ fn has_test_custom_proposal(&self) -> bool {
+ self.custom_proposal_types()
+ .any(|t| t == TEST_CUSTOM_PROPOSAL_TYPE)
+ }
+ }
+
+ #[cfg(feature = "custom_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ #[cfg_attr(mls_build_async, maybe_async::must_be_async)]
+ impl crate::MlsRules for CustomMlsRules {
+ type Error = MlsError;
+
+ fn commit_options(
+ &self,
+ _: &Roster,
+ _: &ExtensionList,
+ proposals: &ProposalBundle,
+ ) -> Result<CommitOptions, MlsError> {
+ Ok(CommitOptions::default().with_path_required(
+ !proposals.has_test_custom_proposal() || self.path_required_for_custom,
+ ))
+ }
+
+ fn encryption_options(
+ &self,
+ _: &Roster,
+ _: &ExtensionList,
+ ) -> Result<crate::mls_rules::EncryptionOptions, MlsError> {
+ Ok(Default::default())
+ }
+
+ async fn filter_proposals(
+ &self,
+ _: CommitDirection,
+ sender: CommitSource,
+ _: &Roster,
+ _: &ExtensionList,
+ proposals: ProposalBundle,
+ ) -> Result<ProposalBundle, MlsError> {
+ let is_external = matches!(sender, CommitSource::NewMember(_));
+ let has_custom = proposals.has_test_custom_proposal();
+ let allowed = !has_custom || !is_external || self.external_joiner_can_send_custom;
+
+ allowed.then_some(proposals).ok_or(MlsError::InvalidSender)
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn group_can_receive_commit_from_self() {
+ let mut group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE)
+ .await
+ .group;
+
+ let commit = group.commit(vec![]).await.unwrap();
+
+ let update = group
+ .process_incoming_message(commit.commit_message)
+ .await
+ .unwrap();
+
+ let ReceivedMessage::Commit(update) = update else {
+ panic!("expected commit message")
+ };
+
+ assert_eq!(update.committer, *group.private_tree.self_index);
+ }
+}
diff --git a/src/group/padding.rs b/src/group/padding.rs
new file mode 100644
index 0000000..6320ccf
--- /dev/null
+++ b/src/group/padding.rs
@@ -0,0 +1,109 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+/// Padding used when sending an encrypted group message.
+#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::ffi_type)]
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
+#[repr(u8)]
+pub enum PaddingMode {
+ /// Step function based on the size of the message being sent.
+ /// The amount of padding used will increase with the size of the original
+ /// message.
+ #[default]
+ StepFunction,
+ /// No padding.
+ None,
+}
+
+impl PaddingMode {
+ pub(super) fn padded_size(&self, content_size: usize) -> usize {
+ match self {
+ PaddingMode::StepFunction => {
+ // The padding hides all but 2 most significant bits of `length`. The hidden bits are replaced
+ // by zeros and then the next number is taken to make sure the message fits.
+ let blind = 1
+ << ((content_size + 1)
+ .next_power_of_two()
+ .max(256)
+ .trailing_zeros()
+ - 3);
+
+ (content_size | (blind - 1)) + 1
+ }
+ PaddingMode::None => content_size,
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::PaddingMode;
+
+ use alloc::vec;
+ use alloc::vec::Vec;
+ #[cfg(target_arch = "wasm32")]
+ use wasm_bindgen_test::wasm_bindgen_test as test;
+
+ #[derive(serde::Deserialize, serde::Serialize)]
+ struct TestCase {
+ input: usize,
+ output: usize,
+ }
+
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn generate_message_padding_test_vector() -> Vec<TestCase> {
+ let mut test_cases = vec![];
+ for x in 1..1024 {
+ test_cases.push(TestCase {
+ input: x,
+ output: PaddingMode::StepFunction.padded_size(x),
+ });
+ }
+ test_cases
+ }
+
+ fn load_test_cases() -> Vec<TestCase> {
+ load_test_case_json!(
+ message_padding_test_vector,
+ generate_message_padding_test_vector()
+ )
+ }
+
+ #[test]
+ fn test_no_padding() {
+ for i in [0, 100, 1000, 10000] {
+ assert_eq!(PaddingMode::None.padded_size(i), i)
+ }
+ }
+
+ #[test]
+ fn test_padding_length() {
+ assert_eq!(PaddingMode::StepFunction.padded_size(0), 32);
+
+ // Short
+ assert_eq!(PaddingMode::StepFunction.padded_size(63), 64);
+ assert_eq!(PaddingMode::StepFunction.padded_size(64), 96);
+ assert_eq!(PaddingMode::StepFunction.padded_size(65), 96);
+
+ // Almost long and almost short
+ assert_eq!(PaddingMode::StepFunction.padded_size(127), 128);
+ assert_eq!(PaddingMode::StepFunction.padded_size(128), 160);
+ assert_eq!(PaddingMode::StepFunction.padded_size(129), 160);
+
+ // One length from each of the 4 buckets between 256 and 512
+ assert_eq!(PaddingMode::StepFunction.padded_size(260), 320);
+ assert_eq!(PaddingMode::StepFunction.padded_size(330), 384);
+ assert_eq!(PaddingMode::StepFunction.padded_size(390), 448);
+ assert_eq!(PaddingMode::StepFunction.padded_size(490), 512);
+
+ // All test cases
+ let test_cases: Vec<TestCase> = load_test_cases();
+ for test_case in test_cases {
+ assert_eq!(
+ test_case.output,
+ PaddingMode::StepFunction.padded_size(test_case.input)
+ );
+ }
+ }
+}
diff --git a/src/group/proposal.rs b/src/group/proposal.rs
new file mode 100644
index 0000000..a31be29
--- /dev/null
+++ b/src/group/proposal.rs
@@ -0,0 +1,578 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::{boxed::Box, vec::Vec};
+
+#[cfg(feature = "by_ref_proposal")]
+use crate::tree_kem::leaf_node::LeafNode;
+
+use crate::{
+ client::MlsError, tree_kem::node::LeafIndex, CipherSuite, KeyPackage, MlsMessage,
+ ProtocolVersion,
+};
+use core::fmt::{self, Debug};
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+use mls_rs_core::{group::Capabilities, identity::SigningIdentity};
+
+#[cfg(feature = "by_ref_proposal")]
+use crate::group::proposal_ref::ProposalRef;
+
+pub use mls_rs_core::extension::ExtensionList;
+pub use mls_rs_core::group::ProposalType;
+
+#[cfg(feature = "psk")]
+use crate::psk::{ExternalPskId, JustPreSharedKeyID, PreSharedKeyID};
+
+#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+/// A proposal that adds a member to a [`Group`](crate::group::Group).
+pub struct AddProposal {
+ pub(crate) key_package: KeyPackage,
+}
+
+impl AddProposal {
+ /// The [`SigningIdentity`] of the [`Member`](mls_rs_core::group::Member)
+ /// that will be added by this proposal.
+ pub fn signing_identity(&self) -> &SigningIdentity {
+ self.key_package.signing_identity()
+ }
+
+ /// Client [`Capabilities`] of the [`Member`](mls_rs_core::group::Member)
+ /// that will be added by this proposal.
+ pub fn capabilities(&self) -> Capabilities {
+ self.key_package.leaf_node.ungreased_capabilities()
+ }
+
+ /// Key package extensions that are assoiciated with the
+ /// [`Member`](mls_rs_core::group::Member) that will be added by this proposal.
+ pub fn key_package_extensions(&self) -> ExtensionList {
+ self.key_package.ungreased_extensions()
+ }
+
+ /// Leaf node extensions that will be entered into the group state for the
+ /// [`Member`](mls_rs_core::group::Member) that will be added.
+ pub fn leaf_node_extensions(&self) -> ExtensionList {
+ self.key_package.leaf_node.ungreased_extensions()
+ }
+}
+
+impl From<KeyPackage> for AddProposal {
+ fn from(key_package: KeyPackage) -> Self {
+ Self { key_package }
+ }
+}
+
+impl TryFrom<MlsMessage> for AddProposal {
+ type Error = MlsError;
+
+ fn try_from(value: MlsMessage) -> Result<Self, Self::Error> {
+ value
+ .into_key_package()
+ .ok_or(MlsError::UnexpectedMessageType)
+ .map(Into::into)
+ }
+}
+
+#[cfg(feature = "by_ref_proposal")]
+#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+/// A proposal that will update an existing [`Member`](mls_rs_core::group::Member) of a
+/// [`Group`](crate::group::Group).
+pub struct UpdateProposal {
+ pub(crate) leaf_node: LeafNode,
+}
+
+#[cfg(feature = "by_ref_proposal")]
+impl UpdateProposal {
+ /// The new [`SigningIdentity`] of the [`Member`](mls_rs_core::group::Member)
+ /// that is being updated by this proposal.
+ pub fn signing_identity(&self) -> &SigningIdentity {
+ &self.leaf_node.signing_identity
+ }
+
+ /// New Client [`Capabilities`] of the [`Member`](mls_rs_core::group::Member)
+ /// that will be updated by this proposal.
+ pub fn capabilities(&self) -> Capabilities {
+ self.leaf_node.ungreased_capabilities()
+ }
+
+ /// New Leaf node extensions that will be entered into the group state for the
+ /// [`Member`](mls_rs_core::group::Member) that is being updated by this proposal.
+ pub fn leaf_node_extensions(&self) -> ExtensionList {
+ self.leaf_node.ungreased_extensions()
+ }
+}
+
+#[derive(Clone, Debug, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+/// A proposal to remove an existing [`Member`](mls_rs_core::group::Member) of a
+/// [`Group`](crate::group::Group).
+pub struct RemoveProposal {
+ pub(crate) to_remove: LeafIndex,
+}
+
+impl RemoveProposal {
+ /// The index of the [`Member`](mls_rs_core::group::Member) that will be removed by
+ /// this proposal.
+ pub fn to_remove(&self) -> u32 {
+ *self.to_remove
+ }
+}
+
+impl From<u32> for RemoveProposal {
+ fn from(value: u32) -> Self {
+ RemoveProposal {
+ to_remove: LeafIndex(value),
+ }
+ }
+}
+
+#[cfg(feature = "psk")]
+#[derive(Clone, Debug, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+/// A proposal to add a pre-shared key to a group.
+pub struct PreSharedKeyProposal {
+ pub(crate) psk: PreSharedKeyID,
+}
+
+#[cfg(feature = "psk")]
+impl PreSharedKeyProposal {
+ /// The external pre-shared key id of this proposal.
+ ///
+ /// MLS requires the pre-shared key type for PreSharedKeyProposal to be of
+ /// type `External`.
+ ///
+ /// Returns `None` in the condition that the underlying psk is not external.
+ pub fn external_psk_id(&self) -> Option<&ExternalPskId> {
+ match self.psk.key_id {
+ JustPreSharedKeyID::External(ref ext) => Some(ext),
+ JustPreSharedKeyID::Resumption(_) => None,
+ }
+ }
+}
+
+#[derive(Clone, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+/// A proposal to reinitialize a group using new parameters.
+pub struct ReInitProposal {
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
+ pub(crate) group_id: Vec<u8>,
+ pub(crate) version: ProtocolVersion,
+ pub(crate) cipher_suite: CipherSuite,
+ pub(crate) extensions: ExtensionList,
+}
+
+impl Debug for ReInitProposal {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("ReInitProposal")
+ .field(
+ "group_id",
+ &mls_rs_core::debug::pretty_group_id(&self.group_id),
+ )
+ .field("version", &self.version)
+ .field("cipher_suite", &self.cipher_suite)
+ .field("extensions", &self.extensions)
+ .finish()
+ }
+}
+
+impl ReInitProposal {
+ /// The unique id of the new group post reinitialization.
+ pub fn group_id(&self) -> &[u8] {
+ &self.group_id
+ }
+
+ /// The new protocol version to use post reinitialization.
+ pub fn new_version(&self) -> ProtocolVersion {
+ self.version
+ }
+
+ /// The new ciphersuite to use post reinitialization.
+ pub fn new_cipher_suite(&self) -> CipherSuite {
+ self.cipher_suite
+ }
+
+ /// Group context extensions to set in the new group post reinitialization.
+ pub fn new_group_context_extensions(&self) -> &ExtensionList {
+ &self.extensions
+ }
+}
+
+#[derive(Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+/// A proposal used for external commits.
+pub struct ExternalInit {
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
+ pub(crate) kem_output: Vec<u8>,
+}
+
+impl Debug for ExternalInit {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("ExternalInit")
+ .field(
+ "kem_output",
+ &mls_rs_core::debug::pretty_bytes(&self.kem_output),
+ )
+ .finish()
+ }
+}
+
+#[cfg(feature = "custom_proposal")]
+#[derive(Clone, PartialEq, Eq)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(
+ all(feature = "ffi", not(test)),
+ safer_ffi_gen::ffi_type(clone, opaque)
+)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+/// A user defined custom proposal.
+///
+/// User defined proposals are passed through the protocol as an opaque value.
+pub struct CustomProposal {
+ proposal_type: ProposalType,
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
+ data: Vec<u8>,
+}
+
+#[cfg(feature = "custom_proposal")]
+impl Debug for CustomProposal {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("CustomProposal")
+ .field("proposal_type", &self.proposal_type)
+ .field("data", &mls_rs_core::debug::pretty_bytes(&self.data))
+ .finish()
+ }
+}
+
+#[cfg(feature = "custom_proposal")]
+#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
+impl CustomProposal {
+ /// Create a custom proposal.
+ ///
+ /// # Warning
+ ///
+ /// Avoid using the [`ProposalType`] values that have constants already
+ /// defined by this crate. Using existing constants in a custom proposal
+ /// has unspecified behavior.
+ pub fn new(proposal_type: ProposalType, data: Vec<u8>) -> Self {
+ Self {
+ proposal_type,
+ data,
+ }
+ }
+
+ /// The proposal type used for this custom proposal.
+ pub fn proposal_type(&self) -> ProposalType {
+ self.proposal_type
+ }
+
+ /// The opaque data communicated by this custom proposal.
+ pub fn data(&self) -> &[u8] {
+ &self.data
+ }
+}
+
+/// Trait to simplify creating custom proposals that are serialized with MLS
+/// encoding.
+#[cfg(feature = "custom_proposal")]
+pub trait MlsCustomProposal: MlsSize + MlsEncode + MlsDecode + Sized {
+ fn proposal_type() -> ProposalType;
+
+ fn to_custom_proposal(&self) -> Result<CustomProposal, mls_rs_codec::Error> {
+ Ok(CustomProposal::new(
+ Self::proposal_type(),
+ self.mls_encode_to_vec()?,
+ ))
+ }
+
+ fn from_custom_proposal(proposal: &CustomProposal) -> Result<Self, mls_rs_codec::Error> {
+ if proposal.proposal_type() != Self::proposal_type() {
+ // #[cfg(feature = "std")]
+ // return Err(mls_rs_codec::Error::Custom(
+ // "invalid proposal type".to_string(),
+ // ));
+
+ //#[cfg(not(feature = "std"))]
+ return Err(mls_rs_codec::Error::Custom(4));
+ }
+
+ Self::mls_decode(&mut proposal.data())
+ }
+}
+
+#[allow(clippy::large_enum_variant)]
+#[derive(Clone, Debug, PartialEq)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+#[repr(u16)]
+#[non_exhaustive]
+/// An enum that represents all possible types of proposals.
+pub enum Proposal {
+ Add(alloc::boxed::Box<AddProposal>),
+ #[cfg(feature = "by_ref_proposal")]
+ Update(UpdateProposal),
+ Remove(RemoveProposal),
+ #[cfg(feature = "psk")]
+ Psk(PreSharedKeyProposal),
+ ReInit(ReInitProposal),
+ ExternalInit(ExternalInit),
+ GroupContextExtensions(ExtensionList),
+ #[cfg(feature = "custom_proposal")]
+ Custom(CustomProposal),
+}
+
+impl MlsSize for Proposal {
+ fn mls_encoded_len(&self) -> usize {
+ let inner_len = match self {
+ Proposal::Add(p) => p.mls_encoded_len(),
+ #[cfg(feature = "by_ref_proposal")]
+ Proposal::Update(p) => p.mls_encoded_len(),
+ Proposal::Remove(p) => p.mls_encoded_len(),
+ #[cfg(feature = "psk")]
+ Proposal::Psk(p) => p.mls_encoded_len(),
+ Proposal::ReInit(p) => p.mls_encoded_len(),
+ Proposal::ExternalInit(p) => p.mls_encoded_len(),
+ Proposal::GroupContextExtensions(p) => p.mls_encoded_len(),
+ #[cfg(feature = "custom_proposal")]
+ Proposal::Custom(p) => mls_rs_codec::byte_vec::mls_encoded_len(&p.data),
+ };
+
+ self.proposal_type().mls_encoded_len() + inner_len
+ }
+}
+
+impl MlsEncode for Proposal {
+ fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error> {
+ self.proposal_type().mls_encode(writer)?;
+
+ match self {
+ Proposal::Add(p) => p.mls_encode(writer),
+ #[cfg(feature = "by_ref_proposal")]
+ Proposal::Update(p) => p.mls_encode(writer),
+ Proposal::Remove(p) => p.mls_encode(writer),
+ #[cfg(feature = "psk")]
+ Proposal::Psk(p) => p.mls_encode(writer),
+ Proposal::ReInit(p) => p.mls_encode(writer),
+ Proposal::ExternalInit(p) => p.mls_encode(writer),
+ Proposal::GroupContextExtensions(p) => p.mls_encode(writer),
+ #[cfg(feature = "custom_proposal")]
+ Proposal::Custom(p) => {
+ if p.proposal_type.raw_value() <= 7 {
+ // #[cfg(feature = "std")]
+ // return Err(mls_rs_codec::Error::Custom(
+ // "custom proposal types can not be set to defined values of 0-7".to_string(),
+ // ));
+
+ // #[cfg(not(feature = "std"))]
+ return Err(mls_rs_codec::Error::Custom(2));
+ }
+ mls_rs_codec::byte_vec::mls_encode(&p.data, writer)
+ }
+ }
+ }
+}
+
+impl MlsDecode for Proposal {
+ fn mls_decode(reader: &mut &[u8]) -> Result<Self, mls_rs_codec::Error> {
+ let proposal_type = ProposalType::mls_decode(reader)?;
+
+ Ok(match proposal_type {
+ ProposalType::ADD => {
+ Proposal::Add(alloc::boxed::Box::new(AddProposal::mls_decode(reader)?))
+ }
+ #[cfg(feature = "by_ref_proposal")]
+ ProposalType::UPDATE => Proposal::Update(UpdateProposal::mls_decode(reader)?),
+ ProposalType::REMOVE => Proposal::Remove(RemoveProposal::mls_decode(reader)?),
+ #[cfg(feature = "psk")]
+ ProposalType::PSK => Proposal::Psk(PreSharedKeyProposal::mls_decode(reader)?),
+ ProposalType::RE_INIT => Proposal::ReInit(ReInitProposal::mls_decode(reader)?),
+ ProposalType::EXTERNAL_INIT => {
+ Proposal::ExternalInit(ExternalInit::mls_decode(reader)?)
+ }
+ ProposalType::GROUP_CONTEXT_EXTENSIONS => {
+ Proposal::GroupContextExtensions(ExtensionList::mls_decode(reader)?)
+ }
+ #[cfg(feature = "custom_proposal")]
+ custom => Proposal::Custom(CustomProposal {
+ proposal_type: custom,
+ data: mls_rs_codec::byte_vec::mls_decode(reader)?,
+ }),
+ // TODO fix test dependency on openssl loading codec with default features
+ #[cfg(not(feature = "custom_proposal"))]
+ _ => return Err(mls_rs_codec::Error::Custom(3)),
+ })
+ }
+}
+
+impl Proposal {
+ pub fn proposal_type(&self) -> ProposalType {
+ match self {
+ Proposal::Add(_) => ProposalType::ADD,
+ #[cfg(feature = "by_ref_proposal")]
+ Proposal::Update(_) => ProposalType::UPDATE,
+ Proposal::Remove(_) => ProposalType::REMOVE,
+ #[cfg(feature = "psk")]
+ Proposal::Psk(_) => ProposalType::PSK,
+ Proposal::ReInit(_) => ProposalType::RE_INIT,
+ Proposal::ExternalInit(_) => ProposalType::EXTERNAL_INIT,
+ Proposal::GroupContextExtensions(_) => ProposalType::GROUP_CONTEXT_EXTENSIONS,
+ #[cfg(feature = "custom_proposal")]
+ Proposal::Custom(c) => c.proposal_type,
+ }
+ }
+}
+
+#[derive(Clone, Debug, PartialEq)]
+/// An enum that represents a borrowed version of [`Proposal`].
+pub enum BorrowedProposal<'a> {
+ Add(&'a AddProposal),
+ #[cfg(feature = "by_ref_proposal")]
+ Update(&'a UpdateProposal),
+ Remove(&'a RemoveProposal),
+ #[cfg(feature = "psk")]
+ Psk(&'a PreSharedKeyProposal),
+ ReInit(&'a ReInitProposal),
+ ExternalInit(&'a ExternalInit),
+ GroupContextExtensions(&'a ExtensionList),
+ #[cfg(feature = "custom_proposal")]
+ Custom(&'a CustomProposal),
+}
+
+impl<'a> From<BorrowedProposal<'a>> for Proposal {
+ fn from(value: BorrowedProposal<'a>) -> Self {
+ match value {
+ BorrowedProposal::Add(add) => Proposal::Add(alloc::boxed::Box::new(add.clone())),
+ #[cfg(feature = "by_ref_proposal")]
+ BorrowedProposal::Update(update) => Proposal::Update(update.clone()),
+ BorrowedProposal::Remove(remove) => Proposal::Remove(remove.clone()),
+ #[cfg(feature = "psk")]
+ BorrowedProposal::Psk(psk) => Proposal::Psk(psk.clone()),
+ BorrowedProposal::ReInit(reinit) => Proposal::ReInit(reinit.clone()),
+ BorrowedProposal::ExternalInit(external) => Proposal::ExternalInit(external.clone()),
+ BorrowedProposal::GroupContextExtensions(ext) => {
+ Proposal::GroupContextExtensions(ext.clone())
+ }
+ #[cfg(feature = "custom_proposal")]
+ BorrowedProposal::Custom(custom) => Proposal::Custom(custom.clone()),
+ }
+ }
+}
+
+impl BorrowedProposal<'_> {
+ pub fn proposal_type(&self) -> ProposalType {
+ match self {
+ BorrowedProposal::Add(_) => ProposalType::ADD,
+ #[cfg(feature = "by_ref_proposal")]
+ BorrowedProposal::Update(_) => ProposalType::UPDATE,
+ BorrowedProposal::Remove(_) => ProposalType::REMOVE,
+ #[cfg(feature = "psk")]
+ BorrowedProposal::Psk(_) => ProposalType::PSK,
+ BorrowedProposal::ReInit(_) => ProposalType::RE_INIT,
+ BorrowedProposal::ExternalInit(_) => ProposalType::EXTERNAL_INIT,
+ BorrowedProposal::GroupContextExtensions(_) => ProposalType::GROUP_CONTEXT_EXTENSIONS,
+ #[cfg(feature = "custom_proposal")]
+ BorrowedProposal::Custom(c) => c.proposal_type,
+ }
+ }
+}
+
+impl<'a> From<&'a Proposal> for BorrowedProposal<'a> {
+ fn from(p: &'a Proposal) -> Self {
+ match p {
+ Proposal::Add(p) => BorrowedProposal::Add(p),
+ #[cfg(feature = "by_ref_proposal")]
+ Proposal::Update(p) => BorrowedProposal::Update(p),
+ Proposal::Remove(p) => BorrowedProposal::Remove(p),
+ #[cfg(feature = "psk")]
+ Proposal::Psk(p) => BorrowedProposal::Psk(p),
+ Proposal::ReInit(p) => BorrowedProposal::ReInit(p),
+ Proposal::ExternalInit(p) => BorrowedProposal::ExternalInit(p),
+ Proposal::GroupContextExtensions(p) => BorrowedProposal::GroupContextExtensions(p),
+ #[cfg(feature = "custom_proposal")]
+ Proposal::Custom(p) => BorrowedProposal::Custom(p),
+ }
+ }
+}
+
+impl<'a> From<&'a AddProposal> for BorrowedProposal<'a> {
+ fn from(p: &'a AddProposal) -> Self {
+ Self::Add(p)
+ }
+}
+
+#[cfg(feature = "by_ref_proposal")]
+impl<'a> From<&'a UpdateProposal> for BorrowedProposal<'a> {
+ fn from(p: &'a UpdateProposal) -> Self {
+ Self::Update(p)
+ }
+}
+
+impl<'a> From<&'a RemoveProposal> for BorrowedProposal<'a> {
+ fn from(p: &'a RemoveProposal) -> Self {
+ Self::Remove(p)
+ }
+}
+
+#[cfg(feature = "psk")]
+impl<'a> From<&'a PreSharedKeyProposal> for BorrowedProposal<'a> {
+ fn from(p: &'a PreSharedKeyProposal) -> Self {
+ Self::Psk(p)
+ }
+}
+
+impl<'a> From<&'a ReInitProposal> for BorrowedProposal<'a> {
+ fn from(p: &'a ReInitProposal) -> Self {
+ Self::ReInit(p)
+ }
+}
+
+impl<'a> From<&'a ExternalInit> for BorrowedProposal<'a> {
+ fn from(p: &'a ExternalInit) -> Self {
+ Self::ExternalInit(p)
+ }
+}
+
+impl<'a> From<&'a ExtensionList> for BorrowedProposal<'a> {
+ fn from(p: &'a ExtensionList) -> Self {
+ Self::GroupContextExtensions(p)
+ }
+}
+
+#[cfg(feature = "custom_proposal")]
+impl<'a> From<&'a CustomProposal> for BorrowedProposal<'a> {
+ fn from(p: &'a CustomProposal) -> Self {
+ Self::Custom(p)
+ }
+}
+
+#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+#[repr(u8)]
+pub(crate) enum ProposalOrRef {
+ Proposal(Box<Proposal>) = 1u8,
+ #[cfg(feature = "by_ref_proposal")]
+ Reference(ProposalRef) = 2u8,
+}
+
+impl From<Proposal> for ProposalOrRef {
+ fn from(proposal: Proposal) -> Self {
+ Self::Proposal(Box::new(proposal))
+ }
+}
+
+#[cfg(feature = "by_ref_proposal")]
+impl From<ProposalRef> for ProposalOrRef {
+ fn from(r: ProposalRef) -> Self {
+ Self::Reference(r)
+ }
+}
diff --git a/src/group/proposal_cache.rs b/src/group/proposal_cache.rs
new file mode 100644
index 0000000..17acf79
--- /dev/null
+++ b/src/group/proposal_cache.rs
@@ -0,0 +1,4216 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::vec::Vec;
+
+use super::{
+ message_processor::ProvisionalState,
+ mls_rules::{CommitDirection, CommitSource, MlsRules},
+ GroupState, ProposalOrRef,
+};
+use crate::{
+ client::MlsError,
+ group::{
+ proposal_filter::{ProposalApplier, ProposalBundle, ProposalSource},
+ Proposal, Sender,
+ },
+ time::MlsTime,
+};
+
+#[cfg(feature = "by_ref_proposal")]
+use crate::group::{proposal_filter::FilterStrategy, ProposalRef, ProtocolVersion};
+
+use crate::tree_kem::leaf_node::LeafNode;
+
+#[cfg(all(feature = "std", feature = "by_ref_proposal"))]
+use std::collections::HashMap;
+
+#[cfg(feature = "by_ref_proposal")]
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+
+use mls_rs_core::{
+ crypto::CipherSuiteProvider, error::IntoAnyError, identity::IdentityProvider,
+ psk::PreSharedKeyStorage,
+};
+
+#[cfg(feature = "by_ref_proposal")]
+use core::fmt::{self, Debug};
+
+#[cfg(feature = "by_ref_proposal")]
+#[derive(Debug, Clone, MlsSize, MlsEncode, MlsDecode, PartialEq)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub struct CachedProposal {
+ pub(crate) proposal: Proposal,
+ pub(crate) sender: Sender,
+}
+
+#[cfg(feature = "by_ref_proposal")]
+#[derive(Clone, PartialEq)]
+pub(crate) struct ProposalCache {
+ protocol_version: ProtocolVersion,
+ group_id: Vec<u8>,
+ #[cfg(feature = "std")]
+ pub(crate) proposals: HashMap<ProposalRef, CachedProposal>,
+ #[cfg(not(feature = "std"))]
+ pub(crate) proposals: Vec<(ProposalRef, CachedProposal)>,
+}
+
+#[cfg(feature = "by_ref_proposal")]
+impl Debug for ProposalCache {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("ProposalCache")
+ .field("protocol_version", &self.protocol_version)
+ .field(
+ "group_id",
+ &mls_rs_core::debug::pretty_group_id(&self.group_id),
+ )
+ .field("proposals", &self.proposals)
+ .finish()
+ }
+}
+
+#[cfg(feature = "by_ref_proposal")]
+impl ProposalCache {
+ pub fn new(protocol_version: ProtocolVersion, group_id: Vec<u8>) -> Self {
+ Self {
+ protocol_version,
+ group_id,
+ proposals: Default::default(),
+ }
+ }
+
+ pub fn import(
+ protocol_version: ProtocolVersion,
+ group_id: Vec<u8>,
+ #[cfg(feature = "std")] proposals: HashMap<ProposalRef, CachedProposal>,
+ #[cfg(not(feature = "std"))] proposals: Vec<(ProposalRef, CachedProposal)>,
+ ) -> Self {
+ Self {
+ protocol_version,
+ group_id,
+ proposals,
+ }
+ }
+
+ #[inline]
+ pub fn clear(&mut self) {
+ self.proposals.clear();
+ }
+
+ #[cfg(feature = "private_message")]
+ #[inline]
+ pub fn is_empty(&self) -> bool {
+ self.proposals.is_empty()
+ }
+
+ pub fn insert(&mut self, proposal_ref: ProposalRef, proposal: Proposal, sender: Sender) {
+ let cached_proposal = CachedProposal { proposal, sender };
+
+ #[cfg(feature = "std")]
+ self.proposals.insert(proposal_ref, cached_proposal);
+
+ #[cfg(not(feature = "std"))]
+ // This may result in dups but it does not matter
+ self.proposals.push((proposal_ref, cached_proposal));
+ }
+
+ pub fn prepare_commit(
+ &self,
+ sender: Sender,
+ additional_proposals: Vec<Proposal>,
+ ) -> ProposalBundle {
+ self.proposals
+ .iter()
+ .map(|(r, p)| {
+ (
+ p.proposal.clone(),
+ p.sender,
+ ProposalSource::ByReference(r.clone()),
+ )
+ })
+ .chain(
+ additional_proposals
+ .into_iter()
+ .map(|p| (p, sender, ProposalSource::ByValue)),
+ )
+ .collect()
+ }
+
+ pub fn resolve_for_commit(
+ &self,
+ sender: Sender,
+ proposal_list: Vec<ProposalOrRef>,
+ ) -> Result<ProposalBundle, MlsError> {
+ let mut proposals = ProposalBundle::default();
+
+ for p in proposal_list {
+ match p {
+ ProposalOrRef::Proposal(p) => proposals.add(*p, sender, ProposalSource::ByValue),
+ ProposalOrRef::Reference(r) => {
+ #[cfg(feature = "std")]
+ let p = self
+ .proposals
+ .get(&r)
+ .ok_or(MlsError::ProposalNotFound)?
+ .clone();
+ #[cfg(not(feature = "std"))]
+ let p = self
+ .proposals
+ .iter()
+ .find_map(|(rr, p)| (rr == &r).then_some(p))
+ .ok_or(MlsError::ProposalNotFound)?
+ .clone();
+
+ proposals.add(p.proposal, p.sender, ProposalSource::ByReference(r));
+ }
+ };
+ }
+
+ Ok(proposals)
+ }
+}
+
+#[cfg(not(feature = "by_ref_proposal"))]
+pub(crate) fn prepare_commit(
+ sender: Sender,
+ additional_proposals: Vec<Proposal>,
+) -> ProposalBundle {
+ let mut proposals = ProposalBundle::default();
+
+ for p in additional_proposals.into_iter() {
+ proposals.add(p, sender, ProposalSource::ByValue);
+ }
+
+ proposals
+}
+
+#[cfg(not(feature = "by_ref_proposal"))]
+pub(crate) fn resolve_for_commit(
+ sender: Sender,
+ proposal_list: Vec<ProposalOrRef>,
+) -> Result<ProposalBundle, MlsError> {
+ let mut proposals = ProposalBundle::default();
+
+ for p in proposal_list {
+ let ProposalOrRef::Proposal(p) = p;
+ proposals.add(*p, sender, ProposalSource::ByValue);
+ }
+
+ Ok(proposals)
+}
+
+impl GroupState {
+ #[inline(never)]
+ #[allow(clippy::too_many_arguments)]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn apply_resolved<C, F, P, CSP>(
+ &self,
+ sender: Sender,
+ mut proposals: ProposalBundle,
+ external_leaf: Option<&LeafNode>,
+ identity_provider: &C,
+ cipher_suite_provider: &CSP,
+ psk_storage: &P,
+ user_rules: &F,
+ commit_time: Option<MlsTime>,
+ direction: CommitDirection,
+ ) -> Result<ProvisionalState, MlsError>
+ where
+ C: IdentityProvider,
+ F: MlsRules,
+ P: PreSharedKeyStorage,
+ CSP: CipherSuiteProvider,
+ {
+ let roster = self.public_tree.roster();
+ let group_extensions = &self.context.extensions;
+
+ #[cfg(feature = "by_ref_proposal")]
+ let all_proposals = proposals.clone();
+
+ let origin = match sender {
+ Sender::Member(index) => Ok::<_, MlsError>(CommitSource::ExistingMember(
+ roster.member_with_index(index)?,
+ )),
+ #[cfg(feature = "by_ref_proposal")]
+ Sender::NewMemberProposal => Err(MlsError::InvalidSender),
+ #[cfg(feature = "by_ref_proposal")]
+ Sender::External(_) => Err(MlsError::InvalidSender),
+ Sender::NewMemberCommit => Ok(CommitSource::NewMember(
+ external_leaf
+ .map(|l| l.signing_identity.clone())
+ .ok_or(MlsError::ExternalCommitMustHaveNewLeaf)?,
+ )),
+ }?;
+
+ proposals = user_rules
+ .filter_proposals(direction, origin, &roster, group_extensions, proposals)
+ .await
+ .map_err(|e| MlsError::MlsRulesError(e.into_any_error()))?;
+
+ let applier = ProposalApplier::new(
+ &self.public_tree,
+ self.context.protocol_version,
+ cipher_suite_provider,
+ group_extensions,
+ external_leaf,
+ identity_provider,
+ psk_storage,
+ #[cfg(feature = "by_ref_proposal")]
+ &self.context.group_id,
+ );
+
+ #[cfg(feature = "by_ref_proposal")]
+ let applier_output = match direction {
+ CommitDirection::Send => {
+ applier
+ .apply_proposals(FilterStrategy::IgnoreByRef, &sender, proposals, commit_time)
+ .await?
+ }
+ CommitDirection::Receive => {
+ applier
+ .apply_proposals(FilterStrategy::IgnoreNone, &sender, proposals, commit_time)
+ .await?
+ }
+ };
+
+ #[cfg(not(feature = "by_ref_proposal"))]
+ let applier_output = applier
+ .apply_proposals(&sender, &proposals, commit_time)
+ .await?;
+
+ #[cfg(feature = "by_ref_proposal")]
+ let unused_proposals = unused_proposals(
+ match direction {
+ CommitDirection::Send => all_proposals,
+ CommitDirection::Receive => self.proposals.proposals.iter().collect(),
+ },
+ &applier_output.applied_proposals,
+ );
+
+ let mut group_context = self.context.clone();
+ group_context.epoch += 1;
+
+ if let Some(ext) = applier_output.new_context_extensions {
+ group_context.extensions = ext;
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ let proposals = applier_output.applied_proposals;
+
+ Ok(ProvisionalState {
+ public_tree: applier_output.new_tree,
+ group_context,
+ applied_proposals: proposals,
+ external_init_index: applier_output.external_init_index,
+ indexes_of_added_kpkgs: applier_output.indexes_of_added_kpkgs,
+ #[cfg(feature = "by_ref_proposal")]
+ unused_proposals,
+ })
+ }
+}
+
+#[cfg(feature = "by_ref_proposal")]
+impl Extend<(ProposalRef, CachedProposal)> for ProposalCache {
+ fn extend<T>(&mut self, iter: T)
+ where
+ T: IntoIterator<Item = (ProposalRef, CachedProposal)>,
+ {
+ self.proposals.extend(iter);
+ }
+}
+
+#[cfg(feature = "by_ref_proposal")]
+fn has_ref(proposals: &ProposalBundle, reference: &ProposalRef) -> bool {
+ proposals
+ .iter_proposals()
+ .any(|p| matches!(&p.source, ProposalSource::ByReference(r) if r == reference))
+}
+
+#[cfg(feature = "by_ref_proposal")]
+fn unused_proposals(
+ all_proposals: ProposalBundle,
+ accepted_proposals: &ProposalBundle,
+) -> Vec<crate::mls_rules::ProposalInfo<Proposal>> {
+ all_proposals
+ .into_proposals()
+ .filter(|p| {
+ matches!(p.source, ProposalSource::ByReference(ref r) if !has_ref(accepted_proposals, r)
+ )
+ })
+ .collect()
+}
+
+// TODO add tests for lite version of filtering
+#[cfg(all(feature = "by_ref_proposal", test))]
+pub(crate) mod test_utils {
+ use mls_rs_core::{
+ crypto::CipherSuiteProvider, extension::ExtensionList, identity::IdentityProvider,
+ psk::PreSharedKeyStorage,
+ };
+
+ use crate::{
+ client::test_utils::TEST_PROTOCOL_VERSION,
+ group::{
+ confirmation_tag::ConfirmationTag,
+ mls_rules::{CommitDirection, DefaultMlsRules, MlsRules},
+ proposal::{Proposal, ProposalOrRef},
+ proposal_ref::ProposalRef,
+ state::GroupState,
+ test_utils::{get_test_group_context, TEST_GROUP},
+ GroupContext, LeafIndex, LeafNode, ProvisionalState, Sender, TreeKemPublic,
+ },
+ identity::{basic::BasicIdentityProvider, test_utils::BasicWithCustomProvider},
+ psk::AlwaysFoundPskStorage,
+ };
+
+ use super::{CachedProposal, MlsError, ProposalCache};
+
+ use alloc::vec::Vec;
+
+ impl CachedProposal {
+ pub fn new(proposal: Proposal, sender: Sender) -> Self {
+ Self { proposal, sender }
+ }
+ }
+
+ #[derive(Debug)]
+ pub(crate) struct CommitReceiver<'a, C, F, P, CSP> {
+ tree: &'a TreeKemPublic,
+ sender: Sender,
+ receiver: LeafIndex,
+ cache: ProposalCache,
+ identity_provider: C,
+ cipher_suite_provider: CSP,
+ group_context_extensions: ExtensionList,
+ user_rules: F,
+ with_psk_storage: P,
+ }
+
+ impl<'a, CSP>
+ CommitReceiver<'a, BasicWithCustomProvider, DefaultMlsRules, AlwaysFoundPskStorage, CSP>
+ {
+ pub fn new<S>(
+ tree: &'a TreeKemPublic,
+ sender: S,
+ receiver: LeafIndex,
+ cipher_suite_provider: CSP,
+ ) -> Self
+ where
+ S: Into<Sender>,
+ {
+ Self {
+ tree,
+ sender: sender.into(),
+ receiver,
+ cache: make_proposal_cache(),
+ identity_provider: BasicWithCustomProvider::new(BasicIdentityProvider),
+ group_context_extensions: Default::default(),
+ user_rules: pass_through_rules(),
+ with_psk_storage: AlwaysFoundPskStorage,
+ cipher_suite_provider,
+ }
+ }
+ }
+
+ impl<'a, C, F, P, CSP> CommitReceiver<'a, C, F, P, CSP>
+ where
+ C: IdentityProvider,
+ F: MlsRules,
+ P: PreSharedKeyStorage,
+ CSP: CipherSuiteProvider,
+ {
+ #[cfg(feature = "by_ref_proposal")]
+ pub fn with_identity_provider<V>(self, validator: V) -> CommitReceiver<'a, V, F, P, CSP>
+ where
+ V: IdentityProvider,
+ {
+ CommitReceiver {
+ tree: self.tree,
+ sender: self.sender,
+ receiver: self.receiver,
+ cache: self.cache,
+ identity_provider: validator,
+ group_context_extensions: self.group_context_extensions,
+ user_rules: self.user_rules,
+ with_psk_storage: self.with_psk_storage,
+ cipher_suite_provider: self.cipher_suite_provider,
+ }
+ }
+
+ pub fn with_user_rules<G>(self, f: G) -> CommitReceiver<'a, C, G, P, CSP>
+ where
+ G: MlsRules,
+ {
+ CommitReceiver {
+ tree: self.tree,
+ sender: self.sender,
+ receiver: self.receiver,
+ cache: self.cache,
+ identity_provider: self.identity_provider,
+ group_context_extensions: self.group_context_extensions,
+ user_rules: f,
+ with_psk_storage: self.with_psk_storage,
+ cipher_suite_provider: self.cipher_suite_provider,
+ }
+ }
+
+ pub fn with_psk_storage<V>(self, v: V) -> CommitReceiver<'a, C, F, V, CSP>
+ where
+ V: PreSharedKeyStorage,
+ {
+ CommitReceiver {
+ tree: self.tree,
+ sender: self.sender,
+ receiver: self.receiver,
+ cache: self.cache,
+ identity_provider: self.identity_provider,
+ group_context_extensions: self.group_context_extensions,
+ user_rules: self.user_rules,
+ with_psk_storage: v,
+ cipher_suite_provider: self.cipher_suite_provider,
+ }
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ pub fn with_extensions(self, extensions: ExtensionList) -> Self {
+ Self {
+ group_context_extensions: extensions,
+ ..self
+ }
+ }
+
+ pub fn cache<S>(mut self, r: ProposalRef, p: Proposal, proposer: S) -> Self
+ where
+ S: Into<Sender>,
+ {
+ self.cache.insert(r, p, proposer.into());
+ self
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn receive<I>(&self, proposals: I) -> Result<ProvisionalState, MlsError>
+ where
+ I: IntoIterator,
+ I::Item: Into<ProposalOrRef>,
+ {
+ self.cache
+ .resolve_for_commit_default(
+ self.sender,
+ proposals.into_iter().map(Into::into).collect(),
+ None,
+ &self.group_context_extensions,
+ &self.identity_provider,
+ &self.cipher_suite_provider,
+ self.tree,
+ &self.with_psk_storage,
+ &self.user_rules,
+ )
+ .await
+ }
+ }
+
+ pub(crate) fn make_proposal_cache() -> ProposalCache {
+ ProposalCache::new(TEST_PROTOCOL_VERSION, TEST_GROUP.to_vec())
+ }
+
+ pub fn pass_through_rules() -> DefaultMlsRules {
+ DefaultMlsRules::new()
+ }
+
+ impl ProposalCache {
+ #[allow(clippy::too_many_arguments)]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn resolve_for_commit_default<C, F, P, CSP>(
+ &self,
+ sender: Sender,
+ proposal_list: Vec<ProposalOrRef>,
+ external_leaf: Option<&LeafNode>,
+ group_extensions: &ExtensionList,
+ identity_provider: &C,
+ cipher_suite_provider: &CSP,
+ public_tree: &TreeKemPublic,
+ psk_storage: &P,
+ user_rules: F,
+ ) -> Result<ProvisionalState, MlsError>
+ where
+ C: IdentityProvider,
+ F: MlsRules,
+ P: PreSharedKeyStorage,
+ CSP: CipherSuiteProvider,
+ {
+ let mut context =
+ get_test_group_context(123, cipher_suite_provider.cipher_suite()).await;
+
+ context.extensions = group_extensions.clone();
+
+ let mut state = GroupState::new(
+ context,
+ public_tree.clone(),
+ Vec::new().into(),
+ ConfirmationTag::empty(cipher_suite_provider).await,
+ );
+
+ state.proposals.proposals = self.proposals.clone();
+ let proposals = self.resolve_for_commit(sender, proposal_list)?;
+
+ state
+ .apply_resolved(
+ sender,
+ proposals,
+ external_leaf,
+ identity_provider,
+ cipher_suite_provider,
+ psk_storage,
+ &user_rules,
+ None,
+ CommitDirection::Receive,
+ )
+ .await
+ }
+
+ #[allow(clippy::too_many_arguments)]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn prepare_commit_default<C, F, P, CSP>(
+ &self,
+ sender: Sender,
+ additional_proposals: Vec<Proposal>,
+ context: &GroupContext,
+ identity_provider: &C,
+ cipher_suite_provider: &CSP,
+ public_tree: &TreeKemPublic,
+ external_leaf: Option<&LeafNode>,
+ psk_storage: &P,
+ user_rules: F,
+ ) -> Result<ProvisionalState, MlsError>
+ where
+ C: IdentityProvider,
+ F: MlsRules,
+ P: PreSharedKeyStorage,
+ CSP: CipherSuiteProvider,
+ {
+ let state = GroupState::new(
+ context.clone(),
+ public_tree.clone(),
+ Vec::new().into(),
+ ConfirmationTag::empty(cipher_suite_provider).await,
+ );
+
+ let proposals = self.prepare_commit(sender, additional_proposals);
+
+ state
+ .apply_resolved(
+ sender,
+ proposals,
+ external_leaf,
+ identity_provider,
+ cipher_suite_provider,
+ psk_storage,
+ &user_rules,
+ None,
+ CommitDirection::Send,
+ )
+ .await
+ }
+ }
+}
+
+// TODO add tests for lite version of filtering
+#[cfg(all(feature = "by_ref_proposal", test))]
+mod tests {
+ use alloc::{boxed::Box, vec, vec::Vec};
+
+ use super::test_utils::{make_proposal_cache, pass_through_rules, CommitReceiver};
+ use super::{CachedProposal, ProposalCache};
+ use crate::client::MlsError;
+ use crate::group::message_processor::ProvisionalState;
+ use crate::group::mls_rules::{CommitDirection, CommitSource, EncryptionOptions};
+ use crate::group::proposal_filter::{ProposalBundle, ProposalInfo, ProposalSource};
+ use crate::group::proposal_ref::test_utils::auth_content_from_proposal;
+ use crate::group::proposal_ref::ProposalRef;
+ use crate::group::{
+ AddProposal, AuthenticatedContent, Content, ExternalInit, Proposal, ProposalOrRef,
+ ReInitProposal, RemoveProposal, Roster, Sender, UpdateProposal,
+ };
+ use crate::key_package::test_utils::test_key_package_with_signer;
+ use crate::signer::Signable;
+ use crate::tree_kem::leaf_node::LeafNode;
+ use crate::tree_kem::node::LeafIndex;
+ use crate::tree_kem::TreeKemPublic;
+ use crate::{
+ client::test_utils::{TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
+ crypto::{self, test_utils::test_cipher_suite_provider},
+ extension::test_utils::TestExtension,
+ group::{
+ message_processor::path_update_required,
+ proposal_filter::proposer_can_propose,
+ test_utils::{get_test_group_context, random_bytes, test_group, TEST_GROUP},
+ },
+ identity::basic::BasicIdentityProvider,
+ identity::test_utils::{get_test_signing_identity, BasicWithCustomProvider},
+ key_package::{test_utils::test_key_package, KeyPackageGenerator},
+ mls_rules::{CommitOptions, DefaultMlsRules},
+ psk::AlwaysFoundPskStorage,
+ tree_kem::{
+ leaf_node::{
+ test_utils::{
+ default_properties, get_basic_test_node, get_basic_test_node_capabilities,
+ get_basic_test_node_sig_key, get_test_capabilities,
+ },
+ ConfigProperties, LeafNodeSigningContext, LeafNodeSource,
+ },
+ Lifetime,
+ },
+ };
+ use crate::{KeyPackage, MlsRules};
+
+ use crate::extension::RequiredCapabilitiesExt;
+
+ #[cfg(feature = "by_ref_proposal")]
+ use crate::{
+ extension::ExternalSendersExt,
+ tree_kem::leaf_node_validator::test_utils::FailureIdentityProvider,
+ };
+
+ #[cfg(feature = "psk")]
+ use crate::{
+ group::proposal::PreSharedKeyProposal,
+ psk::{
+ ExternalPskId, JustPreSharedKeyID, PreSharedKeyID, PskGroupId, PskNonce,
+ ResumptionPSKUsage, ResumptionPsk,
+ },
+ };
+
+ #[cfg(feature = "custom_proposal")]
+ use crate::group::proposal::CustomProposal;
+
+ use assert_matches::assert_matches;
+ use core::convert::Infallible;
+ use itertools::Itertools;
+ use mls_rs_core::crypto::{CipherSuite, CipherSuiteProvider};
+ use mls_rs_core::extension::ExtensionList;
+ use mls_rs_core::group::{Capabilities, ProposalType};
+ use mls_rs_core::identity::IdentityProvider;
+ use mls_rs_core::protocol_version::ProtocolVersion;
+ use mls_rs_core::psk::{PreSharedKey, PreSharedKeyStorage};
+ use mls_rs_core::{
+ extension::MlsExtension,
+ identity::{Credential, CredentialType, CustomCredential},
+ };
+
+ fn test_sender() -> u32 {
+ 1
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn new_tree_custom_proposals(
+ name: &str,
+ proposal_types: Vec<ProposalType>,
+ ) -> (LeafIndex, TreeKemPublic) {
+ let (leaf, secret, _) = get_basic_test_node_capabilities(
+ TEST_CIPHER_SUITE,
+ name,
+ Capabilities {
+ proposals: proposal_types,
+ ..get_test_capabilities()
+ },
+ )
+ .await;
+
+ let (pub_tree, priv_tree) =
+ TreeKemPublic::derive(leaf, secret, &BasicIdentityProvider, &Default::default())
+ .await
+ .unwrap();
+
+ (priv_tree.self_index, pub_tree)
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn new_tree(name: &str) -> (LeafIndex, TreeKemPublic) {
+ new_tree_custom_proposals(name, vec![]).await
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn add_member(tree: &mut TreeKemPublic, name: &str) -> LeafIndex {
+ let test_node = get_basic_test_node(TEST_CIPHER_SUITE, name).await;
+
+ tree.add_leaves(
+ vec![test_node],
+ &BasicIdentityProvider,
+ &test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .await
+ .unwrap()[0]
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn update_leaf_node(name: &str, leaf_index: u32) -> LeafNode {
+ let (mut leaf, _, signer) = get_basic_test_node_sig_key(TEST_CIPHER_SUITE, name).await;
+
+ leaf.update(
+ &test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ TEST_GROUP,
+ leaf_index,
+ default_properties(),
+ None,
+ &signer,
+ )
+ .await
+ .unwrap();
+
+ leaf
+ }
+
+ struct TestProposals {
+ test_sender: u32,
+ test_proposals: Vec<AuthenticatedContent>,
+ expected_effects: ProvisionalState,
+ tree: TreeKemPublic,
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn test_proposals(
+ protocol_version: ProtocolVersion,
+ cipher_suite: CipherSuite,
+ ) -> TestProposals {
+ let cipher_suite_provider = test_cipher_suite_provider(cipher_suite);
+
+ let (sender_leaf, sender_leaf_secret, _) =
+ get_basic_test_node_sig_key(cipher_suite, "alice").await;
+
+ let sender = LeafIndex(0);
+
+ let (mut tree, _) = TreeKemPublic::derive(
+ sender_leaf,
+ sender_leaf_secret,
+ &BasicIdentityProvider,
+ &Default::default(),
+ )
+ .await
+ .unwrap();
+
+ let add_package = test_key_package(protocol_version, cipher_suite, "dave").await;
+
+ let remove_leaf_index = add_member(&mut tree, "carol").await;
+
+ let add = Proposal::Add(Box::new(AddProposal {
+ key_package: add_package.clone(),
+ }));
+
+ let remove = Proposal::Remove(RemoveProposal {
+ to_remove: remove_leaf_index,
+ });
+
+ let extensions = Proposal::GroupContextExtensions(ExtensionList::new());
+
+ let proposals = vec![add, remove, extensions];
+
+ let test_node = get_basic_test_node(cipher_suite, "charlie").await;
+
+ let test_sender = *tree
+ .add_leaves(
+ vec![test_node],
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap()[0];
+
+ let mut expected_tree = tree.clone();
+
+ let mut bundle = ProposalBundle::default();
+
+ let plaintext = proposals
+ .iter()
+ .cloned()
+ .map(|p| auth_content_from_proposal(p, sender))
+ .collect_vec();
+
+ for i in 0..proposals.len() {
+ let pref = ProposalRef::from_content(&cipher_suite_provider, &plaintext[i])
+ .await
+ .unwrap();
+
+ bundle.add(
+ proposals[i].clone(),
+ Sender::Member(test_sender),
+ ProposalSource::ByReference(pref),
+ )
+ }
+
+ expected_tree
+ .batch_edit(
+ &mut bundle,
+ &Default::default(),
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ true,
+ )
+ .await
+ .unwrap();
+
+ let expected_effects = ProvisionalState {
+ public_tree: expected_tree,
+ group_context: get_test_group_context(1, cipher_suite).await,
+ external_init_index: None,
+ indexes_of_added_kpkgs: vec![LeafIndex(1)],
+ #[cfg(feature = "state_update")]
+ unused_proposals: vec![],
+ applied_proposals: bundle,
+ };
+
+ TestProposals {
+ test_sender,
+ test_proposals: plaintext,
+ expected_effects,
+ tree,
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn filter_proposals(
+ cipher_suite: CipherSuite,
+ proposals: Vec<AuthenticatedContent>,
+ ) -> Vec<(ProposalRef, CachedProposal)> {
+ let mut contents = Vec::new();
+
+ for p in proposals {
+ if let Content::Proposal(proposal) = &p.content.content {
+ let proposal_ref =
+ ProposalRef::from_content(&test_cipher_suite_provider(cipher_suite), &p)
+ .await
+ .unwrap();
+ contents.push((
+ proposal_ref,
+ CachedProposal::new(proposal.as_ref().clone(), p.content.sender),
+ ));
+ }
+ }
+
+ contents
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn make_proposal_ref<S>(p: &Proposal, sender: S) -> ProposalRef
+ where
+ S: Into<Sender>,
+ {
+ ProposalRef::from_content(
+ &test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ &auth_content_from_proposal(p.clone(), sender),
+ )
+ .await
+ .unwrap()
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn make_proposal_info<S>(p: &Proposal, sender: S) -> ProposalInfo<Proposal>
+ where
+ S: Into<Sender> + Clone,
+ {
+ ProposalInfo {
+ proposal: p.clone(),
+ sender: sender.clone().into(),
+ source: ProposalSource::ByReference(make_proposal_ref(p, sender).await),
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn test_proposal_cache_setup(proposals: Vec<AuthenticatedContent>) -> ProposalCache {
+ let mut cache = make_proposal_cache();
+ cache.extend(filter_proposals(TEST_CIPHER_SUITE, proposals).await);
+ cache
+ }
+
+ fn assert_matches(mut expected_state: ProvisionalState, state: ProvisionalState) {
+ let expected_proposals = expected_state.applied_proposals.into_proposals_or_refs();
+ let proposals = state.applied_proposals.into_proposals_or_refs();
+
+ assert_eq!(proposals.len(), expected_proposals.len());
+
+ // Determine there are no duplicates in the proposals returned
+ assert!(!proposals.iter().enumerate().any(|(i, p1)| proposals
+ .iter()
+ .enumerate()
+ .any(|(j, p2)| p1 == p2 && i != j)),);
+
+ // Proposal order may change so we just compare the length and contents are the same
+ expected_proposals
+ .iter()
+ .for_each(|p| assert!(proposals.contains(p)));
+
+ assert_eq!(
+ expected_state.external_init_index,
+ state.external_init_index
+ );
+
+ // We don't compare the epoch in this test.
+ expected_state.group_context.epoch = state.group_context.epoch;
+ assert_eq!(expected_state.group_context, state.group_context);
+
+ assert_eq!(
+ expected_state.indexes_of_added_kpkgs,
+ state.indexes_of_added_kpkgs
+ );
+
+ assert_eq!(expected_state.public_tree, state.public_tree);
+
+ #[cfg(feature = "state_update")]
+ assert_eq!(expected_state.unused_proposals, state.unused_proposals);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_proposal_cache_commit_all_cached() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let TestProposals {
+ test_sender,
+ test_proposals,
+ expected_effects,
+ tree,
+ ..
+ } = test_proposals(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+
+ let cache = test_proposal_cache_setup(test_proposals.clone()).await;
+
+ let provisional_state = cache
+ .prepare_commit_default(
+ Sender::Member(test_sender),
+ vec![],
+ &get_test_group_context(0, TEST_CIPHER_SUITE).await,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ &tree,
+ None,
+ &AlwaysFoundPskStorage,
+ pass_through_rules(),
+ )
+ .await
+ .unwrap();
+
+ assert_matches(expected_effects, provisional_state)
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_proposal_cache_commit_additional() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let TestProposals {
+ test_sender,
+ test_proposals,
+ mut expected_effects,
+ tree,
+ ..
+ } = test_proposals(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+
+ let additional_key_package =
+ test_key_package(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "frank").await;
+
+ let additional = AddProposal {
+ key_package: additional_key_package.clone(),
+ };
+
+ let cache = test_proposal_cache_setup(test_proposals.clone()).await;
+
+ let provisional_state = cache
+ .prepare_commit_default(
+ Sender::Member(test_sender),
+ vec![Proposal::Add(Box::new(additional.clone()))],
+ &get_test_group_context(0, TEST_CIPHER_SUITE).await,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ &tree,
+ None,
+ &AlwaysFoundPskStorage,
+ pass_through_rules(),
+ )
+ .await
+ .unwrap();
+
+ expected_effects.applied_proposals.add(
+ Proposal::Add(Box::new(additional.clone())),
+ Sender::Member(test_sender),
+ ProposalSource::ByValue,
+ );
+
+ let leaf = vec![additional_key_package.leaf_node.clone()];
+
+ expected_effects
+ .public_tree
+ .add_leaves(leaf, &BasicIdentityProvider, &cipher_suite_provider)
+ .await
+ .unwrap();
+
+ expected_effects.indexes_of_added_kpkgs.push(LeafIndex(3));
+
+ assert_matches(expected_effects, provisional_state);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_proposal_cache_update_filter() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let TestProposals {
+ test_proposals,
+ tree,
+ ..
+ } = test_proposals(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+
+ let update_proposal = make_update_proposal("foo").await;
+
+ let additional = vec![Proposal::Update(update_proposal)];
+
+ let cache = test_proposal_cache_setup(test_proposals).await;
+
+ let res = cache
+ .prepare_commit_default(
+ Sender::Member(test_sender()),
+ additional,
+ &get_test_group_context(0, TEST_CIPHER_SUITE).await,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ &tree,
+ None,
+ &AlwaysFoundPskStorage,
+ pass_through_rules(),
+ )
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidProposalTypeForSender));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_proposal_cache_removal_override_update() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let TestProposals {
+ test_sender,
+ test_proposals,
+ tree,
+ ..
+ } = test_proposals(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+
+ let update = Proposal::Update(make_update_proposal("foo").await);
+ let update_proposal_ref = make_proposal_ref(&update, LeafIndex(1)).await;
+ let mut cache = test_proposal_cache_setup(test_proposals).await;
+
+ cache.insert(update_proposal_ref.clone(), update, Sender::Member(1));
+
+ let provisional_state = cache
+ .prepare_commit_default(
+ Sender::Member(test_sender),
+ vec![],
+ &get_test_group_context(0, TEST_CIPHER_SUITE).await,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ &tree,
+ None,
+ &AlwaysFoundPskStorage,
+ pass_through_rules(),
+ )
+ .await
+ .unwrap();
+
+ assert!(provisional_state
+ .applied_proposals
+ .removals
+ .iter()
+ .any(|p| *p.proposal.to_remove == 1));
+
+ assert!(!provisional_state
+ .applied_proposals
+ .into_proposals_or_refs()
+ .contains(&ProposalOrRef::Reference(update_proposal_ref)))
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_proposal_cache_filter_duplicates_insert() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let TestProposals {
+ test_sender,
+ test_proposals,
+ expected_effects,
+ tree,
+ ..
+ } = test_proposals(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+
+ let mut cache = test_proposal_cache_setup(test_proposals.clone()).await;
+ cache.extend(filter_proposals(TEST_CIPHER_SUITE, test_proposals.clone()).await);
+
+ let provisional_state = cache
+ .prepare_commit_default(
+ Sender::Member(test_sender),
+ vec![],
+ &get_test_group_context(0, TEST_CIPHER_SUITE).await,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ &tree,
+ None,
+ &AlwaysFoundPskStorage,
+ pass_through_rules(),
+ )
+ .await
+ .unwrap();
+
+ assert_matches(expected_effects, provisional_state)
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_proposal_cache_filter_duplicates_additional() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let TestProposals {
+ test_proposals,
+ expected_effects,
+ tree,
+ ..
+ } = test_proposals(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+
+ let mut cache = test_proposal_cache_setup(test_proposals.clone()).await;
+
+ // Updates from different senders will be allowed so we test duplicates for add / remove
+ let additional = test_proposals
+ .clone()
+ .into_iter()
+ .filter_map(|plaintext| match plaintext.content.content {
+ Content::Proposal(p) if p.proposal_type() == ProposalType::UPDATE => None,
+ Content::Proposal(_) => Some(plaintext),
+ _ => None,
+ })
+ .collect::<Vec<_>>();
+
+ cache.extend(filter_proposals(TEST_CIPHER_SUITE, additional).await);
+
+ let provisional_state = cache
+ .prepare_commit_default(
+ Sender::Member(2),
+ Vec::new(),
+ &get_test_group_context(0, TEST_CIPHER_SUITE).await,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ &tree,
+ None,
+ &AlwaysFoundPskStorage,
+ pass_through_rules(),
+ )
+ .await
+ .unwrap();
+
+ assert_matches(expected_effects, provisional_state)
+ }
+
+ #[cfg(feature = "private_message")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_proposal_cache_is_empty() {
+ let mut cache = make_proposal_cache();
+ assert!(cache.is_empty());
+
+ let test_proposal = Proposal::Remove(RemoveProposal {
+ to_remove: LeafIndex(test_sender()),
+ });
+
+ let proposer = test_sender();
+ let test_proposal_ref = make_proposal_ref(&test_proposal, LeafIndex(proposer)).await;
+ cache.insert(test_proposal_ref, test_proposal, Sender::Member(proposer));
+
+ assert!(!cache.is_empty())
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_proposal_cache_resolve() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let TestProposals {
+ test_sender,
+ test_proposals,
+ tree,
+ ..
+ } = test_proposals(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+
+ let cache = test_proposal_cache_setup(test_proposals).await;
+
+ let proposal = Proposal::Add(Box::new(AddProposal {
+ key_package: test_key_package(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "frank").await,
+ }));
+
+ let additional = vec![proposal];
+
+ let expected_effects = cache
+ .prepare_commit_default(
+ Sender::Member(test_sender),
+ additional,
+ &get_test_group_context(0, TEST_CIPHER_SUITE).await,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ &tree,
+ None,
+ &AlwaysFoundPskStorage,
+ pass_through_rules(),
+ )
+ .await
+ .unwrap();
+
+ let proposals = expected_effects
+ .applied_proposals
+ .clone()
+ .into_proposals_or_refs();
+
+ let resolution = cache
+ .resolve_for_commit_default(
+ Sender::Member(test_sender),
+ proposals,
+ None,
+ &ExtensionList::new(),
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ &tree,
+ &AlwaysFoundPskStorage,
+ pass_through_rules(),
+ )
+ .await
+ .unwrap();
+
+ assert_matches(expected_effects, resolution);
+ }
+
+ #[cfg(feature = "psk")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn proposal_cache_filters_duplicate_psk_ids() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let (alice, tree) = new_tree("alice").await;
+ let cache = make_proposal_cache();
+
+ let proposal = Proposal::Psk(make_external_psk(
+ b"ted",
+ crate::psk::PskNonce::random(&test_cipher_suite_provider(TEST_CIPHER_SUITE)).unwrap(),
+ ));
+
+ let res = cache
+ .prepare_commit_default(
+ Sender::Member(*alice),
+ vec![proposal.clone(), proposal],
+ &get_test_group_context(0, TEST_CIPHER_SUITE).await,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ &tree,
+ None,
+ &AlwaysFoundPskStorage,
+ pass_through_rules(),
+ )
+ .await;
+
+ assert_matches!(res, Err(MlsError::DuplicatePskIds));
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn test_node() -> LeafNode {
+ let (mut leaf_node, _, signer) =
+ get_basic_test_node_sig_key(TEST_CIPHER_SUITE, "foo").await;
+
+ leaf_node
+ .commit(
+ &test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ TEST_GROUP,
+ 0,
+ default_properties(),
+ None,
+ &signer,
+ )
+ .await
+ .unwrap();
+
+ leaf_node
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn external_commit_must_have_new_leaf() {
+ let cache = make_proposal_cache();
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let kem_output = vec![0; cipher_suite_provider.kdf_extract_size()];
+ let group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let public_tree = &group.group.state.public_tree;
+
+ let res = cache
+ .resolve_for_commit_default(
+ Sender::NewMemberCommit,
+ vec![ProposalOrRef::Proposal(Box::new(Proposal::ExternalInit(
+ ExternalInit { kem_output },
+ )))],
+ None,
+ &group.group.context().extensions,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ public_tree,
+ &AlwaysFoundPskStorage,
+ pass_through_rules(),
+ )
+ .await;
+
+ assert_matches!(res, Err(MlsError::ExternalCommitMustHaveNewLeaf));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn proposal_cache_rejects_proposals_by_ref_for_new_member() {
+ let mut cache = make_proposal_cache();
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let proposal = {
+ let kem_output = vec![0; cipher_suite_provider.kdf_extract_size()];
+ Proposal::ExternalInit(ExternalInit { kem_output })
+ };
+
+ let proposal_ref = make_proposal_ref(&proposal, test_sender()).await;
+
+ cache.insert(
+ proposal_ref.clone(),
+ proposal,
+ Sender::Member(test_sender()),
+ );
+
+ let group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let public_tree = &group.group.state.public_tree;
+
+ let res = cache
+ .resolve_for_commit_default(
+ Sender::NewMemberCommit,
+ vec![ProposalOrRef::Reference(proposal_ref)],
+ Some(&test_node().await),
+ &group.group.context().extensions,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ public_tree,
+ &AlwaysFoundPskStorage,
+ pass_through_rules(),
+ )
+ .await;
+
+ assert_matches!(res, Err(MlsError::OnlyMembersCanCommitProposalsByRef));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn proposal_cache_rejects_multiple_external_init_proposals_in_commit() {
+ let cache = make_proposal_cache();
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+ let kem_output = vec![0; cipher_suite_provider.kdf_extract_size()];
+ let group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let public_tree = &group.group.state.public_tree;
+
+ let res = cache
+ .resolve_for_commit_default(
+ Sender::NewMemberCommit,
+ [
+ Proposal::ExternalInit(ExternalInit {
+ kem_output: kem_output.clone(),
+ }),
+ Proposal::ExternalInit(ExternalInit { kem_output }),
+ ]
+ .into_iter()
+ .map(|p| ProposalOrRef::Proposal(Box::new(p)))
+ .collect(),
+ Some(&test_node().await),
+ &group.group.context().extensions,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ public_tree,
+ &AlwaysFoundPskStorage,
+ pass_through_rules(),
+ )
+ .await;
+
+ assert_matches!(
+ res,
+ Err(MlsError::ExternalCommitMustHaveExactlyOneExternalInit)
+ );
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn new_member_commits_proposal(proposal: Proposal) -> Result<ProvisionalState, MlsError> {
+ let cache = make_proposal_cache();
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+ let kem_output = vec![0; cipher_suite_provider.kdf_extract_size()];
+ let group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let public_tree = &group.group.state.public_tree;
+
+ cache
+ .resolve_for_commit_default(
+ Sender::NewMemberCommit,
+ [
+ Proposal::ExternalInit(ExternalInit { kem_output }),
+ proposal,
+ ]
+ .into_iter()
+ .map(|p| ProposalOrRef::Proposal(Box::new(p)))
+ .collect(),
+ Some(&test_node().await),
+ &group.group.context().extensions,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ public_tree,
+ &AlwaysFoundPskStorage,
+ pass_through_rules(),
+ )
+ .await
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn new_member_cannot_commit_add_proposal() {
+ let res = new_member_commits_proposal(Proposal::Add(Box::new(AddProposal {
+ key_package: test_key_package(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "frank").await,
+ })))
+ .await;
+
+ assert_matches!(
+ res,
+ Err(MlsError::InvalidProposalTypeInExternalCommit(
+ ProposalType::ADD
+ ))
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn new_member_cannot_commit_more_than_one_remove_proposal() {
+ let cache = make_proposal_cache();
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+ let kem_output = vec![0; cipher_suite_provider.kdf_extract_size()];
+ let group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let group_extensions = group.group.context().extensions.clone();
+ let mut public_tree = group.group.state.public_tree;
+
+ let foo = get_basic_test_node(TEST_CIPHER_SUITE, "foo").await;
+
+ let bar = get_basic_test_node(TEST_CIPHER_SUITE, "bar").await;
+
+ let test_leaf_nodes = vec![foo, bar];
+
+ let test_leaf_node_indexes = public_tree
+ .add_leaves(
+ test_leaf_nodes,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap();
+
+ let proposals = vec![
+ Proposal::ExternalInit(ExternalInit { kem_output }),
+ Proposal::Remove(RemoveProposal {
+ to_remove: test_leaf_node_indexes[0],
+ }),
+ Proposal::Remove(RemoveProposal {
+ to_remove: test_leaf_node_indexes[1],
+ }),
+ ];
+
+ let res = cache
+ .resolve_for_commit_default(
+ Sender::NewMemberCommit,
+ proposals
+ .into_iter()
+ .map(|p| ProposalOrRef::Proposal(Box::new(p)))
+ .collect(),
+ Some(&test_node().await),
+ &group_extensions,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ &public_tree,
+ &AlwaysFoundPskStorage,
+ pass_through_rules(),
+ )
+ .await;
+
+ assert_matches!(res, Err(MlsError::ExternalCommitWithMoreThanOneRemove));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn new_member_remove_proposal_invalid_credential() {
+ let cache = make_proposal_cache();
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+ let kem_output = vec![0; cipher_suite_provider.kdf_extract_size()];
+ let group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let group_extensions = group.group.context().extensions.clone();
+ let mut public_tree = group.group.state.public_tree;
+
+ let node = get_basic_test_node(TEST_CIPHER_SUITE, "bar").await;
+
+ let test_leaf_nodes = vec![node];
+
+ let test_leaf_node_indexes = public_tree
+ .add_leaves(
+ test_leaf_nodes,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap();
+
+ let proposals = vec![
+ Proposal::ExternalInit(ExternalInit { kem_output }),
+ Proposal::Remove(RemoveProposal {
+ to_remove: test_leaf_node_indexes[0],
+ }),
+ ];
+
+ let res = cache
+ .resolve_for_commit_default(
+ Sender::NewMemberCommit,
+ proposals
+ .into_iter()
+ .map(|p| ProposalOrRef::Proposal(Box::new(p)))
+ .collect(),
+ Some(&test_node().await),
+ &group_extensions,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ &public_tree,
+ &AlwaysFoundPskStorage,
+ pass_through_rules(),
+ )
+ .await;
+
+ assert_matches!(res, Err(MlsError::ExternalCommitRemovesOtherIdentity));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn new_member_remove_proposal_valid_credential() {
+ let cache = make_proposal_cache();
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+ let kem_output = vec![0; cipher_suite_provider.kdf_extract_size()];
+ let group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let group_extensions = group.group.context().extensions.clone();
+ let mut public_tree = group.group.state.public_tree;
+
+ let node = get_basic_test_node(TEST_CIPHER_SUITE, "foo").await;
+
+ let test_leaf_nodes = vec![node];
+
+ let test_leaf_node_indexes = public_tree
+ .add_leaves(
+ test_leaf_nodes,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap();
+
+ let proposals = vec![
+ Proposal::ExternalInit(ExternalInit { kem_output }),
+ Proposal::Remove(RemoveProposal {
+ to_remove: test_leaf_node_indexes[0],
+ }),
+ ];
+
+ let res = cache
+ .resolve_for_commit_default(
+ Sender::NewMemberCommit,
+ proposals
+ .into_iter()
+ .map(|p| ProposalOrRef::Proposal(Box::new(p)))
+ .collect(),
+ Some(&test_node().await),
+ &group_extensions,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ &public_tree,
+ &AlwaysFoundPskStorage,
+ pass_through_rules(),
+ )
+ .await;
+
+ assert_matches!(res, Ok(_));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn new_member_cannot_commit_update_proposal() {
+ let res = new_member_commits_proposal(Proposal::Update(UpdateProposal {
+ leaf_node: get_basic_test_node(TEST_CIPHER_SUITE, "foo").await,
+ }))
+ .await;
+
+ assert_matches!(
+ res,
+ Err(MlsError::InvalidProposalTypeInExternalCommit(
+ ProposalType::UPDATE
+ ))
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn new_member_cannot_commit_group_extensions_proposal() {
+ let res =
+ new_member_commits_proposal(Proposal::GroupContextExtensions(ExtensionList::new()))
+ .await;
+
+ assert_matches!(
+ res,
+ Err(MlsError::InvalidProposalTypeInExternalCommit(
+ ProposalType::GROUP_CONTEXT_EXTENSIONS,
+ ))
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn new_member_cannot_commit_reinit_proposal() {
+ let res = new_member_commits_proposal(Proposal::ReInit(ReInitProposal {
+ group_id: b"foo".to_vec(),
+ version: TEST_PROTOCOL_VERSION,
+ cipher_suite: TEST_CIPHER_SUITE,
+ extensions: ExtensionList::new(),
+ }))
+ .await;
+
+ assert_matches!(
+ res,
+ Err(MlsError::InvalidProposalTypeInExternalCommit(
+ ProposalType::RE_INIT
+ ))
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn new_member_commit_must_contain_an_external_init_proposal() {
+ let cache = make_proposal_cache();
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+ let group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let public_tree = &group.group.state.public_tree;
+
+ let res = cache
+ .resolve_for_commit_default(
+ Sender::NewMemberCommit,
+ Vec::new(),
+ Some(&test_node().await),
+ &group.group.context().extensions,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ public_tree,
+ &AlwaysFoundPskStorage,
+ pass_through_rules(),
+ )
+ .await;
+
+ assert_matches!(
+ res,
+ Err(MlsError::ExternalCommitMustHaveExactlyOneExternalInit)
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_path_update_required_empty() {
+ let cache = make_proposal_cache();
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let mut tree = TreeKemPublic::new();
+ add_member(&mut tree, "alice").await;
+ add_member(&mut tree, "bob").await;
+
+ let effects = cache
+ .prepare_commit_default(
+ Sender::Member(test_sender()),
+ vec![],
+ &get_test_group_context(1, TEST_CIPHER_SUITE).await,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ &tree,
+ None,
+ &AlwaysFoundPskStorage,
+ pass_through_rules(),
+ )
+ .await
+ .unwrap();
+
+ assert!(path_update_required(&effects.applied_proposals))
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_path_update_required_updates() {
+ let mut cache = make_proposal_cache();
+ let update = Proposal::Update(make_update_proposal("bar").await);
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ cache.insert(
+ make_proposal_ref(&update, LeafIndex(2)).await,
+ update,
+ Sender::Member(2),
+ );
+
+ let mut tree = TreeKemPublic::new();
+ add_member(&mut tree, "alice").await;
+ add_member(&mut tree, "bob").await;
+ add_member(&mut tree, "carol").await;
+
+ let effects = cache
+ .prepare_commit_default(
+ Sender::Member(test_sender()),
+ Vec::new(),
+ &get_test_group_context(1, TEST_CIPHER_SUITE).await,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ &tree,
+ None,
+ &AlwaysFoundPskStorage,
+ pass_through_rules(),
+ )
+ .await
+ .unwrap();
+
+ assert!(path_update_required(&effects.applied_proposals))
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_path_update_required_removes() {
+ let cache = make_proposal_cache();
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let (alice_leaf, alice_secret, _) =
+ get_basic_test_node_sig_key(TEST_CIPHER_SUITE, "alice").await;
+ let alice = 0;
+
+ let (mut tree, _) = TreeKemPublic::derive(
+ alice_leaf,
+ alice_secret,
+ &BasicIdentityProvider,
+ &Default::default(),
+ )
+ .await
+ .unwrap();
+
+ let bob_node = get_basic_test_node(TEST_CIPHER_SUITE, "bob").await;
+
+ let bob = tree
+ .add_leaves(
+ vec![bob_node],
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap()[0];
+
+ let remove = Proposal::Remove(RemoveProposal { to_remove: bob });
+
+ let effects = cache
+ .prepare_commit_default(
+ Sender::Member(alice),
+ vec![remove],
+ &get_test_group_context(1, TEST_CIPHER_SUITE).await,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ &tree,
+ None,
+ &AlwaysFoundPskStorage,
+ pass_through_rules(),
+ )
+ .await
+ .unwrap();
+
+ assert!(path_update_required(&effects.applied_proposals))
+ }
+
+ #[cfg(feature = "psk")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_path_update_not_required() {
+ let (alice, tree) = new_tree("alice").await;
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+ let cache = make_proposal_cache();
+
+ let psk = Proposal::Psk(PreSharedKeyProposal {
+ psk: PreSharedKeyID::new(
+ JustPreSharedKeyID::External(ExternalPskId::new(vec![])),
+ &test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .unwrap(),
+ });
+
+ let add = Proposal::Add(Box::new(AddProposal {
+ key_package: test_key_package(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "bob").await,
+ }));
+
+ let effects = cache
+ .prepare_commit_default(
+ Sender::Member(*alice),
+ vec![psk, add],
+ &get_test_group_context(1, TEST_CIPHER_SUITE).await,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ &tree,
+ None,
+ &AlwaysFoundPskStorage,
+ pass_through_rules(),
+ )
+ .await
+ .unwrap();
+
+ assert!(!path_update_required(&effects.applied_proposals))
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn path_update_is_not_required_for_re_init() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+ let (alice, tree) = new_tree("alice").await;
+ let cache = make_proposal_cache();
+
+ let reinit = Proposal::ReInit(ReInitProposal {
+ group_id: vec![],
+ version: TEST_PROTOCOL_VERSION,
+ cipher_suite: TEST_CIPHER_SUITE,
+ extensions: Default::default(),
+ });
+
+ let effects = cache
+ .prepare_commit_default(
+ Sender::Member(*alice),
+ vec![reinit],
+ &get_test_group_context(1, TEST_CIPHER_SUITE).await,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ &tree,
+ None,
+ &AlwaysFoundPskStorage,
+ pass_through_rules(),
+ )
+ .await
+ .unwrap();
+
+ assert!(!path_update_required(&effects.applied_proposals))
+ }
+
+ #[derive(Debug)]
+ struct CommitSender<'a, C, F, P, CSP> {
+ cipher_suite_provider: CSP,
+ tree: &'a TreeKemPublic,
+ sender: LeafIndex,
+ cache: ProposalCache,
+ additional_proposals: Vec<Proposal>,
+ identity_provider: C,
+ user_rules: F,
+ psk_storage: P,
+ }
+
+ impl<'a, CSP>
+ CommitSender<'a, BasicWithCustomProvider, DefaultMlsRules, AlwaysFoundPskStorage, CSP>
+ {
+ fn new(tree: &'a TreeKemPublic, sender: LeafIndex, cipher_suite_provider: CSP) -> Self {
+ Self {
+ tree,
+ sender,
+ cache: make_proposal_cache(),
+ additional_proposals: Vec::new(),
+ identity_provider: BasicWithCustomProvider::new(BasicIdentityProvider::new()),
+ user_rules: pass_through_rules(),
+ psk_storage: AlwaysFoundPskStorage,
+ cipher_suite_provider,
+ }
+ }
+ }
+
+ impl<'a, C, F, P, CSP> CommitSender<'a, C, F, P, CSP>
+ where
+ C: IdentityProvider,
+ F: MlsRules,
+ P: PreSharedKeyStorage,
+ CSP: CipherSuiteProvider,
+ {
+ #[cfg(feature = "by_ref_proposal")]
+ fn with_identity_provider<V>(self, identity_provider: V) -> CommitSender<'a, V, F, P, CSP>
+ where
+ V: IdentityProvider,
+ {
+ CommitSender {
+ identity_provider,
+ cipher_suite_provider: self.cipher_suite_provider,
+ tree: self.tree,
+ sender: self.sender,
+ cache: self.cache,
+ additional_proposals: self.additional_proposals,
+ user_rules: self.user_rules,
+ psk_storage: self.psk_storage,
+ }
+ }
+
+ fn cache<S>(mut self, r: ProposalRef, p: Proposal, proposer: S) -> Self
+ where
+ S: Into<Sender>,
+ {
+ self.cache.insert(r, p, proposer.into());
+ self
+ }
+
+ fn with_additional<I>(mut self, proposals: I) -> Self
+ where
+ I: IntoIterator<Item = Proposal>,
+ {
+ self.additional_proposals.extend(proposals);
+ self
+ }
+
+ fn with_user_rules<G>(self, f: G) -> CommitSender<'a, C, G, P, CSP>
+ where
+ G: MlsRules,
+ {
+ CommitSender {
+ tree: self.tree,
+ sender: self.sender,
+ cache: self.cache,
+ additional_proposals: self.additional_proposals,
+ identity_provider: self.identity_provider,
+ user_rules: f,
+ psk_storage: self.psk_storage,
+ cipher_suite_provider: self.cipher_suite_provider,
+ }
+ }
+
+ fn with_psk_storage<V>(self, v: V) -> CommitSender<'a, C, F, V, CSP>
+ where
+ V: PreSharedKeyStorage,
+ {
+ CommitSender {
+ tree: self.tree,
+ sender: self.sender,
+ cache: self.cache,
+ additional_proposals: self.additional_proposals,
+ identity_provider: self.identity_provider,
+ user_rules: self.user_rules,
+ psk_storage: v,
+ cipher_suite_provider: self.cipher_suite_provider,
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn send(&self) -> Result<(Vec<ProposalOrRef>, ProvisionalState), MlsError> {
+ let state = self
+ .cache
+ .prepare_commit_default(
+ Sender::Member(*self.sender),
+ self.additional_proposals.clone(),
+ &get_test_group_context(1, TEST_CIPHER_SUITE).await,
+ &self.identity_provider,
+ &self.cipher_suite_provider,
+ self.tree,
+ None,
+ &self.psk_storage,
+ &self.user_rules,
+ )
+ .await?;
+
+ let proposals = state.applied_proposals.clone().into_proposals_or_refs();
+
+ Ok((proposals, state))
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn key_package_with_invalid_signature() -> KeyPackage {
+ let mut kp = test_key_package(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "mallory").await;
+ kp.signature.clear();
+ kp
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn key_package_with_public_key(key: crypto::HpkePublicKey) -> KeyPackage {
+ let cs = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let (mut key_package, signer) =
+ test_key_package_with_signer(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "test").await;
+
+ key_package.leaf_node.public_key = key;
+
+ key_package
+ .leaf_node
+ .sign(
+ &cs,
+ &signer,
+ &LeafNodeSigningContext {
+ group_id: None,
+ leaf_index: None,
+ },
+ )
+ .await
+ .unwrap();
+
+ key_package.sign(&cs, &signer, &()).await.unwrap();
+
+ key_package
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_add_with_invalid_key_package_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .receive([Proposal::Add(Box::new(AddProposal {
+ key_package: key_package_with_invalid_signature().await,
+ }))])
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidSignature));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_additional_add_with_invalid_key_package_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_additional([Proposal::Add(Box::new(AddProposal {
+ key_package: key_package_with_invalid_signature().await,
+ }))])
+ .send()
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidSignature));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_add_with_invalid_key_package_filters_it_out() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let proposal = Proposal::Add(Box::new(AddProposal {
+ key_package: key_package_with_invalid_signature().await,
+ }));
+
+ let proposal_info = make_proposal_info(&proposal, alice).await;
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .cache(
+ proposal_info.proposal_ref().unwrap().clone(),
+ proposal.clone(),
+ alice,
+ )
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(processed_proposals.0, Vec::new());
+
+ #[cfg(feature = "state_update")]
+ assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_add_with_hpke_key_of_another_member_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_additional([Proposal::Add(Box::new(AddProposal {
+ key_package: key_package_with_public_key(
+ tree.get_leaf_node(alice).unwrap().public_key.clone(),
+ )
+ .await,
+ }))])
+ .send()
+ .await;
+
+ assert_matches!(res, Err(MlsError::DuplicateLeafData(_)));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_add_with_hpke_key_of_another_member_filters_it_out() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let proposal = Proposal::Add(Box::new(AddProposal {
+ key_package: key_package_with_public_key(
+ tree.get_leaf_node(alice).unwrap().public_key.clone(),
+ )
+ .await,
+ }));
+
+ let proposal_info = make_proposal_info(&proposal, alice).await;
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .cache(
+ proposal_info.proposal_ref().unwrap().clone(),
+ proposal.clone(),
+ alice,
+ )
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(processed_proposals.0, Vec::new());
+
+ #[cfg(feature = "state_update")]
+ assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_update_with_invalid_leaf_node_fails() {
+ let (alice, mut tree) = new_tree("alice").await;
+ let bob = add_member(&mut tree, "bob").await;
+
+ let proposal = Proposal::Update(UpdateProposal {
+ leaf_node: get_basic_test_node(TEST_CIPHER_SUITE, "alice").await,
+ });
+
+ let proposal_ref = make_proposal_ref(&proposal, bob).await;
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ bob,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .cache(proposal_ref.clone(), proposal, bob)
+ .receive([proposal_ref])
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidLeafNodeSource));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_update_with_invalid_leaf_node_filters_it_out() {
+ let (alice, mut tree) = new_tree("alice").await;
+ let bob = add_member(&mut tree, "bob").await;
+
+ let proposal = Proposal::Update(UpdateProposal {
+ leaf_node: get_basic_test_node(TEST_CIPHER_SUITE, "alice").await,
+ });
+
+ let proposal_info = make_proposal_info(&proposal, bob).await;
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .cache(proposal_info.proposal_ref().unwrap().clone(), proposal, bob)
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(processed_proposals.0, Vec::new());
+
+ #[cfg(feature = "state_update")]
+ assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_remove_with_invalid_index_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .receive([Proposal::Remove(RemoveProposal {
+ to_remove: LeafIndex(10),
+ })])
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidNodeIndex(20)));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_additional_remove_with_invalid_index_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_additional([Proposal::Remove(RemoveProposal {
+ to_remove: LeafIndex(10),
+ })])
+ .send()
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidNodeIndex(20)));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_remove_with_invalid_index_filters_it_out() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let proposal = Proposal::Remove(RemoveProposal {
+ to_remove: LeafIndex(10),
+ });
+
+ let proposal_info = make_proposal_info(&proposal, alice).await;
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .cache(
+ proposal_info.proposal_ref().unwrap().clone(),
+ proposal.clone(),
+ alice,
+ )
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(processed_proposals.0, Vec::new());
+
+ #[cfg(feature = "state_update")]
+ assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]);
+ }
+
+ #[cfg(feature = "psk")]
+ fn make_external_psk(id: &[u8], nonce: PskNonce) -> PreSharedKeyProposal {
+ PreSharedKeyProposal {
+ psk: PreSharedKeyID {
+ key_id: JustPreSharedKeyID::External(ExternalPskId::new(id.to_vec())),
+ psk_nonce: nonce,
+ },
+ }
+ }
+
+ #[cfg(feature = "psk")]
+ fn new_external_psk(id: &[u8]) -> PreSharedKeyProposal {
+ make_external_psk(
+ id,
+ PskNonce::random(&test_cipher_suite_provider(TEST_CIPHER_SUITE)).unwrap(),
+ )
+ }
+
+ #[cfg(feature = "psk")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_psk_with_invalid_nonce_fails() {
+ let invalid_nonce = PskNonce(vec![0, 1, 2]);
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .receive([Proposal::Psk(make_external_psk(
+ b"foo",
+ invalid_nonce.clone(),
+ ))])
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidPskNonceLength,));
+ }
+
+ #[cfg(feature = "psk")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_additional_psk_with_invalid_nonce_fails() {
+ let invalid_nonce = PskNonce(vec![0, 1, 2]);
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_additional([Proposal::Psk(make_external_psk(
+ b"foo",
+ invalid_nonce.clone(),
+ ))])
+ .send()
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidPskNonceLength));
+ }
+
+ #[cfg(feature = "psk")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_psk_with_invalid_nonce_filters_it_out() {
+ let invalid_nonce = PskNonce(vec![0, 1, 2]);
+ let (alice, tree) = new_tree("alice").await;
+ let proposal = Proposal::Psk(make_external_psk(b"foo", invalid_nonce));
+
+ let proposal_info = make_proposal_info(&proposal, alice).await;
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .cache(
+ proposal_info.proposal_ref().unwrap().clone(),
+ proposal.clone(),
+ alice,
+ )
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(processed_proposals.0, Vec::new());
+
+ #[cfg(feature = "state_update")]
+ assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]);
+ }
+
+ #[cfg(feature = "psk")]
+ fn make_resumption_psk(usage: ResumptionPSKUsage) -> PreSharedKeyProposal {
+ PreSharedKeyProposal {
+ psk: PreSharedKeyID {
+ key_id: JustPreSharedKeyID::Resumption(ResumptionPsk {
+ usage,
+ psk_group_id: PskGroupId(TEST_GROUP.to_vec()),
+ psk_epoch: 1,
+ }),
+ psk_nonce: PskNonce::random(&test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .unwrap(),
+ },
+ }
+ }
+
+ #[cfg(feature = "psk")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn receiving_resumption_psk_with_bad_usage_fails(usage: ResumptionPSKUsage) {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .receive([Proposal::Psk(make_resumption_psk(usage))])
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidTypeOrUsageInPreSharedKeyProposal));
+ }
+
+ #[cfg(feature = "psk")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn sending_additional_resumption_psk_with_bad_usage_fails(usage: ResumptionPSKUsage) {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_additional([Proposal::Psk(make_resumption_psk(usage))])
+ .send()
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidTypeOrUsageInPreSharedKeyProposal));
+ }
+
+ #[cfg(feature = "psk")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn sending_resumption_psk_with_bad_usage_filters_it_out(usage: ResumptionPSKUsage) {
+ let (alice, tree) = new_tree("alice").await;
+ let proposal = Proposal::Psk(make_resumption_psk(usage));
+ let proposal_info = make_proposal_info(&proposal, alice).await;
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .cache(
+ proposal_info.proposal_ref().unwrap().clone(),
+ proposal.clone(),
+ alice,
+ )
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(processed_proposals.0, Vec::new());
+
+ #[cfg(feature = "state_update")]
+ assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]);
+ }
+
+ #[cfg(feature = "psk")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_resumption_psk_with_reinit_usage_fails() {
+ receiving_resumption_psk_with_bad_usage_fails(ResumptionPSKUsage::Reinit).await;
+ }
+
+ #[cfg(feature = "psk")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_additional_resumption_psk_with_reinit_usage_fails() {
+ sending_additional_resumption_psk_with_bad_usage_fails(ResumptionPSKUsage::Reinit).await;
+ }
+
+ #[cfg(feature = "psk")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_resumption_psk_with_reinit_usage_filters_it_out() {
+ sending_resumption_psk_with_bad_usage_filters_it_out(ResumptionPSKUsage::Reinit).await;
+ }
+
+ #[cfg(feature = "psk")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_resumption_psk_with_branch_usage_fails() {
+ receiving_resumption_psk_with_bad_usage_fails(ResumptionPSKUsage::Branch).await;
+ }
+
+ #[cfg(feature = "psk")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_additional_resumption_psk_with_branch_usage_fails() {
+ sending_additional_resumption_psk_with_bad_usage_fails(ResumptionPSKUsage::Branch).await;
+ }
+
+ #[cfg(feature = "psk")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_resumption_psk_with_branch_usage_filters_it_out() {
+ sending_resumption_psk_with_bad_usage_filters_it_out(ResumptionPSKUsage::Branch).await;
+ }
+
+ fn make_reinit(version: ProtocolVersion) -> ReInitProposal {
+ ReInitProposal {
+ group_id: TEST_GROUP.to_vec(),
+ version,
+ cipher_suite: TEST_CIPHER_SUITE,
+ extensions: ExtensionList::new(),
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_reinit_downgrading_version_fails() {
+ let smaller_protocol_version = ProtocolVersion::from(0);
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .receive([Proposal::ReInit(make_reinit(smaller_protocol_version))])
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidProtocolVersionInReInit));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_additional_reinit_downgrading_version_fails() {
+ let smaller_protocol_version = ProtocolVersion::from(0);
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_additional([Proposal::ReInit(make_reinit(smaller_protocol_version))])
+ .send()
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidProtocolVersionInReInit));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_reinit_downgrading_version_filters_it_out() {
+ let smaller_protocol_version = ProtocolVersion::from(0);
+ let (alice, tree) = new_tree("alice").await;
+ let proposal = Proposal::ReInit(make_reinit(smaller_protocol_version));
+ let proposal_info = make_proposal_info(&proposal, alice).await;
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .cache(
+ proposal_info.proposal_ref().unwrap().clone(),
+ proposal.clone(),
+ alice,
+ )
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(processed_proposals.0, Vec::new());
+
+ #[cfg(feature = "state_update")]
+ assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_update_for_committer_fails() {
+ let (alice, tree) = new_tree("alice").await;
+ let update = Proposal::Update(make_update_proposal("alice").await);
+ let update_ref = make_proposal_ref(&update, alice).await;
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .cache(update_ref.clone(), update, alice)
+ .receive([update_ref])
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidCommitSelfUpdate));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_additional_update_for_committer_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_additional([Proposal::Update(make_update_proposal("alice").await)])
+ .send()
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidProposalTypeForSender));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_update_for_committer_filters_it_out() {
+ let (alice, tree) = new_tree("alice").await;
+ let proposal = Proposal::Update(make_update_proposal("alice").await);
+ let proposal_info = make_proposal_info(&proposal, alice).await;
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .cache(
+ proposal_info.proposal_ref().unwrap().clone(),
+ proposal.clone(),
+ alice,
+ )
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(processed_proposals.0, Vec::new());
+
+ #[cfg(feature = "state_update")]
+ assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_remove_for_committer_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .receive([Proposal::Remove(RemoveProposal { to_remove: alice })])
+ .await;
+
+ assert_matches!(res, Err(MlsError::CommitterSelfRemoval));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_additional_remove_for_committer_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_additional([Proposal::Remove(RemoveProposal { to_remove: alice })])
+ .send()
+ .await;
+
+ assert_matches!(res, Err(MlsError::CommitterSelfRemoval));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_remove_for_committer_filters_it_out() {
+ let (alice, tree) = new_tree("alice").await;
+ let proposal = Proposal::Remove(RemoveProposal { to_remove: alice });
+ let proposal_info = make_proposal_info(&proposal, alice).await;
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .cache(
+ proposal_info.proposal_ref().unwrap().clone(),
+ proposal.clone(),
+ alice,
+ )
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(processed_proposals.0, Vec::new());
+
+ #[cfg(feature = "state_update")]
+ assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_update_and_remove_for_same_leaf_fails() {
+ let (alice, mut tree) = new_tree("alice").await;
+ let bob = add_member(&mut tree, "bob").await;
+
+ let update = Proposal::Update(make_update_proposal("bob").await);
+ let update_ref = make_proposal_ref(&update, bob).await;
+
+ let remove = Proposal::Remove(RemoveProposal { to_remove: bob });
+ let remove_ref = make_proposal_ref(&remove, bob).await;
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .cache(update_ref.clone(), update, bob)
+ .cache(remove_ref.clone(), remove, bob)
+ .receive([update_ref, remove_ref])
+ .await;
+
+ assert_matches!(res, Err(MlsError::UpdatingNonExistingMember));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_update_and_remove_for_same_leaf_filters_update_out() {
+ let (alice, mut tree) = new_tree("alice").await;
+ let bob = add_member(&mut tree, "bob").await;
+
+ let update = Proposal::Update(make_update_proposal("bob").await);
+ let update_info = make_proposal_info(&update, alice).await;
+
+ let remove = Proposal::Remove(RemoveProposal { to_remove: bob });
+ let remove_ref = make_proposal_ref(&remove, alice).await;
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .cache(
+ update_info.proposal_ref().unwrap().clone(),
+ update.clone(),
+ alice,
+ )
+ .cache(remove_ref.clone(), remove, alice)
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(processed_proposals.0, vec![remove_ref.into()]);
+
+ #[cfg(feature = "state_update")]
+ assert_eq!(processed_proposals.1.unused_proposals, vec![update_info]);
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn make_add_proposal() -> Box<AddProposal> {
+ Box::new(AddProposal {
+ key_package: test_key_package(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "frank").await,
+ })
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_add_proposals_for_same_client_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .receive([
+ Proposal::Add(make_add_proposal().await),
+ Proposal::Add(make_add_proposal().await),
+ ])
+ .await;
+
+ assert_matches!(res, Err(MlsError::DuplicateLeafData(1)));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_additional_add_proposals_for_same_client_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_additional([
+ Proposal::Add(make_add_proposal().await),
+ Proposal::Add(make_add_proposal().await),
+ ])
+ .send()
+ .await;
+
+ assert_matches!(res, Err(MlsError::DuplicateLeafData(1)));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_add_proposals_for_same_client_keeps_only_one() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let add_one = Proposal::Add(make_add_proposal().await);
+ let add_two = Proposal::Add(make_add_proposal().await);
+ let add_ref_one = make_proposal_ref(&add_one, alice).await;
+ let add_ref_two = make_proposal_ref(&add_two, alice).await;
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .cache(add_ref_one.clone(), add_one.clone(), alice)
+ .cache(add_ref_two.clone(), add_two.clone(), alice)
+ .send()
+ .await
+ .unwrap();
+
+ let committed_add_ref = match &*processed_proposals.0 {
+ [ProposalOrRef::Reference(add_ref)] => add_ref,
+ _ => panic!("committed proposals list does not contain exactly one reference"),
+ };
+
+ let add_refs = [add_ref_one, add_ref_two];
+ assert!(add_refs.contains(committed_add_ref));
+
+ #[cfg(feature = "state_update")]
+ assert_matches!(
+ &*processed_proposals.1.unused_proposals,
+ [rejected_add_info] if committed_add_ref != rejected_add_info.proposal_ref().unwrap() && add_refs.contains(rejected_add_info.proposal_ref().unwrap())
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_update_for_different_identity_fails() {
+ let (alice, mut tree) = new_tree("alice").await;
+ let bob = add_member(&mut tree, "bob").await;
+
+ let update = Proposal::Update(make_update_proposal_custom("carol", 1).await);
+ let update_ref = make_proposal_ref(&update, bob).await;
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .cache(update_ref.clone(), update, bob)
+ .receive([update_ref])
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidSuccessor));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_update_for_different_identity_filters_it_out() {
+ let (alice, mut tree) = new_tree("alice").await;
+ let bob = add_member(&mut tree, "bob").await;
+
+ let update = Proposal::Update(make_update_proposal("carol").await);
+ let update_info = make_proposal_info(&update, bob).await;
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .cache(update_info.proposal_ref().unwrap().clone(), update, bob)
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(processed_proposals.0, Vec::new());
+
+ // Bob proposed the update, so it is not listed as rejected when Alice commits it because
+ // she didn't propose it.
+ #[cfg(feature = "state_update")]
+ assert_eq!(processed_proposals.1.unused_proposals, vec![update_info]);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_add_for_same_client_as_existing_member_fails() {
+ let (alice, public_tree) = new_tree("alice").await;
+ let add = Proposal::Add(make_add_proposal().await);
+
+ let ProvisionalState { public_tree, .. } = CommitReceiver::new(
+ &public_tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .receive([add.clone()])
+ .await
+ .unwrap();
+
+ let res = CommitReceiver::new(
+ &public_tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .receive([add])
+ .await;
+
+ assert_matches!(res, Err(MlsError::DuplicateLeafData(1)));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_additional_add_for_same_client_as_existing_member_fails() {
+ let (alice, public_tree) = new_tree("alice").await;
+ let add = Proposal::Add(make_add_proposal().await);
+
+ let ProvisionalState { public_tree, .. } = CommitReceiver::new(
+ &public_tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .receive([add.clone()])
+ .await
+ .unwrap();
+
+ let res = CommitSender::new(
+ &public_tree,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .with_additional([add])
+ .send()
+ .await;
+
+ assert_matches!(res, Err(MlsError::DuplicateLeafData(1)));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_add_for_same_client_as_existing_member_filters_it_out() {
+ let (alice, public_tree) = new_tree("alice").await;
+ let add = Proposal::Add(make_add_proposal().await);
+
+ let ProvisionalState { public_tree, .. } = CommitReceiver::new(
+ &public_tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .receive([add.clone()])
+ .await
+ .unwrap();
+
+ let proposal_info = make_proposal_info(&add, alice).await;
+
+ let processed_proposals = CommitSender::new(
+ &public_tree,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .cache(
+ proposal_info.proposal_ref().unwrap().clone(),
+ add.clone(),
+ alice,
+ )
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(processed_proposals.0, Vec::new());
+
+ #[cfg(feature = "state_update")]
+ assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]);
+ }
+
+ #[cfg(feature = "psk")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_psk_proposals_with_same_psk_id_fails() {
+ let (alice, tree) = new_tree("alice").await;
+ let psk_proposal = Proposal::Psk(new_external_psk(b"foo"));
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .receive([psk_proposal.clone(), psk_proposal])
+ .await;
+
+ assert_matches!(res, Err(MlsError::DuplicatePskIds));
+ }
+
+ #[cfg(feature = "psk")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_additional_psk_proposals_with_same_psk_id_fails() {
+ let (alice, tree) = new_tree("alice").await;
+ let psk_proposal = Proposal::Psk(new_external_psk(b"foo"));
+
+ let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_additional([psk_proposal.clone(), psk_proposal])
+ .send()
+ .await;
+
+ assert_matches!(res, Err(MlsError::DuplicatePskIds));
+ }
+
+ #[cfg(feature = "psk")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_psk_proposals_with_same_psk_id_keeps_only_one() {
+ let (alice, mut tree) = new_tree("alice").await;
+ let bob = add_member(&mut tree, "bob").await;
+
+ let proposal = Proposal::Psk(new_external_psk(b"foo"));
+
+ let proposal_info = [
+ make_proposal_info(&proposal, alice).await,
+ make_proposal_info(&proposal, bob).await,
+ ];
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .cache(
+ proposal_info[0].proposal_ref().unwrap().clone(),
+ proposal.clone(),
+ alice,
+ )
+ .cache(
+ proposal_info[1].proposal_ref().unwrap().clone(),
+ proposal,
+ bob,
+ )
+ .send()
+ .await
+ .unwrap();
+
+ let committed_info = match processed_proposals
+ .1
+ .applied_proposals
+ .clone()
+ .into_proposals()
+ .collect_vec()
+ .as_slice()
+ {
+ [r] => r.clone(),
+ _ => panic!("Expected single proposal reference in {processed_proposals:?}"),
+ };
+
+ assert!(proposal_info.contains(&committed_info));
+
+ #[cfg(feature = "state_update")]
+ match &*processed_proposals.1.unused_proposals {
+ [r] => {
+ assert_ne!(*r, committed_info);
+ assert!(proposal_info.contains(r));
+ }
+ _ => panic!(
+ "Expected one proposal reference in {:?}",
+ processed_proposals.1.unused_proposals
+ ),
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_multiple_group_context_extensions_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .receive([
+ Proposal::GroupContextExtensions(ExtensionList::new()),
+ Proposal::GroupContextExtensions(ExtensionList::new()),
+ ])
+ .await;
+
+ assert_matches!(
+ res,
+ Err(MlsError::MoreThanOneGroupContextExtensionsProposal)
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_multiple_additional_group_context_extensions_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_additional([
+ Proposal::GroupContextExtensions(ExtensionList::new()),
+ Proposal::GroupContextExtensions(ExtensionList::new()),
+ ])
+ .send()
+ .await;
+
+ assert_matches!(
+ res,
+ Err(MlsError::MoreThanOneGroupContextExtensionsProposal)
+ );
+ }
+
+ fn make_extension_list(foo: u8) -> ExtensionList {
+ vec![TestExtension { foo }.into_extension().unwrap()].into()
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_multiple_group_context_extensions_keeps_only_one() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let (alice, tree) = {
+ let (signing_identity, signature_key) =
+ get_test_signing_identity(TEST_CIPHER_SUITE, b"alice").await;
+
+ let properties = ConfigProperties {
+ capabilities: Capabilities {
+ extensions: vec![42.into()],
+ ..Capabilities::default()
+ },
+ extensions: Default::default(),
+ };
+
+ let (leaf, secret) = LeafNode::generate(
+ &cipher_suite_provider,
+ properties,
+ signing_identity,
+ &signature_key,
+ Lifetime::years(1).unwrap(),
+ )
+ .await
+ .unwrap();
+
+ let (pub_tree, priv_tree) =
+ TreeKemPublic::derive(leaf, secret, &BasicIdentityProvider, &Default::default())
+ .await
+ .unwrap();
+
+ (priv_tree.self_index, pub_tree)
+ };
+
+ let proposals = [
+ Proposal::GroupContextExtensions(make_extension_list(0)),
+ Proposal::GroupContextExtensions(make_extension_list(1)),
+ ];
+
+ let gce_info = [
+ make_proposal_info(&proposals[0], alice).await,
+ make_proposal_info(&proposals[1], alice).await,
+ ];
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .cache(
+ gce_info[0].proposal_ref().unwrap().clone(),
+ proposals[0].clone(),
+ alice,
+ )
+ .cache(
+ gce_info[1].proposal_ref().unwrap().clone(),
+ proposals[1].clone(),
+ alice,
+ )
+ .send()
+ .await
+ .unwrap();
+
+ let committed_gce_info = match processed_proposals
+ .1
+ .applied_proposals
+ .clone()
+ .into_proposals()
+ .collect_vec()
+ .as_slice()
+ {
+ [gce_info] => gce_info.clone(),
+ _ => panic!("committed proposals list does not contain exactly one reference"),
+ };
+
+ assert!(gce_info.contains(&committed_gce_info));
+
+ #[cfg(feature = "state_update")]
+ assert_matches!(
+ &*processed_proposals.1.unused_proposals,
+ [rejected_gce_info] if committed_gce_info != *rejected_gce_info && gce_info.contains(rejected_gce_info)
+ );
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn make_external_senders_extension() -> ExtensionList {
+ let identity = get_test_signing_identity(TEST_CIPHER_SUITE, b"alice")
+ .await
+ .0;
+
+ vec![ExternalSendersExt::new(vec![identity])
+ .into_extension()
+ .unwrap()]
+ .into()
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_invalid_external_senders_extension_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .with_identity_provider(FailureIdentityProvider::new())
+ .receive([Proposal::GroupContextExtensions(
+ make_external_senders_extension().await,
+ )])
+ .await;
+
+ assert_matches!(res, Err(MlsError::IdentityProviderError(_)));
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_additional_invalid_external_senders_extension_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_identity_provider(FailureIdentityProvider::new())
+ .with_additional([Proposal::GroupContextExtensions(
+ make_external_senders_extension().await,
+ )])
+ .send()
+ .await;
+
+ assert_matches!(res, Err(MlsError::IdentityProviderError(_)));
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_invalid_external_senders_extension_filters_it_out() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let proposal = Proposal::GroupContextExtensions(make_external_senders_extension().await);
+
+ let proposal_info = make_proposal_info(&proposal, alice).await;
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_identity_provider(FailureIdentityProvider::new())
+ .cache(
+ proposal_info.proposal_ref().unwrap().clone(),
+ proposal.clone(),
+ alice,
+ )
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(processed_proposals.0, Vec::new());
+
+ #[cfg(feature = "state_update")]
+ assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_reinit_with_other_proposals_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .receive([
+ Proposal::ReInit(make_reinit(TEST_PROTOCOL_VERSION)),
+ Proposal::Add(make_add_proposal().await),
+ ])
+ .await;
+
+ assert_matches!(res, Err(MlsError::OtherProposalWithReInit));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_additional_reinit_with_other_proposals_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_additional([
+ Proposal::ReInit(make_reinit(TEST_PROTOCOL_VERSION)),
+ Proposal::Add(make_add_proposal().await),
+ ])
+ .send()
+ .await;
+
+ assert_matches!(res, Err(MlsError::OtherProposalWithReInit));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_reinit_with_other_proposals_filters_it_out() {
+ let (alice, tree) = new_tree("alice").await;
+ let reinit = Proposal::ReInit(make_reinit(TEST_PROTOCOL_VERSION));
+ let reinit_info = make_proposal_info(&reinit, alice).await;
+ let add = Proposal::Add(make_add_proposal().await);
+ let add_ref = make_proposal_ref(&add, alice).await;
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .cache(
+ reinit_info.proposal_ref().unwrap().clone(),
+ reinit.clone(),
+ alice,
+ )
+ .cache(add_ref.clone(), add, alice)
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(processed_proposals.0, vec![add_ref.into()]);
+
+ #[cfg(feature = "state_update")]
+ assert_eq!(processed_proposals.1.unused_proposals, vec![reinit_info]);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_multiple_reinits_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .receive([
+ Proposal::ReInit(make_reinit(TEST_PROTOCOL_VERSION)),
+ Proposal::ReInit(make_reinit(TEST_PROTOCOL_VERSION)),
+ ])
+ .await;
+
+ assert_matches!(res, Err(MlsError::OtherProposalWithReInit));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_additional_multiple_reinits_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_additional([
+ Proposal::ReInit(make_reinit(TEST_PROTOCOL_VERSION)),
+ Proposal::ReInit(make_reinit(TEST_PROTOCOL_VERSION)),
+ ])
+ .send()
+ .await;
+
+ assert_matches!(res, Err(MlsError::OtherProposalWithReInit));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_multiple_reinits_keeps_only_one() {
+ let (alice, tree) = new_tree("alice").await;
+ let reinit = Proposal::ReInit(make_reinit(TEST_PROTOCOL_VERSION));
+ let reinit_ref = make_proposal_ref(&reinit, alice).await;
+ let other_reinit = Proposal::ReInit(ReInitProposal {
+ group_id: b"other_group".to_vec(),
+ ..make_reinit(TEST_PROTOCOL_VERSION)
+ });
+ let other_reinit_ref = make_proposal_ref(&other_reinit, alice).await;
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .cache(reinit_ref.clone(), reinit.clone(), alice)
+ .cache(other_reinit_ref.clone(), other_reinit.clone(), alice)
+ .send()
+ .await
+ .unwrap();
+
+ let processed_ref = match &*processed_proposals.0 {
+ [ProposalOrRef::Reference(r)] => r,
+ p => panic!("Expected single proposal reference but found {p:?}"),
+ };
+
+ assert!(*processed_ref == reinit_ref || *processed_ref == other_reinit_ref);
+
+ #[cfg(feature = "state_update")]
+ {
+ let (rejected_ref, unused_proposal) = match &*processed_proposals.1.unused_proposals {
+ [r] => (r.proposal_ref().unwrap().clone(), r.proposal.clone()),
+ p => panic!("Expected single proposal but found {p:?}"),
+ };
+
+ assert_ne!(rejected_ref, *processed_ref);
+ assert!(rejected_ref == reinit_ref || rejected_ref == other_reinit_ref);
+ assert!(unused_proposal == reinit || unused_proposal == other_reinit);
+ }
+ }
+
+ fn make_external_init() -> ExternalInit {
+ ExternalInit {
+ kem_output: vec![33; test_cipher_suite_provider(TEST_CIPHER_SUITE).kdf_extract_size()],
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_external_init_from_member_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .receive([Proposal::ExternalInit(make_external_init())])
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidProposalTypeForSender));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_additional_external_init_from_member_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_additional([Proposal::ExternalInit(make_external_init())])
+ .send()
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidProposalTypeForSender));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_external_init_from_member_filters_it_out() {
+ let (alice, tree) = new_tree("alice").await;
+ let external_init = Proposal::ExternalInit(make_external_init());
+ let external_init_info = make_proposal_info(&external_init, alice).await;
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .cache(
+ external_init_info.proposal_ref().unwrap().clone(),
+ external_init.clone(),
+ alice,
+ )
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(processed_proposals.0, Vec::new());
+
+ #[cfg(feature = "state_update")]
+ assert_eq!(
+ processed_proposals.1.unused_proposals,
+ vec![external_init_info]
+ );
+ }
+
+ fn required_capabilities_proposal(extension: u16) -> Proposal {
+ let required_capabilities = RequiredCapabilitiesExt {
+ extensions: vec![extension.into()],
+ ..Default::default()
+ };
+
+ let ext = vec![required_capabilities.into_extension().unwrap()];
+
+ Proposal::GroupContextExtensions(ext.into())
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_required_capabilities_not_supported_by_member_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .receive([required_capabilities_proposal(33)])
+ .await;
+
+ assert_matches!(
+ res,
+ Err(MlsError::RequiredExtensionNotFound(v)) if v == 33.into()
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_required_capabilities_not_supported_by_member_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_additional([required_capabilities_proposal(33)])
+ .send()
+ .await;
+
+ assert_matches!(
+ res,
+ Err(MlsError::RequiredExtensionNotFound(v)) if v == 33.into()
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_additional_required_capabilities_not_supported_by_member_filters_it_out() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let proposal = required_capabilities_proposal(33);
+ let proposal_info = make_proposal_info(&proposal, alice).await;
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .cache(
+ proposal_info.proposal_ref().unwrap().clone(),
+ proposal.clone(),
+ alice,
+ )
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(processed_proposals.0, Vec::new());
+
+ #[cfg(feature = "state_update")]
+ assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn committing_update_from_pk1_to_pk2_and_update_from_pk2_to_pk3_works() {
+ let (alice_leaf, alice_secret, alice_signer) =
+ get_basic_test_node_sig_key(TEST_CIPHER_SUITE, "alice").await;
+
+ let (mut tree, priv_tree) = TreeKemPublic::derive(
+ alice_leaf.clone(),
+ alice_secret,
+ &BasicIdentityProvider,
+ &Default::default(),
+ )
+ .await
+ .unwrap();
+
+ let alice = priv_tree.self_index;
+
+ let bob = add_member(&mut tree, "bob").await;
+ let carol = add_member(&mut tree, "carol").await;
+
+ let bob_current_leaf = tree.get_leaf_node(bob).unwrap();
+
+ let mut alice_new_leaf = LeafNode {
+ public_key: bob_current_leaf.public_key.clone(),
+ leaf_node_source: LeafNodeSource::Update,
+ ..alice_leaf
+ };
+
+ alice_new_leaf
+ .sign(
+ &test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ &alice_signer,
+ &(TEST_GROUP, 0).into(),
+ )
+ .await
+ .unwrap();
+
+ let bob_new_leaf = update_leaf_node("bob", 1).await;
+
+ let pk1_to_pk2 = Proposal::Update(UpdateProposal {
+ leaf_node: alice_new_leaf.clone(),
+ });
+
+ let pk1_to_pk2_ref = make_proposal_ref(&pk1_to_pk2, alice).await;
+
+ let pk2_to_pk3 = Proposal::Update(UpdateProposal {
+ leaf_node: bob_new_leaf.clone(),
+ });
+
+ let pk2_to_pk3_ref = make_proposal_ref(&pk2_to_pk3, bob).await;
+
+ let effects = CommitReceiver::new(
+ &tree,
+ carol,
+ carol,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .cache(pk1_to_pk2_ref.clone(), pk1_to_pk2, alice)
+ .cache(pk2_to_pk3_ref.clone(), pk2_to_pk3, bob)
+ .receive([pk1_to_pk2_ref, pk2_to_pk3_ref])
+ .await
+ .unwrap();
+
+ assert_eq!(effects.applied_proposals.update_senders, vec![alice, bob]);
+
+ assert_eq!(
+ effects
+ .applied_proposals
+ .updates
+ .into_iter()
+ .map(|p| p.proposal.leaf_node)
+ .collect_vec(),
+ vec![alice_new_leaf, bob_new_leaf]
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn committing_update_from_pk1_to_pk2_and_removal_of_pk2_works() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let (alice_leaf, alice_secret, alice_signer) =
+ get_basic_test_node_sig_key(TEST_CIPHER_SUITE, "alice").await;
+
+ let (mut tree, priv_tree) = TreeKemPublic::derive(
+ alice_leaf.clone(),
+ alice_secret,
+ &BasicIdentityProvider,
+ &Default::default(),
+ )
+ .await
+ .unwrap();
+
+ let alice = priv_tree.self_index;
+
+ let bob = add_member(&mut tree, "bob").await;
+ let carol = add_member(&mut tree, "carol").await;
+
+ let bob_current_leaf = tree.get_leaf_node(bob).unwrap();
+
+ let mut alice_new_leaf = LeafNode {
+ public_key: bob_current_leaf.public_key.clone(),
+ leaf_node_source: LeafNodeSource::Update,
+ ..alice_leaf
+ };
+
+ alice_new_leaf
+ .sign(
+ &cipher_suite_provider,
+ &alice_signer,
+ &(TEST_GROUP, 0).into(),
+ )
+ .await
+ .unwrap();
+
+ let pk1_to_pk2 = Proposal::Update(UpdateProposal {
+ leaf_node: alice_new_leaf.clone(),
+ });
+
+ let pk1_to_pk2_ref = make_proposal_ref(&pk1_to_pk2, alice).await;
+
+ let remove_pk2 = Proposal::Remove(RemoveProposal { to_remove: bob });
+
+ let remove_pk2_ref = make_proposal_ref(&remove_pk2, bob).await;
+
+ let effects = CommitReceiver::new(
+ &tree,
+ carol,
+ carol,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .cache(pk1_to_pk2_ref.clone(), pk1_to_pk2, alice)
+ .cache(remove_pk2_ref.clone(), remove_pk2, bob)
+ .receive([pk1_to_pk2_ref, remove_pk2_ref])
+ .await
+ .unwrap();
+
+ assert_eq!(effects.applied_proposals.update_senders, vec![alice]);
+
+ assert_eq!(
+ effects
+ .applied_proposals
+ .updates
+ .into_iter()
+ .map(|p| p.proposal.leaf_node)
+ .collect_vec(),
+ vec![alice_new_leaf]
+ );
+
+ assert_eq!(
+ effects
+ .applied_proposals
+ .removals
+ .into_iter()
+ .map(|p| p.proposal.to_remove)
+ .collect_vec(),
+ vec![bob]
+ );
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn unsupported_credential_key_package(name: &str) -> KeyPackage {
+ let (mut signing_identity, secret_key) =
+ get_test_signing_identity(TEST_CIPHER_SUITE, name.as_bytes()).await;
+
+ signing_identity.credential = Credential::Custom(CustomCredential::new(
+ CredentialType::new(BasicWithCustomProvider::CUSTOM_CREDENTIAL_TYPE),
+ random_bytes(32),
+ ));
+
+ let generator = KeyPackageGenerator {
+ protocol_version: TEST_PROTOCOL_VERSION,
+ cipher_suite_provider: &test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ signing_identity: &signing_identity,
+ signing_key: &secret_key,
+ identity_provider: &BasicWithCustomProvider::new(BasicIdentityProvider::new()),
+ };
+
+ generator
+ .generate(
+ Lifetime::years(1).unwrap(),
+ Capabilities {
+ credentials: vec![42.into()],
+ ..Default::default()
+ },
+ Default::default(),
+ Default::default(),
+ )
+ .await
+ .unwrap()
+ .key_package
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_add_with_leaf_not_supporting_credential_type_of_other_leaf_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .receive([Proposal::Add(Box::new(AddProposal {
+ key_package: unsupported_credential_key_package("bob").await,
+ }))])
+ .await;
+
+ assert_matches!(res, Err(MlsError::InUseCredentialTypeUnsupportedByNewLeaf));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_additional_add_with_leaf_not_supporting_credential_type_of_other_leaf_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_additional([Proposal::Add(Box::new(AddProposal {
+ key_package: unsupported_credential_key_package("bob").await,
+ }))])
+ .send()
+ .await;
+
+ assert_matches!(res, Err(MlsError::InUseCredentialTypeUnsupportedByNewLeaf));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_add_with_leaf_not_supporting_credential_type_of_other_leaf_filters_it_out() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let add = Proposal::Add(Box::new(AddProposal {
+ key_package: unsupported_credential_key_package("bob").await,
+ }));
+
+ let add_info = make_proposal_info(&add, alice).await;
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .cache(add_info.proposal_ref().unwrap().clone(), add.clone(), alice)
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(processed_proposals.0, Vec::new());
+
+ #[cfg(feature = "state_update")]
+ assert_eq!(processed_proposals.1.unused_proposals, vec![add_info]);
+ }
+
+ #[cfg(feature = "custom_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_custom_proposal_with_member_not_supporting_proposal_type_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let custom_proposal = Proposal::Custom(CustomProposal::new(ProposalType::new(42), vec![]));
+
+ let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_additional([custom_proposal.clone()])
+ .send()
+ .await;
+
+ assert_matches!(
+ res,
+ Err(
+ MlsError::UnsupportedCustomProposal(c)
+ ) if c == custom_proposal.proposal_type()
+ );
+ }
+
+ #[cfg(feature = "custom_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_custom_proposal_with_member_not_supporting_filters_it_out() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let custom_proposal = Proposal::Custom(CustomProposal::new(ProposalType::new(42), vec![]));
+
+ let custom_info = make_proposal_info(&custom_proposal, alice).await;
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .cache(
+ custom_info.proposal_ref().unwrap().clone(),
+ custom_proposal.clone(),
+ alice,
+ )
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(processed_proposals.0, Vec::new());
+
+ #[cfg(feature = "state_update")]
+ assert_eq!(processed_proposals.1.unused_proposals, vec![custom_info]);
+ }
+
+ #[cfg(feature = "custom_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_custom_proposal_with_member_not_supporting_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let custom_proposal = Proposal::Custom(CustomProposal::new(ProposalType::new(42), vec![]));
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .receive([custom_proposal.clone()])
+ .await;
+
+ assert_matches!(
+ res,
+ Err(MlsError::UnsupportedCustomProposal(c)) if c == custom_proposal.proposal_type()
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_group_extension_unsupported_by_leaf_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .receive([Proposal::GroupContextExtensions(make_extension_list(0))])
+ .await;
+
+ assert_matches!(
+ res,
+ Err(
+ MlsError::UnsupportedGroupExtension(v)
+ ) if v == 42.into()
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_additional_group_extension_unsupported_by_leaf_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_additional([Proposal::GroupContextExtensions(make_extension_list(0))])
+ .send()
+ .await;
+
+ assert_matches!(
+ res,
+ Err(
+ MlsError::UnsupportedGroupExtension(v)
+ ) if v == 42.into()
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_group_extension_unsupported_by_leaf_filters_it_out() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let proposal = Proposal::GroupContextExtensions(make_extension_list(0));
+ let proposal_info = make_proposal_info(&proposal, alice).await;
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .cache(
+ proposal_info.proposal_ref().unwrap().clone(),
+ proposal.clone(),
+ alice,
+ )
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(processed_proposals.0, Vec::new());
+
+ #[cfg(feature = "state_update")]
+ assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]);
+ }
+
+ #[cfg(feature = "psk")]
+ #[derive(Debug)]
+ struct AlwaysNotFoundPskStorage;
+
+ #[cfg(feature = "psk")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ #[cfg_attr(mls_build_async, maybe_async::must_be_async)]
+ impl PreSharedKeyStorage for AlwaysNotFoundPskStorage {
+ type Error = Infallible;
+
+ async fn get(&self, _: &ExternalPskId) -> Result<Option<PreSharedKey>, Self::Error> {
+ Ok(None)
+ }
+ }
+
+ #[cfg(feature = "psk")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_external_psk_with_unknown_id_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .with_psk_storage(AlwaysNotFoundPskStorage)
+ .receive([Proposal::Psk(new_external_psk(b"abc"))])
+ .await;
+
+ assert_matches!(res, Err(MlsError::MissingRequiredPsk));
+ }
+
+ #[cfg(feature = "psk")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_additional_external_psk_with_unknown_id_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_psk_storage(AlwaysNotFoundPskStorage)
+ .with_additional([Proposal::Psk(new_external_psk(b"abc"))])
+ .send()
+ .await;
+
+ assert_matches!(res, Err(MlsError::MissingRequiredPsk));
+ }
+
+ #[cfg(feature = "psk")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_external_psk_with_unknown_id_filters_it_out() {
+ let (alice, tree) = new_tree("alice").await;
+ let proposal = Proposal::Psk(new_external_psk(b"abc"));
+ let proposal_info = make_proposal_info(&proposal, alice).await;
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_psk_storage(AlwaysNotFoundPskStorage)
+ .cache(
+ proposal_info.proposal_ref().unwrap().clone(),
+ proposal.clone(),
+ alice,
+ )
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(processed_proposals.0, Vec::new());
+
+ #[cfg(feature = "state_update")]
+ assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn user_defined_filter_can_remove_proposals() {
+ struct RemoveGroupContextExtensions;
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ #[cfg_attr(mls_build_async, maybe_async::must_be_async)]
+ impl MlsRules for RemoveGroupContextExtensions {
+ type Error = Infallible;
+
+ async fn filter_proposals(
+ &self,
+ _: CommitDirection,
+ _: CommitSource,
+ _: &Roster,
+ _: &ExtensionList,
+ mut proposals: ProposalBundle,
+ ) -> Result<ProposalBundle, Self::Error> {
+ proposals.group_context_extensions.clear();
+ Ok(proposals)
+ }
+
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn commit_options(
+ &self,
+ _: &Roster,
+ _: &ExtensionList,
+ _: &ProposalBundle,
+ ) -> Result<CommitOptions, Self::Error> {
+ Ok(Default::default())
+ }
+
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn encryption_options(
+ &self,
+ _: &Roster,
+ _: &ExtensionList,
+ ) -> Result<EncryptionOptions, Self::Error> {
+ Ok(Default::default())
+ }
+ }
+
+ let (alice, tree) = new_tree("alice").await;
+
+ let (committed, _) =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_additional([Proposal::GroupContextExtensions(Default::default())])
+ .with_user_rules(RemoveGroupContextExtensions)
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(committed, Vec::new());
+ }
+
+ struct FailureMlsRules;
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ #[cfg_attr(mls_build_async, maybe_async::must_be_async)]
+ impl MlsRules for FailureMlsRules {
+ type Error = MlsError;
+
+ async fn filter_proposals(
+ &self,
+ _: CommitDirection,
+ _: CommitSource,
+ _: &Roster,
+ _: &ExtensionList,
+ _: ProposalBundle,
+ ) -> Result<ProposalBundle, Self::Error> {
+ Err(MlsError::InvalidSignature)
+ }
+
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn commit_options(
+ &self,
+ _: &Roster,
+ _: &ExtensionList,
+ _: &ProposalBundle,
+ ) -> Result<CommitOptions, Self::Error> {
+ Ok(Default::default())
+ }
+
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn encryption_options(
+ &self,
+ _: &Roster,
+ _: &ExtensionList,
+ ) -> Result<EncryptionOptions, Self::Error> {
+ Ok(Default::default())
+ }
+ }
+
+ struct InjectMlsRules {
+ to_inject: Proposal,
+ source: ProposalSource,
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ #[cfg_attr(mls_build_async, maybe_async::must_be_async)]
+ impl MlsRules for InjectMlsRules {
+ type Error = MlsError;
+
+ async fn filter_proposals(
+ &self,
+ _: CommitDirection,
+ _: CommitSource,
+ _: &Roster,
+ _: &ExtensionList,
+ mut proposals: ProposalBundle,
+ ) -> Result<ProposalBundle, Self::Error> {
+ proposals.add(
+ self.to_inject.clone(),
+ Sender::Member(0),
+ self.source.clone(),
+ );
+ Ok(proposals)
+ }
+
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn commit_options(
+ &self,
+ _: &Roster,
+ _: &ExtensionList,
+ _: &ProposalBundle,
+ ) -> Result<CommitOptions, Self::Error> {
+ Ok(Default::default())
+ }
+
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn encryption_options(
+ &self,
+ _: &Roster,
+ _: &ExtensionList,
+ ) -> Result<EncryptionOptions, Self::Error> {
+ Ok(Default::default())
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn user_defined_filter_can_inject_proposals() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let test_proposal = Proposal::GroupContextExtensions(Default::default());
+
+ let (committed, _) =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_user_rules(InjectMlsRules {
+ to_inject: test_proposal.clone(),
+ source: ProposalSource::ByValue,
+ })
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(
+ committed,
+ vec![ProposalOrRef::Proposal(test_proposal.into())]
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn user_defined_filter_can_inject_local_only_proposals() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let test_proposal = Proposal::GroupContextExtensions(Default::default());
+
+ let (committed, _) =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_user_rules(InjectMlsRules {
+ to_inject: test_proposal.clone(),
+ source: ProposalSource::Local,
+ })
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(committed, vec![]);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn user_defined_filter_cant_break_base_rules() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let test_proposal = Proposal::Update(UpdateProposal {
+ leaf_node: get_basic_test_node(TEST_CIPHER_SUITE, "leaf").await,
+ });
+
+ let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_user_rules(InjectMlsRules {
+ to_inject: test_proposal.clone(),
+ source: ProposalSource::ByValue,
+ })
+ .send()
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidProposalTypeForSender { .. }))
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn user_defined_filter_can_refuse_to_send_commit() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_additional([Proposal::GroupContextExtensions(Default::default())])
+ .with_user_rules(FailureMlsRules)
+ .send()
+ .await;
+
+ assert_matches!(res, Err(MlsError::MlsRulesError(_)));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn user_defined_filter_can_reject_incoming_commit() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .with_user_rules(FailureMlsRules)
+ .receive([Proposal::GroupContextExtensions(Default::default())])
+ .await;
+
+ assert_matches!(res, Err(MlsError::MlsRulesError(_)));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn proposers_are_verified() {
+ let (alice, mut tree) = new_tree("alice").await;
+ let bob = add_member(&mut tree, "bob").await;
+
+ #[cfg(feature = "by_ref_proposal")]
+ let identity = get_test_signing_identity(TEST_CIPHER_SUITE, b"carol")
+ .await
+ .0;
+
+ #[cfg(feature = "by_ref_proposal")]
+ let external_senders = ExternalSendersExt::new(vec![identity]);
+
+ let proposals: &[Proposal] = &[
+ Proposal::Add(make_add_proposal().await),
+ Proposal::Update(make_update_proposal("alice").await),
+ Proposal::Remove(RemoveProposal { to_remove: bob }),
+ #[cfg(feature = "psk")]
+ Proposal::Psk(make_external_psk(
+ b"ted",
+ PskNonce::random(&test_cipher_suite_provider(TEST_CIPHER_SUITE)).unwrap(),
+ )),
+ Proposal::ReInit(make_reinit(TEST_PROTOCOL_VERSION)),
+ Proposal::ExternalInit(make_external_init()),
+ Proposal::GroupContextExtensions(Default::default()),
+ ];
+
+ let proposers = [
+ Sender::Member(*alice),
+ #[cfg(feature = "by_ref_proposal")]
+ Sender::External(0),
+ Sender::NewMemberCommit,
+ Sender::NewMemberProposal,
+ ];
+
+ for ((proposer, proposal), by_ref) in proposers
+ .into_iter()
+ .cartesian_product(proposals)
+ .cartesian_product([true])
+ {
+ let committer = Sender::Member(*alice);
+
+ let receiver = CommitReceiver::new(
+ &tree,
+ committer,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ );
+
+ #[cfg(feature = "by_ref_proposal")]
+ let extensions: ExtensionList =
+ vec![external_senders.clone().into_extension().unwrap()].into();
+
+ #[cfg(feature = "by_ref_proposal")]
+ let receiver = receiver.with_extensions(extensions);
+
+ let (receiver, proposals, proposer) = if by_ref {
+ let proposal_ref = make_proposal_ref(proposal, proposer).await;
+ let receiver = receiver.cache(proposal_ref.clone(), proposal.clone(), proposer);
+ (receiver, vec![ProposalOrRef::from(proposal_ref)], proposer)
+ } else {
+ (receiver, vec![proposal.clone().into()], committer)
+ };
+
+ let res = receiver.receive(proposals).await;
+
+ if proposer_can_propose(proposer, proposal.proposal_type(), by_ref).is_err() {
+ assert_matches!(res, Err(MlsError::InvalidProposalTypeForSender));
+ } else {
+ let is_self_update = proposal.proposal_type() == ProposalType::UPDATE
+ && by_ref
+ && matches!(proposer, Sender::Member(_));
+
+ if !is_self_update {
+ res.unwrap();
+ }
+ }
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn make_update_proposal(name: &str) -> UpdateProposal {
+ UpdateProposal {
+ leaf_node: update_leaf_node(name, 1).await,
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn make_update_proposal_custom(name: &str, leaf_index: u32) -> UpdateProposal {
+ UpdateProposal {
+ leaf_node: update_leaf_node(name, leaf_index).await,
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn when_receiving_commit_unused_proposals_are_proposals_in_cache_but_not_in_commit() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let proposal = Proposal::GroupContextExtensions(Default::default());
+ let proposal_ref = make_proposal_ref(&proposal, alice).await;
+
+ let state = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .cache(proposal_ref.clone(), proposal, alice)
+ .receive([Proposal::Add(Box::new(AddProposal {
+ key_package: test_key_package(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "bob").await,
+ }))])
+ .await
+ .unwrap();
+
+ let [p] = &state.unused_proposals[..] else {
+ panic!(
+ "Expected single unused proposal but got {:?}",
+ state.unused_proposals
+ );
+ };
+
+ assert_eq!(p.proposal_ref(), Some(&proposal_ref));
+ }
+}
diff --git a/src/group/proposal_filter.rs b/src/group/proposal_filter.rs
new file mode 100644
index 0000000..5ef6b20
--- /dev/null
+++ b/src/group/proposal_filter.rs
@@ -0,0 +1,23 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+mod bundle;
+mod filtering_common;
+
+#[cfg(feature = "by_ref_proposal")]
+mod filtering;
+#[cfg(not(feature = "by_ref_proposal"))]
+pub mod filtering_lite;
+#[cfg(all(feature = "custom_proposal", not(feature = "by_ref_proposal")))]
+use filtering_lite as filtering;
+
+pub use bundle::{ProposalBundle, ProposalInfo, ProposalSource};
+
+#[cfg(feature = "by_ref_proposal")]
+pub(crate) use filtering::FilterStrategy;
+
+pub(crate) use filtering_common::ProposalApplier;
+
+#[cfg(all(feature = "by_ref_proposal", test))]
+pub(crate) use filtering::proposer_can_propose;
diff --git a/src/group/proposal_filter/bundle.rs b/src/group/proposal_filter/bundle.rs
new file mode 100644
index 0000000..f18a75b
--- /dev/null
+++ b/src/group/proposal_filter/bundle.rs
@@ -0,0 +1,633 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::boxed::Box;
+use alloc::vec::Vec;
+
+#[cfg(feature = "custom_proposal")]
+use itertools::Itertools;
+
+use crate::{
+ group::{
+ AddProposal, BorrowedProposal, Proposal, ProposalOrRef, ProposalType, ReInitProposal,
+ RemoveProposal, Sender,
+ },
+ ExtensionList,
+};
+
+#[cfg(feature = "by_ref_proposal")]
+use crate::group::{proposal_cache::CachedProposal, LeafIndex, ProposalRef, UpdateProposal};
+
+#[cfg(feature = "psk")]
+use crate::group::PreSharedKeyProposal;
+
+#[cfg(feature = "custom_proposal")]
+use crate::group::proposal::CustomProposal;
+
+use crate::group::ExternalInit;
+
+use core::iter::empty;
+
+#[derive(Clone, Debug, Default)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+/// A collection of proposals.
+pub struct ProposalBundle {
+ pub(crate) additions: Vec<ProposalInfo<AddProposal>>,
+ #[cfg(feature = "by_ref_proposal")]
+ pub(crate) updates: Vec<ProposalInfo<UpdateProposal>>,
+ #[cfg(feature = "by_ref_proposal")]
+ pub(crate) update_senders: Vec<LeafIndex>,
+ pub(crate) removals: Vec<ProposalInfo<RemoveProposal>>,
+ #[cfg(feature = "psk")]
+ pub(crate) psks: Vec<ProposalInfo<PreSharedKeyProposal>>,
+ pub(crate) reinitializations: Vec<ProposalInfo<ReInitProposal>>,
+ pub(crate) external_initializations: Vec<ProposalInfo<ExternalInit>>,
+ pub(crate) group_context_extensions: Vec<ProposalInfo<ExtensionList>>,
+ #[cfg(feature = "custom_proposal")]
+ pub(crate) custom_proposals: Vec<ProposalInfo<CustomProposal>>,
+}
+
+impl ProposalBundle {
+ pub fn add(&mut self, proposal: Proposal, sender: Sender, source: ProposalSource) {
+ match proposal {
+ Proposal::Add(proposal) => self.additions.push(ProposalInfo {
+ proposal: *proposal,
+ sender,
+ source,
+ }),
+ #[cfg(feature = "by_ref_proposal")]
+ Proposal::Update(proposal) => self.updates.push(ProposalInfo {
+ proposal,
+ sender,
+ source,
+ }),
+ Proposal::Remove(proposal) => self.removals.push(ProposalInfo {
+ proposal,
+ sender,
+ source,
+ }),
+ #[cfg(feature = "psk")]
+ Proposal::Psk(proposal) => self.psks.push(ProposalInfo {
+ proposal,
+ sender,
+ source,
+ }),
+ Proposal::ReInit(proposal) => self.reinitializations.push(ProposalInfo {
+ proposal,
+ sender,
+ source,
+ }),
+ Proposal::ExternalInit(proposal) => self.external_initializations.push(ProposalInfo {
+ proposal,
+ sender,
+ source,
+ }),
+ Proposal::GroupContextExtensions(proposal) => {
+ self.group_context_extensions.push(ProposalInfo {
+ proposal,
+ sender,
+ source,
+ })
+ }
+ #[cfg(feature = "custom_proposal")]
+ Proposal::Custom(proposal) => self.custom_proposals.push(ProposalInfo {
+ proposal,
+ sender,
+ source,
+ }),
+ }
+ }
+
+ /// Remove the proposal of type `T` at `index`
+ ///
+ /// Type `T` can be any of the standard MLS proposal types defined in the
+ /// [`proposal`](crate::group::proposal) module.
+ ///
+ /// `index` is consistent with the index returned by any of the proposal
+ /// type specific functions in this module.
+ pub fn remove<T: Proposable>(&mut self, index: usize) {
+ T::remove(self, index);
+ }
+
+ /// Iterate over proposals, filtered by type.
+ ///
+ /// Type `T` can be any of the standard MLS proposal types defined in the
+ /// [`proposal`](crate::group::proposal) module.
+ pub fn by_type<'a, T: Proposable + 'a>(&'a self) -> impl Iterator<Item = &'a ProposalInfo<T>> {
+ T::filter(self).iter()
+ }
+
+ /// Retain proposals, filtered by type.
+ ///
+ /// Type `T` can be any of the standard MLS proposal types defined in the
+ /// [`proposal`](crate::group::proposal) module.
+ pub fn retain_by_type<T, F, E>(&mut self, mut f: F) -> Result<(), E>
+ where
+ T: Proposable,
+ F: FnMut(&ProposalInfo<T>) -> Result<bool, E>,
+ {
+ let mut res = Ok(());
+
+ T::retain(self, |p| match f(p) {
+ Ok(keep) => keep,
+ Err(e) => {
+ if res.is_ok() {
+ res = Err(e);
+ }
+ false
+ }
+ });
+
+ res
+ }
+
+ /// Retain custom proposals in the bundle.
+ #[cfg(feature = "custom_proposal")]
+ pub fn retain_custom<F, E>(&mut self, mut f: F) -> Result<(), E>
+ where
+ F: FnMut(&ProposalInfo<CustomProposal>) -> Result<bool, E>,
+ {
+ let mut res = Ok(());
+
+ self.custom_proposals.retain(|p| match f(p) {
+ Ok(keep) => keep,
+ Err(e) => {
+ if res.is_ok() {
+ res = Err(e);
+ }
+ false
+ }
+ });
+
+ res
+ }
+
+ /// Retain MLS standard proposals in the bundle.
+ pub fn retain<F, E>(&mut self, mut f: F) -> Result<(), E>
+ where
+ F: FnMut(&ProposalInfo<BorrowedProposal<'_>>) -> Result<bool, E>,
+ {
+ self.retain_by_type::<AddProposal, _, _>(|proposal| {
+ f(&proposal.as_ref().map(BorrowedProposal::from))
+ })?;
+
+ #[cfg(feature = "by_ref_proposal")]
+ self.retain_by_type::<UpdateProposal, _, _>(|proposal| {
+ f(&proposal.as_ref().map(BorrowedProposal::from))
+ })?;
+
+ self.retain_by_type::<RemoveProposal, _, _>(|proposal| {
+ f(&proposal.as_ref().map(BorrowedProposal::from))
+ })?;
+
+ #[cfg(feature = "psk")]
+ self.retain_by_type::<PreSharedKeyProposal, _, _>(|proposal| {
+ f(&proposal.as_ref().map(BorrowedProposal::from))
+ })?;
+
+ self.retain_by_type::<ReInitProposal, _, _>(|proposal| {
+ f(&proposal.as_ref().map(BorrowedProposal::from))
+ })?;
+
+ self.retain_by_type::<ExternalInit, _, _>(|proposal| {
+ f(&proposal.as_ref().map(BorrowedProposal::from))
+ })?;
+
+ self.retain_by_type::<ExtensionList, _, _>(|proposal| {
+ f(&proposal.as_ref().map(BorrowedProposal::from))
+ })?;
+
+ Ok(())
+ }
+
+ /// The number of proposals in the bundle
+ pub fn length(&self) -> usize {
+ let len = 0;
+
+ #[cfg(feature = "psk")]
+ let len = len + self.psks.len();
+
+ let len = len + self.external_initializations.len();
+
+ #[cfg(feature = "custom_proposal")]
+ let len = len + self.custom_proposals.len();
+
+ #[cfg(feature = "by_ref_proposal")]
+ let len = len + self.updates.len();
+
+ len + self.additions.len()
+ + self.removals.len()
+ + self.reinitializations.len()
+ + self.group_context_extensions.len()
+ }
+
+ /// Iterate over all proposals inside the bundle.
+ pub fn iter_proposals(&self) -> impl Iterator<Item = ProposalInfo<BorrowedProposal<'_>>> {
+ let res = self
+ .additions
+ .iter()
+ .map(|p| p.as_ref().map(BorrowedProposal::Add))
+ .chain(
+ self.removals
+ .iter()
+ .map(|p| p.as_ref().map(BorrowedProposal::Remove)),
+ )
+ .chain(
+ self.reinitializations
+ .iter()
+ .map(|p| p.as_ref().map(BorrowedProposal::ReInit)),
+ );
+
+ #[cfg(feature = "by_ref_proposal")]
+ let res = res.chain(
+ self.updates
+ .iter()
+ .map(|p| p.as_ref().map(BorrowedProposal::Update)),
+ );
+
+ #[cfg(feature = "psk")]
+ let res = res.chain(
+ self.psks
+ .iter()
+ .map(|p| p.as_ref().map(BorrowedProposal::Psk)),
+ );
+
+ let res = res.chain(
+ self.external_initializations
+ .iter()
+ .map(|p| p.as_ref().map(BorrowedProposal::ExternalInit)),
+ );
+
+ let res = res.chain(
+ self.group_context_extensions
+ .iter()
+ .map(|p| p.as_ref().map(BorrowedProposal::GroupContextExtensions)),
+ );
+
+ #[cfg(feature = "custom_proposal")]
+ let res = res.chain(
+ self.custom_proposals
+ .iter()
+ .map(|p| p.as_ref().map(BorrowedProposal::Custom)),
+ );
+
+ res
+ }
+
+ /// Iterate over proposal in the bundle, consuming the bundle.
+ pub fn into_proposals(self) -> impl Iterator<Item = ProposalInfo<Proposal>> {
+ let res = empty();
+
+ #[cfg(feature = "custom_proposal")]
+ let res = res.chain(
+ self.custom_proposals
+ .into_iter()
+ .map(|p| p.map(Proposal::Custom)),
+ );
+
+ let res = res.chain(
+ self.external_initializations
+ .into_iter()
+ .map(|p| p.map(Proposal::ExternalInit)),
+ );
+
+ #[cfg(feature = "psk")]
+ let res = res.chain(self.psks.into_iter().map(|p| p.map(Proposal::Psk)));
+
+ #[cfg(feature = "by_ref_proposal")]
+ let res = res.chain(self.updates.into_iter().map(|p| p.map(Proposal::Update)));
+
+ res.chain(
+ self.additions
+ .into_iter()
+ .map(|p| p.map(|p| Proposal::Add(alloc::boxed::Box::new(p)))),
+ )
+ .chain(self.removals.into_iter().map(|p| p.map(Proposal::Remove)))
+ .chain(
+ self.reinitializations
+ .into_iter()
+ .map(|p| p.map(Proposal::ReInit)),
+ )
+ .chain(
+ self.group_context_extensions
+ .into_iter()
+ .map(|p| p.map(Proposal::GroupContextExtensions)),
+ )
+ }
+
+ pub(crate) fn into_proposals_or_refs(self) -> Vec<ProposalOrRef> {
+ self.into_proposals()
+ .filter_map(|p| match p.source {
+ ProposalSource::ByValue => Some(ProposalOrRef::Proposal(Box::new(p.proposal))),
+ #[cfg(feature = "by_ref_proposal")]
+ ProposalSource::ByReference(reference) => Some(ProposalOrRef::Reference(reference)),
+ _ => None,
+ })
+ .collect()
+ }
+
+ /// Add proposals in the bundle.
+ pub fn add_proposals(&self) -> &[ProposalInfo<AddProposal>] {
+ &self.additions
+ }
+
+ /// Update proposals in the bundle.
+ #[cfg(feature = "by_ref_proposal")]
+ pub fn update_proposals(&self) -> &[ProposalInfo<UpdateProposal>] {
+ &self.updates
+ }
+
+ /// Senders of update proposals in the bundle.
+ #[cfg(feature = "by_ref_proposal")]
+ pub fn update_proposal_senders(&self) -> &[LeafIndex] {
+ &self.update_senders
+ }
+
+ /// Remove proposals in the bundle.
+ pub fn remove_proposals(&self) -> &[ProposalInfo<RemoveProposal>] {
+ &self.removals
+ }
+
+ /// Pre-shared key proposals in the bundle.
+ #[cfg(feature = "psk")]
+ pub fn psk_proposals(&self) -> &[ProposalInfo<PreSharedKeyProposal>] {
+ &self.psks
+ }
+
+ /// Reinit proposals in the bundle.
+ pub fn reinit_proposals(&self) -> &[ProposalInfo<ReInitProposal>] {
+ &self.reinitializations
+ }
+
+ /// External init proposals in the bundle.
+ pub fn external_init_proposals(&self) -> &[ProposalInfo<ExternalInit>] {
+ &self.external_initializations
+ }
+
+ /// Group context extension proposals in the bundle.
+ pub fn group_context_ext_proposals(&self) -> &[ProposalInfo<ExtensionList>] {
+ &self.group_context_extensions
+ }
+
+ /// Custom proposals in the bundle.
+ #[cfg(feature = "custom_proposal")]
+ pub fn custom_proposals(&self) -> &[ProposalInfo<CustomProposal>] {
+ &self.custom_proposals
+ }
+
+ pub(crate) fn group_context_extensions_proposal(&self) -> Option<&ProposalInfo<ExtensionList>> {
+ self.group_context_extensions.first()
+ }
+
+ /// Custom proposal types that are in use within this bundle.
+ #[cfg(feature = "custom_proposal")]
+ pub fn custom_proposal_types(&self) -> impl Iterator<Item = ProposalType> + '_ {
+ #[cfg(feature = "std")]
+ let res = self
+ .custom_proposals
+ .iter()
+ .map(|v| v.proposal.proposal_type())
+ .unique();
+
+ #[cfg(not(feature = "std"))]
+ let res = self
+ .custom_proposals
+ .iter()
+ .map(|v| v.proposal.proposal_type())
+ .collect::<alloc::collections::BTreeSet<_>>()
+ .into_iter();
+
+ res
+ }
+
+ /// Standard proposal types that are in use within this bundle.
+ pub fn proposal_types(&self) -> impl Iterator<Item = ProposalType> + '_ {
+ let res = (!self.additions.is_empty())
+ .then_some(ProposalType::ADD)
+ .into_iter()
+ .chain((!self.removals.is_empty()).then_some(ProposalType::REMOVE))
+ .chain((!self.reinitializations.is_empty()).then_some(ProposalType::RE_INIT));
+
+ #[cfg(feature = "by_ref_proposal")]
+ let res = res.chain((!self.updates.is_empty()).then_some(ProposalType::UPDATE));
+
+ #[cfg(feature = "psk")]
+ let res = res.chain((!self.psks.is_empty()).then_some(ProposalType::PSK));
+
+ let res = res.chain(
+ (!self.external_initializations.is_empty()).then_some(ProposalType::EXTERNAL_INIT),
+ );
+
+ #[cfg(not(feature = "custom_proposal"))]
+ return res.chain(
+ (!self.group_context_extensions.is_empty())
+ .then_some(ProposalType::GROUP_CONTEXT_EXTENSIONS),
+ );
+
+ #[cfg(feature = "custom_proposal")]
+ return res
+ .chain(
+ (!self.group_context_extensions.is_empty())
+ .then_some(ProposalType::GROUP_CONTEXT_EXTENSIONS),
+ )
+ .chain(self.custom_proposal_types());
+ }
+}
+
+impl FromIterator<(Proposal, Sender, ProposalSource)> for ProposalBundle {
+ fn from_iter<I>(iter: I) -> Self
+ where
+ I: IntoIterator<Item = (Proposal, Sender, ProposalSource)>,
+ {
+ let mut bundle = ProposalBundle::default();
+ for (proposal, sender, source) in iter {
+ bundle.add(proposal, sender, source);
+ }
+ bundle
+ }
+}
+
+#[cfg(feature = "by_ref_proposal")]
+impl<'a> FromIterator<(&'a ProposalRef, &'a CachedProposal)> for ProposalBundle {
+ fn from_iter<I>(iter: I) -> Self
+ where
+ I: IntoIterator<Item = (&'a ProposalRef, &'a CachedProposal)>,
+ {
+ iter.into_iter()
+ .map(|(r, p)| {
+ (
+ p.proposal.clone(),
+ p.sender,
+ ProposalSource::ByReference(r.clone()),
+ )
+ })
+ .collect()
+ }
+}
+
+#[cfg(feature = "by_ref_proposal")]
+impl<'a> FromIterator<&'a (ProposalRef, CachedProposal)> for ProposalBundle {
+ fn from_iter<I>(iter: I) -> Self
+ where
+ I: IntoIterator<Item = &'a (ProposalRef, CachedProposal)>,
+ {
+ iter.into_iter().map(|pair| (&pair.0, &pair.1)).collect()
+ }
+}
+
+#[cfg_attr(
+ all(feature = "ffi", not(test)),
+ safer_ffi_gen::ffi_type(clone, opaque)
+)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+#[derive(Clone, Debug, PartialEq)]
+pub enum ProposalSource {
+ ByValue,
+ #[cfg(feature = "by_ref_proposal")]
+ ByReference(ProposalRef),
+ Local,
+}
+
+#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::ffi_type(opaque))]
+#[derive(Clone, Debug, PartialEq)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+#[non_exhaustive]
+/// Proposal description used as input to a
+/// [`MlsRules`](crate::MlsRules).
+pub struct ProposalInfo<T> {
+ /// The underlying proposal value.
+ pub proposal: T,
+ /// The sender of this proposal.
+ pub sender: Sender,
+ /// The source of the proposal.
+ pub source: ProposalSource,
+}
+
+#[cfg_attr(all(feature = "ffi", not(test)), ::safer_ffi_gen::safer_ffi_gen)]
+impl<T> ProposalInfo<T> {
+ /// Create a new ProposalInfo.
+ ///
+ /// The resulting value will be either transmitted with a commit or
+ /// locally injected into a commit resolution depending on the
+ /// `can_transmit` flag.
+ ///
+ /// This function is useful when implementing custom
+ /// [`MlsRules`](crate::MlsRules).
+ #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen_ignore)]
+ pub fn new(proposal: T, sender: Sender, can_transmit: bool) -> Self {
+ let source = if can_transmit {
+ ProposalSource::ByValue
+ } else {
+ ProposalSource::Local
+ };
+
+ ProposalInfo {
+ proposal,
+ sender,
+ source,
+ }
+ }
+
+ #[cfg(all(feature = "ffi", not(test)))]
+ pub fn sender(&self) -> &Sender {
+ &self.sender
+ }
+
+ #[cfg(all(feature = "ffi", not(test)))]
+ pub fn source(&self) -> &ProposalSource {
+ &self.source
+ }
+
+ #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen_ignore)]
+ pub fn map<U, F>(self, f: F) -> ProposalInfo<U>
+ where
+ F: FnOnce(T) -> U,
+ {
+ ProposalInfo {
+ proposal: f(self.proposal),
+ sender: self.sender,
+ source: self.source,
+ }
+ }
+
+ #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen_ignore)]
+ pub fn as_ref(&self) -> ProposalInfo<&T> {
+ ProposalInfo {
+ proposal: &self.proposal,
+ sender: self.sender,
+ source: self.source.clone(),
+ }
+ }
+
+ #[inline(always)]
+ pub fn is_by_value(&self) -> bool {
+ self.source == ProposalSource::ByValue
+ }
+
+ #[inline(always)]
+ pub fn is_by_reference(&self) -> bool {
+ !self.is_by_value()
+ }
+
+ /// The [`ProposalRef`] of this proposal if its source is [`ProposalSource::ByReference`]
+ #[cfg(feature = "by_ref_proposal")]
+ pub fn proposal_ref(&self) -> Option<&ProposalRef> {
+ match self.source {
+ ProposalSource::ByReference(ref reference) => Some(reference),
+ _ => None,
+ }
+ }
+}
+
+#[cfg(all(feature = "ffi", not(test)))]
+safer_ffi_gen::specialize!(ProposalInfoFfi = ProposalInfo<Proposal>);
+
+pub trait Proposable: Sized {
+ const TYPE: ProposalType;
+
+ fn filter(bundle: &ProposalBundle) -> &[ProposalInfo<Self>];
+ fn remove(bundle: &mut ProposalBundle, index: usize);
+ fn retain<F>(bundle: &mut ProposalBundle, keep: F)
+ where
+ F: FnMut(&ProposalInfo<Self>) -> bool;
+}
+
+macro_rules! impl_proposable {
+ ($ty:ty, $proposal_type:ident, $field:ident) => {
+ impl Proposable for $ty {
+ const TYPE: ProposalType = ProposalType::$proposal_type;
+
+ fn filter(bundle: &ProposalBundle) -> &[ProposalInfo<Self>] {
+ &bundle.$field
+ }
+
+ fn remove(bundle: &mut ProposalBundle, index: usize) {
+ if index < bundle.$field.len() {
+ bundle.$field.remove(index);
+ }
+ }
+
+ fn retain<F>(bundle: &mut ProposalBundle, keep: F)
+ where
+ F: FnMut(&ProposalInfo<Self>) -> bool,
+ {
+ bundle.$field.retain(keep);
+ }
+ }
+ };
+}
+
+impl_proposable!(AddProposal, ADD, additions);
+#[cfg(feature = "by_ref_proposal")]
+impl_proposable!(UpdateProposal, UPDATE, updates);
+impl_proposable!(RemoveProposal, REMOVE, removals);
+#[cfg(feature = "psk")]
+impl_proposable!(PreSharedKeyProposal, PSK, psks);
+impl_proposable!(ReInitProposal, RE_INIT, reinitializations);
+impl_proposable!(ExternalInit, EXTERNAL_INIT, external_initializations);
+impl_proposable!(
+ ExtensionList,
+ GROUP_CONTEXT_EXTENSIONS,
+ group_context_extensions
+);
diff --git a/src/group/proposal_filter/filtering.rs b/src/group/proposal_filter/filtering.rs
new file mode 100644
index 0000000..8e67ff5
--- /dev/null
+++ b/src/group/proposal_filter/filtering.rs
@@ -0,0 +1,580 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use crate::{
+ client::MlsError,
+ group::{
+ proposal::ReInitProposal,
+ proposal_filter::{ProposalBundle, ProposalInfo},
+ AddProposal, ProposalType, RemoveProposal, Sender, UpdateProposal,
+ },
+ iter::wrap_iter,
+ protocol_version::ProtocolVersion,
+ time::MlsTime,
+ tree_kem::{
+ leaf_node_validator::{LeafNodeValidator, ValidationContext},
+ node::LeafIndex,
+ TreeKemPublic,
+ },
+ CipherSuiteProvider, ExtensionList,
+};
+
+use super::filtering_common::{filter_out_invalid_psks, ApplyProposalsOutput, ProposalApplier};
+
+#[cfg(feature = "by_ref_proposal")]
+use crate::extension::ExternalSendersExt;
+
+use alloc::vec::Vec;
+use mls_rs_core::{error::IntoAnyError, identity::IdentityProvider, psk::PreSharedKeyStorage};
+
+#[cfg(any(
+ feature = "custom_proposal",
+ not(any(mls_build_async, feature = "rayon"))
+))]
+use itertools::Itertools;
+
+use crate::group::ExternalInit;
+
+#[cfg(feature = "psk")]
+use crate::group::proposal::PreSharedKeyProposal;
+
+#[cfg(all(not(mls_build_async), feature = "rayon"))]
+use {crate::iter::ParallelIteratorExt, rayon::prelude::*};
+
+#[cfg(mls_build_async)]
+use futures::{StreamExt, TryStreamExt};
+
+impl<'a, C, P, CSP> ProposalApplier<'a, C, P, CSP>
+where
+ C: IdentityProvider,
+ P: PreSharedKeyStorage,
+ CSP: CipherSuiteProvider,
+{
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(super) async fn apply_proposals_from_member(
+ &self,
+ strategy: FilterStrategy,
+ commit_sender: LeafIndex,
+ proposals: ProposalBundle,
+ commit_time: Option<MlsTime>,
+ ) -> Result<ApplyProposalsOutput, MlsError> {
+ let proposals = filter_out_invalid_proposers(strategy, proposals)?;
+
+ let mut proposals: ProposalBundle =
+ filter_out_update_for_committer(strategy, commit_sender, proposals)?;
+
+ // We ignore the strategy here because the check above ensures all updates are from members
+ proposals.update_senders = proposals
+ .updates
+ .iter()
+ .map(leaf_index_of_update_sender)
+ .collect::<Result<_, _>>()?;
+
+ let mut proposals = filter_out_removal_of_committer(strategy, commit_sender, proposals)?;
+
+ filter_out_invalid_psks(
+ strategy,
+ self.cipher_suite_provider,
+ &mut proposals,
+ self.psk_storage,
+ )
+ .await?;
+
+ #[cfg(feature = "by_ref_proposal")]
+ let proposals = filter_out_invalid_group_extensions(
+ strategy,
+ proposals,
+ self.identity_provider,
+ commit_time,
+ )
+ .await?;
+
+ let proposals = filter_out_extra_group_context_extensions(strategy, proposals)?;
+ let proposals = filter_out_invalid_reinit(strategy, proposals, self.protocol_version)?;
+ let proposals = filter_out_reinit_if_other_proposals(strategy.is_ignore(), proposals)?;
+
+ let proposals = filter_out_external_init(strategy, proposals)?;
+
+ self.apply_proposal_changes(strategy, proposals, commit_time)
+ .await
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(super) async fn apply_proposal_changes(
+ &self,
+ strategy: FilterStrategy,
+ proposals: ProposalBundle,
+ commit_time: Option<MlsTime>,
+ ) -> Result<ApplyProposalsOutput, MlsError> {
+ match proposals.group_context_extensions_proposal().cloned() {
+ Some(p) => {
+ self.apply_proposals_with_new_capabilities(strategy, proposals, p, commit_time)
+ .await
+ }
+ None => {
+ self.apply_tree_changes(
+ strategy,
+ proposals,
+ self.original_group_extensions,
+ commit_time,
+ )
+ .await
+ }
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(super) async fn apply_tree_changes(
+ &self,
+ strategy: FilterStrategy,
+ proposals: ProposalBundle,
+ group_extensions_in_use: &ExtensionList,
+ commit_time: Option<MlsTime>,
+ ) -> Result<ApplyProposalsOutput, MlsError> {
+ let mut applied_proposals = self
+ .validate_new_nodes(strategy, proposals, group_extensions_in_use, commit_time)
+ .await?;
+
+ let mut new_tree = self.original_tree.clone();
+
+ let added = new_tree
+ .batch_edit(
+ &mut applied_proposals,
+ group_extensions_in_use,
+ self.identity_provider,
+ self.cipher_suite_provider,
+ strategy.is_ignore(),
+ )
+ .await?;
+
+ let new_context_extensions = applied_proposals
+ .group_context_extensions_proposal()
+ .map(|gce| gce.proposal.clone());
+
+ Ok(ApplyProposalsOutput {
+ applied_proposals,
+ new_tree,
+ indexes_of_added_kpkgs: added,
+ external_init_index: None,
+ new_context_extensions,
+ })
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn validate_new_nodes(
+ &self,
+ strategy: FilterStrategy,
+ mut proposals: ProposalBundle,
+ group_extensions_in_use: &ExtensionList,
+ commit_time: Option<MlsTime>,
+ ) -> Result<ProposalBundle, MlsError> {
+ let leaf_node_validator = &LeafNodeValidator::new(
+ self.cipher_suite_provider,
+ self.identity_provider,
+ Some(group_extensions_in_use),
+ );
+
+ let bad_indices: Vec<_> = wrap_iter(proposals.update_proposals())
+ .zip(wrap_iter(proposals.update_proposal_senders()))
+ .enumerate()
+ .filter_map(|(i, (p, &sender_index))| async move {
+ let res = {
+ let leaf = &p.proposal.leaf_node;
+
+ let res = leaf_node_validator
+ .check_if_valid(
+ leaf,
+ ValidationContext::Update((self.group_id, *sender_index, commit_time)),
+ )
+ .await;
+
+ let old_leaf = match self.original_tree.get_leaf_node(sender_index) {
+ Ok(leaf) => leaf,
+ Err(e) => return Some(Err(e)),
+ };
+
+ let valid_successor = self
+ .identity_provider
+ .valid_successor(
+ &old_leaf.signing_identity,
+ &leaf.signing_identity,
+ group_extensions_in_use,
+ )
+ .await
+ .map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))
+ .and_then(|valid| valid.then_some(()).ok_or(MlsError::InvalidSuccessor));
+
+ res.and(valid_successor)
+ };
+
+ apply_strategy(strategy, p.is_by_reference(), res)
+ .map(|b| (!b).then_some(i))
+ .transpose()
+ })
+ .try_collect()
+ .await?;
+
+ bad_indices.into_iter().rev().for_each(|i| {
+ proposals.remove::<UpdateProposal>(i);
+ proposals.update_senders.remove(i);
+ });
+
+ let bad_indices: Vec<_> = wrap_iter(proposals.add_proposals())
+ .enumerate()
+ .filter_map(|(i, p)| async move {
+ let res = self
+ .validate_new_node(leaf_node_validator, &p.proposal.key_package, commit_time)
+ .await;
+
+ apply_strategy(strategy, p.is_by_reference(), res)
+ .map(|b| (!b).then_some(i))
+ .transpose()
+ })
+ .try_collect()
+ .await?;
+
+ bad_indices
+ .into_iter()
+ .rev()
+ .for_each(|i| proposals.remove::<AddProposal>(i));
+
+ Ok(proposals)
+ }
+}
+
+#[derive(Clone, Copy, Debug)]
+pub enum FilterStrategy {
+ IgnoreByRef,
+ IgnoreNone,
+}
+
+impl FilterStrategy {
+ pub(super) fn ignore(self, by_ref: bool) -> bool {
+ match self {
+ FilterStrategy::IgnoreByRef => by_ref,
+ FilterStrategy::IgnoreNone => false,
+ }
+ }
+
+ fn is_ignore(self) -> bool {
+ match self {
+ FilterStrategy::IgnoreByRef => true,
+ FilterStrategy::IgnoreNone => false,
+ }
+ }
+}
+
+pub(crate) fn apply_strategy(
+ strategy: FilterStrategy,
+ by_ref: bool,
+ r: Result<(), MlsError>,
+) -> Result<bool, MlsError> {
+ r.map(|_| true)
+ .or_else(|error| strategy.ignore(by_ref).then_some(false).ok_or(error))
+}
+
+fn filter_out_update_for_committer(
+ strategy: FilterStrategy,
+ commit_sender: LeafIndex,
+ mut proposals: ProposalBundle,
+) -> Result<ProposalBundle, MlsError> {
+ proposals.retain_by_type::<UpdateProposal, _, _>(|p| {
+ apply_strategy(
+ strategy,
+ p.is_by_reference(),
+ (p.sender != Sender::Member(*commit_sender))
+ .then_some(())
+ .ok_or(MlsError::InvalidCommitSelfUpdate),
+ )
+ })?;
+ Ok(proposals)
+}
+
+fn filter_out_removal_of_committer(
+ strategy: FilterStrategy,
+ commit_sender: LeafIndex,
+ mut proposals: ProposalBundle,
+) -> Result<ProposalBundle, MlsError> {
+ proposals.retain_by_type::<RemoveProposal, _, _>(|p| {
+ apply_strategy(
+ strategy,
+ p.is_by_reference(),
+ (p.proposal.to_remove != commit_sender)
+ .then_some(())
+ .ok_or(MlsError::CommitterSelfRemoval),
+ )
+ })?;
+ Ok(proposals)
+}
+
+#[cfg(feature = "by_ref_proposal")]
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+async fn filter_out_invalid_group_extensions<C>(
+ strategy: FilterStrategy,
+ mut proposals: ProposalBundle,
+ identity_provider: &C,
+ commit_time: Option<MlsTime>,
+) -> Result<ProposalBundle, MlsError>
+where
+ C: IdentityProvider,
+{
+ let mut bad_indices = Vec::new();
+
+ for (i, p) in proposals.by_type::<ExtensionList>().enumerate() {
+ let ext = p.proposal.get_as::<ExternalSendersExt>();
+
+ let res = match ext {
+ Ok(None) => Ok(()),
+ Ok(Some(extension)) => extension
+ .verify_all(identity_provider, commit_time, &p.proposal)
+ .await
+ .map_err(|e| MlsError::IdentityProviderError(e.into_any_error())),
+ Err(e) => Err(MlsError::from(e)),
+ };
+
+ if !apply_strategy(strategy, p.is_by_reference(), res)? {
+ bad_indices.push(i);
+ }
+ }
+
+ bad_indices
+ .into_iter()
+ .rev()
+ .for_each(|i| proposals.remove::<ExtensionList>(i));
+
+ Ok(proposals)
+}
+
+fn filter_out_extra_group_context_extensions(
+ strategy: FilterStrategy,
+ mut proposals: ProposalBundle,
+) -> Result<ProposalBundle, MlsError> {
+ let mut found = false;
+
+ proposals.retain_by_type::<ExtensionList, _, _>(|p| {
+ apply_strategy(
+ strategy,
+ p.is_by_reference(),
+ (!core::mem::replace(&mut found, true))
+ .then_some(())
+ .ok_or(MlsError::MoreThanOneGroupContextExtensionsProposal),
+ )
+ })?;
+
+ Ok(proposals)
+}
+
+fn filter_out_invalid_reinit(
+ strategy: FilterStrategy,
+ mut proposals: ProposalBundle,
+ protocol_version: ProtocolVersion,
+) -> Result<ProposalBundle, MlsError> {
+ proposals.retain_by_type::<ReInitProposal, _, _>(|p| {
+ apply_strategy(
+ strategy,
+ p.is_by_reference(),
+ (p.proposal.version >= protocol_version)
+ .then_some(())
+ .ok_or(MlsError::InvalidProtocolVersionInReInit),
+ )
+ })?;
+
+ Ok(proposals)
+}
+
+fn filter_out_reinit_if_other_proposals(
+ filter: bool,
+ mut proposals: ProposalBundle,
+) -> Result<ProposalBundle, MlsError> {
+ let proposal_count = proposals.length();
+
+ let has_reinit_and_other_proposal =
+ !proposals.reinit_proposals().is_empty() && proposal_count != 1;
+
+ if has_reinit_and_other_proposal {
+ let any_by_val = proposals.reinit_proposals().iter().any(|p| p.is_by_value());
+
+ if any_by_val || !filter {
+ return Err(MlsError::OtherProposalWithReInit);
+ }
+
+ let has_other_proposal_type = proposal_count > proposals.reinit_proposals().len();
+
+ if has_other_proposal_type {
+ proposals.reinitializations = Vec::new();
+ } else {
+ proposals.reinitializations.truncate(1);
+ }
+ }
+
+ Ok(proposals)
+}
+
+fn filter_out_external_init(
+ strategy: FilterStrategy,
+ mut proposals: ProposalBundle,
+) -> Result<ProposalBundle, MlsError> {
+ proposals.retain_by_type::<ExternalInit, _, _>(|p| {
+ apply_strategy(
+ strategy,
+ p.is_by_reference(),
+ Err(MlsError::InvalidProposalTypeForSender),
+ )
+ })?;
+
+ Ok(proposals)
+}
+
+pub(crate) fn proposer_can_propose(
+ proposer: Sender,
+ proposal_type: ProposalType,
+ by_ref: bool,
+) -> Result<(), MlsError> {
+ let can_propose = match (proposer, by_ref) {
+ (Sender::Member(_), false) => matches!(
+ proposal_type,
+ ProposalType::ADD
+ | ProposalType::REMOVE
+ | ProposalType::PSK
+ | ProposalType::RE_INIT
+ | ProposalType::GROUP_CONTEXT_EXTENSIONS
+ ),
+ (Sender::Member(_), true) => matches!(
+ proposal_type,
+ ProposalType::ADD
+ | ProposalType::UPDATE
+ | ProposalType::REMOVE
+ | ProposalType::PSK
+ | ProposalType::RE_INIT
+ | ProposalType::GROUP_CONTEXT_EXTENSIONS
+ ),
+ #[cfg(feature = "by_ref_proposal")]
+ (Sender::External(_), false) => false,
+ #[cfg(feature = "by_ref_proposal")]
+ (Sender::External(_), true) => matches!(
+ proposal_type,
+ ProposalType::ADD
+ | ProposalType::REMOVE
+ | ProposalType::RE_INIT
+ | ProposalType::PSK
+ | ProposalType::GROUP_CONTEXT_EXTENSIONS
+ ),
+ (Sender::NewMemberCommit, false) => matches!(
+ proposal_type,
+ ProposalType::REMOVE | ProposalType::PSK | ProposalType::EXTERNAL_INIT
+ ),
+ (Sender::NewMemberCommit, true) => false,
+ (Sender::NewMemberProposal, false) => false,
+ (Sender::NewMemberProposal, true) => matches!(proposal_type, ProposalType::ADD),
+ };
+
+ can_propose
+ .then_some(())
+ .ok_or(MlsError::InvalidProposalTypeForSender)
+}
+
+pub(crate) fn filter_out_invalid_proposers(
+ strategy: FilterStrategy,
+ mut proposals: ProposalBundle,
+) -> Result<ProposalBundle, MlsError> {
+ for i in (0..proposals.add_proposals().len()).rev() {
+ let p = &proposals.add_proposals()[i];
+ let res = proposer_can_propose(p.sender, ProposalType::ADD, p.is_by_reference());
+
+ if !apply_strategy(strategy, p.is_by_reference(), res)? {
+ proposals.remove::<AddProposal>(i);
+ }
+ }
+
+ for i in (0..proposals.update_proposals().len()).rev() {
+ let p = &proposals.update_proposals()[i];
+ let res = proposer_can_propose(p.sender, ProposalType::UPDATE, p.is_by_reference());
+
+ if !apply_strategy(strategy, p.is_by_reference(), res)? {
+ proposals.remove::<UpdateProposal>(i);
+ proposals.update_senders.remove(i);
+ }
+ }
+
+ for i in (0..proposals.remove_proposals().len()).rev() {
+ let p = &proposals.remove_proposals()[i];
+ let res = proposer_can_propose(p.sender, ProposalType::REMOVE, p.is_by_reference());
+
+ if !apply_strategy(strategy, p.is_by_reference(), res)? {
+ proposals.remove::<RemoveProposal>(i);
+ }
+ }
+
+ #[cfg(feature = "psk")]
+ for i in (0..proposals.psk_proposals().len()).rev() {
+ let p = &proposals.psk_proposals()[i];
+ let res = proposer_can_propose(p.sender, ProposalType::PSK, p.is_by_reference());
+
+ if !apply_strategy(strategy, p.is_by_reference(), res)? {
+ proposals.remove::<PreSharedKeyProposal>(i);
+ }
+ }
+
+ for i in (0..proposals.reinit_proposals().len()).rev() {
+ let p = &proposals.reinit_proposals()[i];
+ let res = proposer_can_propose(p.sender, ProposalType::RE_INIT, p.is_by_reference());
+
+ if !apply_strategy(strategy, p.is_by_reference(), res)? {
+ proposals.remove::<ReInitProposal>(i);
+ }
+ }
+
+ for i in (0..proposals.external_init_proposals().len()).rev() {
+ let p = &proposals.external_init_proposals()[i];
+ let res = proposer_can_propose(p.sender, ProposalType::EXTERNAL_INIT, p.is_by_reference());
+
+ if !apply_strategy(strategy, p.is_by_reference(), res)? {
+ proposals.remove::<ExternalInit>(i);
+ }
+ }
+
+ for i in (0..proposals.group_context_ext_proposals().len()).rev() {
+ let p = &proposals.group_context_ext_proposals()[i];
+ let gce_type = ProposalType::GROUP_CONTEXT_EXTENSIONS;
+ let res = proposer_can_propose(p.sender, gce_type, p.is_by_reference());
+
+ if !apply_strategy(strategy, p.is_by_reference(), res)? {
+ proposals.remove::<ExtensionList>(i);
+ }
+ }
+
+ Ok(proposals)
+}
+
+fn leaf_index_of_update_sender(p: &ProposalInfo<UpdateProposal>) -> Result<LeafIndex, MlsError> {
+ match p.sender {
+ Sender::Member(i) => Ok(LeafIndex(i)),
+ _ => Err(MlsError::InvalidProposalTypeForSender),
+ }
+}
+
+#[cfg(feature = "custom_proposal")]
+pub(super) fn filter_out_unsupported_custom_proposals(
+ proposals: &mut ProposalBundle,
+ tree: &TreeKemPublic,
+ strategy: FilterStrategy,
+) -> Result<(), MlsError> {
+ let supported_types = proposals
+ .custom_proposal_types()
+ .filter(|t| tree.can_support_proposal(*t))
+ .collect_vec();
+
+ proposals.retain_custom(|p| {
+ let proposal_type = p.proposal.proposal_type();
+
+ apply_strategy(
+ strategy,
+ p.is_by_reference(),
+ supported_types
+ .contains(&proposal_type)
+ .then_some(())
+ .ok_or(MlsError::UnsupportedCustomProposal(proposal_type)),
+ )
+ })
+}
diff --git a/src/group/proposal_filter/filtering_common.rs b/src/group/proposal_filter/filtering_common.rs
new file mode 100644
index 0000000..278c0de
--- /dev/null
+++ b/src/group/proposal_filter/filtering_common.rs
@@ -0,0 +1,579 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use crate::{
+ client::MlsError,
+ group::{proposal_filter::ProposalBundle, Sender},
+ key_package::{validate_key_package_properties, KeyPackage},
+ protocol_version::ProtocolVersion,
+ time::MlsTime,
+ tree_kem::{
+ leaf_node_validator::{LeafNodeValidator, ValidationContext},
+ node::LeafIndex,
+ TreeKemPublic,
+ },
+ CipherSuiteProvider, ExtensionList,
+};
+
+use crate::tree_kem::leaf_node::LeafNode;
+
+use super::ProposalInfo;
+
+use crate::extension::{MlsExtension, RequiredCapabilitiesExt};
+
+#[cfg(feature = "by_ref_proposal")]
+use crate::extension::ExternalSendersExt;
+
+use mls_rs_core::error::IntoAnyError;
+
+use alloc::vec::Vec;
+use mls_rs_core::{identity::IdentityProvider, psk::PreSharedKeyStorage};
+
+use crate::group::{ExternalInit, ProposalType, RemoveProposal};
+
+#[cfg(all(feature = "by_ref_proposal", feature = "psk"))]
+use crate::group::proposal::PreSharedKeyProposal;
+
+#[cfg(feature = "psk")]
+use crate::group::{JustPreSharedKeyID, ResumptionPSKUsage, ResumptionPsk};
+
+#[cfg(all(feature = "std", feature = "psk"))]
+use std::collections::HashSet;
+
+#[cfg(feature = "by_ref_proposal")]
+use super::filtering::{apply_strategy, filter_out_invalid_proposers, FilterStrategy};
+
+#[cfg(feature = "custom_proposal")]
+use super::filtering::filter_out_unsupported_custom_proposals;
+
+#[derive(Debug)]
+pub(crate) struct ProposalApplier<'a, C, P, CSP> {
+ pub original_tree: &'a TreeKemPublic,
+ pub protocol_version: ProtocolVersion,
+ pub cipher_suite_provider: &'a CSP,
+ pub original_group_extensions: &'a ExtensionList,
+ pub external_leaf: Option<&'a LeafNode>,
+ pub identity_provider: &'a C,
+ pub psk_storage: &'a P,
+ #[cfg(feature = "by_ref_proposal")]
+ pub group_id: &'a [u8],
+}
+
+#[derive(Debug)]
+pub(crate) struct ApplyProposalsOutput {
+ pub(crate) new_tree: TreeKemPublic,
+ pub(crate) indexes_of_added_kpkgs: Vec<LeafIndex>,
+ pub(crate) external_init_index: Option<LeafIndex>,
+ #[cfg(feature = "by_ref_proposal")]
+ pub(crate) applied_proposals: ProposalBundle,
+ pub(crate) new_context_extensions: Option<ExtensionList>,
+}
+
+impl<'a, C, P, CSP> ProposalApplier<'a, C, P, CSP>
+where
+ C: IdentityProvider,
+ P: PreSharedKeyStorage,
+ CSP: CipherSuiteProvider,
+{
+ #[allow(clippy::too_many_arguments)]
+ pub(crate) fn new(
+ original_tree: &'a TreeKemPublic,
+ protocol_version: ProtocolVersion,
+ cipher_suite_provider: &'a CSP,
+ original_group_extensions: &'a ExtensionList,
+ external_leaf: Option<&'a LeafNode>,
+ identity_provider: &'a C,
+ psk_storage: &'a P,
+ #[cfg(feature = "by_ref_proposal")] group_id: &'a [u8],
+ ) -> Self {
+ Self {
+ original_tree,
+ protocol_version,
+ cipher_suite_provider,
+ original_group_extensions,
+ external_leaf,
+ identity_provider,
+ psk_storage,
+ #[cfg(feature = "by_ref_proposal")]
+ group_id,
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn apply_proposals(
+ &self,
+ #[cfg(feature = "by_ref_proposal")] strategy: FilterStrategy,
+ commit_sender: &Sender,
+ #[cfg(not(feature = "by_ref_proposal"))] proposals: &ProposalBundle,
+ #[cfg(feature = "by_ref_proposal")] proposals: ProposalBundle,
+ commit_time: Option<MlsTime>,
+ ) -> Result<ApplyProposalsOutput, MlsError> {
+ let output = match commit_sender {
+ Sender::Member(sender) => {
+ self.apply_proposals_from_member(
+ #[cfg(feature = "by_ref_proposal")]
+ strategy,
+ LeafIndex(*sender),
+ proposals,
+ commit_time,
+ )
+ .await
+ }
+ Sender::NewMemberCommit => {
+ self.apply_proposals_from_new_member(proposals, commit_time)
+ .await
+ }
+ #[cfg(feature = "by_ref_proposal")]
+ Sender::External(_) => Err(MlsError::ExternalSenderCannotCommit),
+ #[cfg(feature = "by_ref_proposal")]
+ Sender::NewMemberProposal => Err(MlsError::ExternalSenderCannotCommit),
+ }?;
+
+ #[cfg(all(feature = "by_ref_proposal", feature = "custom_proposal"))]
+ let mut output = output;
+
+ #[cfg(all(feature = "by_ref_proposal", feature = "custom_proposal"))]
+ filter_out_unsupported_custom_proposals(
+ &mut output.applied_proposals,
+ &output.new_tree,
+ strategy,
+ )?;
+
+ #[cfg(all(not(feature = "by_ref_proposal"), feature = "custom_proposal"))]
+ filter_out_unsupported_custom_proposals(proposals, &output.new_tree)?;
+
+ Ok(output)
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ // The lint below is triggered by the `proposals` parameter which may or may not be a borrow.
+ #[allow(clippy::needless_borrow)]
+ async fn apply_proposals_from_new_member(
+ &self,
+ #[cfg(not(feature = "by_ref_proposal"))] proposals: &ProposalBundle,
+ #[cfg(feature = "by_ref_proposal")] proposals: ProposalBundle,
+ commit_time: Option<MlsTime>,
+ ) -> Result<ApplyProposalsOutput, MlsError> {
+ let external_leaf = self
+ .external_leaf
+ .ok_or(MlsError::ExternalCommitMustHaveNewLeaf)?;
+
+ ensure_exactly_one_external_init(&proposals)?;
+
+ ensure_at_most_one_removal_for_self(
+ &proposals,
+ external_leaf,
+ self.original_tree,
+ self.identity_provider,
+ self.original_group_extensions,
+ )
+ .await?;
+
+ ensure_proposals_in_external_commit_are_allowed(&proposals)?;
+ ensure_no_proposal_by_ref(&proposals)?;
+
+ #[cfg(feature = "by_ref_proposal")]
+ let mut proposals = filter_out_invalid_proposers(FilterStrategy::IgnoreNone, proposals)?;
+
+ filter_out_invalid_psks(
+ #[cfg(feature = "by_ref_proposal")]
+ FilterStrategy::IgnoreNone,
+ self.cipher_suite_provider,
+ #[cfg(feature = "by_ref_proposal")]
+ &mut proposals,
+ #[cfg(not(feature = "by_ref_proposal"))]
+ proposals,
+ self.psk_storage,
+ )
+ .await?;
+
+ let mut output = self
+ .apply_proposal_changes(
+ #[cfg(feature = "by_ref_proposal")]
+ FilterStrategy::IgnoreNone,
+ proposals,
+ commit_time,
+ )
+ .await?;
+
+ output.external_init_index = Some(
+ insert_external_leaf(
+ &mut output.new_tree,
+ external_leaf.clone(),
+ self.identity_provider,
+ self.original_group_extensions,
+ )
+ .await?,
+ );
+
+ Ok(output)
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(super) async fn apply_proposals_with_new_capabilities(
+ &self,
+ #[cfg(feature = "by_ref_proposal")] strategy: FilterStrategy,
+ #[cfg(not(feature = "by_ref_proposal"))] proposals: &ProposalBundle,
+ #[cfg(feature = "by_ref_proposal")] proposals: ProposalBundle,
+ group_context_extensions_proposal: ProposalInfo<ExtensionList>,
+ commit_time: Option<MlsTime>,
+ ) -> Result<ApplyProposalsOutput, MlsError>
+ where
+ C: IdentityProvider,
+ {
+ #[cfg(feature = "by_ref_proposal")]
+ let mut proposals_clone = proposals.clone();
+
+ // Apply adds, updates etc. in the context of new extensions
+ let output = self
+ .apply_tree_changes(
+ #[cfg(feature = "by_ref_proposal")]
+ strategy,
+ proposals,
+ &group_context_extensions_proposal.proposal,
+ commit_time,
+ )
+ .await?;
+
+ // Verify that capabilities and extensions are supported after modifications.
+ // TODO: The newly inserted nodes have already been validated by `apply_tree_changes`
+ // above. We should investigate if there is an easy way to avoid the double check.
+ let must_check = group_context_extensions_proposal
+ .proposal
+ .has_extension(RequiredCapabilitiesExt::extension_type());
+
+ #[cfg(feature = "by_ref_proposal")]
+ let must_check = must_check
+ || group_context_extensions_proposal
+ .proposal
+ .has_extension(ExternalSendersExt::extension_type());
+
+ let new_capabilities_supported = if must_check {
+ let leaf_validator = LeafNodeValidator::new(
+ self.cipher_suite_provider,
+ self.identity_provider,
+ Some(&group_context_extensions_proposal.proposal),
+ );
+
+ output
+ .new_tree
+ .non_empty_leaves()
+ .try_for_each(|(_, leaf)| {
+ leaf_validator.validate_required_capabilities(leaf)?;
+
+ #[cfg(feature = "by_ref_proposal")]
+ leaf_validator.validate_external_senders_ext_credentials(leaf)?;
+
+ Ok(())
+ })
+ } else {
+ Ok(())
+ };
+
+ let new_extensions_supported = group_context_extensions_proposal
+ .proposal
+ .iter()
+ .map(|extension| extension.extension_type)
+ .filter(|&ext_type| !ext_type.is_default())
+ .find(|ext_type| {
+ !output
+ .new_tree
+ .non_empty_leaves()
+ .all(|(_, leaf)| leaf.capabilities.extensions.contains(ext_type))
+ })
+ .map_or(Ok(()), |ext| Err(MlsError::UnsupportedGroupExtension(ext)));
+
+ #[cfg(not(feature = "by_ref_proposal"))]
+ {
+ new_capabilities_supported.and(new_extensions_supported)?;
+ Ok(output)
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ // If extensions are good, return `Ok`. If not and the strategy is to filter, remove the group
+ // context extensions proposal and try applying all proposals again in the context of the old
+ // extensions. Else, return an error.
+ match new_capabilities_supported.and(new_extensions_supported) {
+ Ok(()) => Ok(output),
+ Err(e) => {
+ if strategy.ignore(group_context_extensions_proposal.is_by_reference()) {
+ proposals_clone.group_context_extensions.clear();
+
+ self.apply_tree_changes(
+ strategy,
+ proposals_clone,
+ self.original_group_extensions,
+ commit_time,
+ )
+ .await
+ } else {
+ Err(e)
+ }
+ }
+ }
+ }
+
+ #[cfg(any(mls_build_async, not(feature = "rayon")))]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn validate_new_node<Ip: IdentityProvider, Cp: CipherSuiteProvider>(
+ &self,
+ leaf_node_validator: &LeafNodeValidator<'_, Ip, Cp>,
+ key_package: &KeyPackage,
+ commit_time: Option<MlsTime>,
+ ) -> Result<(), MlsError> {
+ leaf_node_validator
+ .check_if_valid(&key_package.leaf_node, ValidationContext::Add(commit_time))
+ .await?;
+
+ validate_key_package_properties(
+ key_package,
+ self.protocol_version,
+ self.cipher_suite_provider,
+ )
+ .await
+ }
+
+ #[cfg(all(not(mls_build_async), feature = "rayon"))]
+ pub fn validate_new_node<Ip: IdentityProvider, Cp: CipherSuiteProvider>(
+ &self,
+ leaf_node_validator: &LeafNodeValidator<'_, Ip, Cp>,
+ key_package: &KeyPackage,
+ commit_time: Option<MlsTime>,
+ ) -> Result<(), MlsError> {
+ let (a, b) = rayon::join(
+ || {
+ leaf_node_validator
+ .check_if_valid(&key_package.leaf_node, ValidationContext::Add(commit_time))
+ },
+ || {
+ validate_key_package_properties(
+ key_package,
+ self.protocol_version,
+ self.cipher_suite_provider,
+ )
+ },
+ );
+ a?;
+ b
+ }
+}
+
+#[cfg(feature = "psk")]
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(crate) async fn filter_out_invalid_psks<P, CP>(
+ #[cfg(feature = "by_ref_proposal")] strategy: FilterStrategy,
+ cipher_suite_provider: &CP,
+ #[cfg(not(feature = "by_ref_proposal"))] proposals: &ProposalBundle,
+ #[cfg(feature = "by_ref_proposal")] proposals: &mut ProposalBundle,
+ psk_storage: &P,
+) -> Result<(), MlsError>
+where
+ P: PreSharedKeyStorage,
+ CP: CipherSuiteProvider,
+{
+ let kdf_extract_size = cipher_suite_provider.kdf_extract_size();
+
+ #[cfg(feature = "std")]
+ let mut ids_seen = HashSet::new();
+
+ #[cfg(not(feature = "std"))]
+ let mut ids_seen = Vec::new();
+
+ #[cfg(feature = "by_ref_proposal")]
+ let mut bad_indices = Vec::new();
+
+ for i in 0..proposals.psk_proposals().len() {
+ let p = &proposals.psks[i];
+
+ let valid = matches!(
+ p.proposal.psk.key_id,
+ JustPreSharedKeyID::External(_)
+ | JustPreSharedKeyID::Resumption(ResumptionPsk {
+ usage: ResumptionPSKUsage::Application,
+ ..
+ })
+ );
+
+ let nonce_length = p.proposal.psk.psk_nonce.0.len();
+ let nonce_valid = nonce_length == kdf_extract_size;
+
+ #[cfg(feature = "std")]
+ let is_new_id = ids_seen.insert(p.proposal.psk.clone());
+
+ #[cfg(not(feature = "std"))]
+ let is_new_id = !ids_seen.contains(&p.proposal.psk);
+
+ let external_id_is_valid = match &p.proposal.psk.key_id {
+ JustPreSharedKeyID::External(id) => psk_storage
+ .contains(id)
+ .await
+ .map_err(|e| MlsError::PskStoreError(e.into_any_error()))
+ .and_then(|found| {
+ if found {
+ Ok(())
+ } else {
+ Err(MlsError::MissingRequiredPsk)
+ }
+ }),
+ JustPreSharedKeyID::Resumption(_) => Ok(()),
+ };
+
+ #[cfg(not(feature = "by_ref_proposal"))]
+ if !valid {
+ return Err(MlsError::InvalidTypeOrUsageInPreSharedKeyProposal);
+ } else if !nonce_valid {
+ return Err(MlsError::InvalidPskNonceLength);
+ } else if !is_new_id {
+ return Err(MlsError::DuplicatePskIds);
+ } else if external_id_is_valid.is_err() {
+ return external_id_is_valid;
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ {
+ let res = if !valid {
+ Err(MlsError::InvalidTypeOrUsageInPreSharedKeyProposal)
+ } else if !nonce_valid {
+ Err(MlsError::InvalidPskNonceLength)
+ } else if !is_new_id {
+ Err(MlsError::DuplicatePskIds)
+ } else {
+ external_id_is_valid
+ };
+
+ if !apply_strategy(strategy, p.is_by_reference(), res)? {
+ bad_indices.push(i)
+ }
+ }
+
+ #[cfg(not(feature = "std"))]
+ ids_seen.push(p.proposal.psk.clone());
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ bad_indices
+ .into_iter()
+ .rev()
+ .for_each(|i| proposals.remove::<PreSharedKeyProposal>(i));
+
+ Ok(())
+}
+
+#[cfg(not(feature = "psk"))]
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(crate) async fn filter_out_invalid_psks<P, CP>(
+ #[cfg(feature = "by_ref_proposal")] _: FilterStrategy,
+ _: &CP,
+ #[cfg(not(feature = "by_ref_proposal"))] _: &ProposalBundle,
+ #[cfg(feature = "by_ref_proposal")] _: &mut ProposalBundle,
+ _: &P,
+) -> Result<(), MlsError>
+where
+ P: PreSharedKeyStorage,
+ CP: CipherSuiteProvider,
+{
+ Ok(())
+}
+
+fn ensure_exactly_one_external_init(proposals: &ProposalBundle) -> Result<(), MlsError> {
+ (proposals.by_type::<ExternalInit>().count() == 1)
+ .then_some(())
+ .ok_or(MlsError::ExternalCommitMustHaveExactlyOneExternalInit)
+}
+
+/// Non-default proposal types are by default allowed. Custom MlsRules may disallow
+/// specific custom proposals in external commits
+fn ensure_proposals_in_external_commit_are_allowed(
+ proposals: &ProposalBundle,
+) -> Result<(), MlsError> {
+ let supported_default_types = [
+ ProposalType::EXTERNAL_INIT,
+ ProposalType::REMOVE,
+ ProposalType::PSK,
+ ];
+
+ let unsupported_type = proposals
+ .proposal_types()
+ .find(|ty| !supported_default_types.contains(ty) && ProposalType::DEFAULT.contains(ty));
+
+ match unsupported_type {
+ Some(kind) => Err(MlsError::InvalidProposalTypeInExternalCommit(kind)),
+ None => Ok(()),
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+async fn ensure_at_most_one_removal_for_self<C>(
+ proposals: &ProposalBundle,
+ external_leaf: &LeafNode,
+ tree: &TreeKemPublic,
+ identity_provider: &C,
+ extensions: &ExtensionList,
+) -> Result<(), MlsError>
+where
+ C: IdentityProvider,
+{
+ let mut removals = proposals.by_type::<RemoveProposal>();
+
+ match (removals.next(), removals.next()) {
+ (Some(removal), None) => {
+ ensure_removal_is_for_self(
+ &removal.proposal,
+ external_leaf,
+ tree,
+ identity_provider,
+ extensions,
+ )
+ .await
+ }
+ (Some(_), Some(_)) => Err(MlsError::ExternalCommitWithMoreThanOneRemove),
+ (None, _) => Ok(()),
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+async fn ensure_removal_is_for_self<C>(
+ removal: &RemoveProposal,
+ external_leaf: &LeafNode,
+ tree: &TreeKemPublic,
+ identity_provider: &C,
+ extensions: &ExtensionList,
+) -> Result<(), MlsError>
+where
+ C: IdentityProvider,
+{
+ let existing_signing_id = &tree.get_leaf_node(removal.to_remove)?.signing_identity;
+
+ identity_provider
+ .valid_successor(
+ existing_signing_id,
+ &external_leaf.signing_identity,
+ extensions,
+ )
+ .await
+ .map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))?
+ .then_some(())
+ .ok_or(MlsError::ExternalCommitRemovesOtherIdentity)
+}
+
+/// Non-default by-ref proposal types are by default allowed. Custom MlsRules may disallow
+/// specific custom by-ref proposals.
+fn ensure_no_proposal_by_ref(proposals: &ProposalBundle) -> Result<(), MlsError> {
+ proposals
+ .iter_proposals()
+ .all(|p| !ProposalType::DEFAULT.contains(&p.proposal.proposal_type()) || p.is_by_value())
+ .then_some(())
+ .ok_or(MlsError::OnlyMembersCanCommitProposalsByRef)
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+async fn insert_external_leaf<I: IdentityProvider>(
+ tree: &mut TreeKemPublic,
+ leaf_node: LeafNode,
+ identity_provider: &I,
+ extensions: &ExtensionList,
+) -> Result<LeafIndex, MlsError> {
+ tree.add_leaf(leaf_node, identity_provider, extensions, None)
+ .await
+}
diff --git a/src/group/proposal_filter/filtering_lite.rs b/src/group/proposal_filter/filtering_lite.rs
new file mode 100644
index 0000000..09ca389
--- /dev/null
+++ b/src/group/proposal_filter/filtering_lite.rs
@@ -0,0 +1,225 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use crate::{
+ client::MlsError,
+ group::proposal_filter::ProposalBundle,
+ iter::wrap_iter,
+ protocol_version::ProtocolVersion,
+ time::MlsTime,
+ tree_kem::{leaf_node_validator::LeafNodeValidator, node::LeafIndex},
+ CipherSuiteProvider, ExtensionList,
+};
+
+use super::filtering_common::{filter_out_invalid_psks, ApplyProposalsOutput, ProposalApplier};
+
+#[cfg(feature = "by_ref_proposal")]
+use {crate::extension::ExternalSendersExt, mls_rs_core::error::IntoAnyError};
+
+use mls_rs_core::{identity::IdentityProvider, psk::PreSharedKeyStorage};
+
+#[cfg(feature = "custom_proposal")]
+use itertools::Itertools;
+
+#[cfg(all(not(mls_build_async), feature = "rayon"))]
+use rayon::prelude::*;
+
+#[cfg(mls_build_async)]
+use futures::{StreamExt, TryStreamExt};
+
+#[cfg(feature = "custom_proposal")]
+use crate::tree_kem::TreeKemPublic;
+
+#[cfg(feature = "psk")]
+use crate::group::{
+ proposal::PreSharedKeyProposal, JustPreSharedKeyID, ResumptionPSKUsage, ResumptionPsk,
+};
+
+#[cfg(all(feature = "std", feature = "psk"))]
+use std::collections::HashSet;
+
+impl<'a, C, P, CSP> ProposalApplier<'a, C, P, CSP>
+where
+ C: IdentityProvider,
+ P: PreSharedKeyStorage,
+ CSP: CipherSuiteProvider,
+{
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(super) async fn apply_proposals_from_member(
+ &self,
+ commit_sender: LeafIndex,
+ proposals: &ProposalBundle,
+ commit_time: Option<MlsTime>,
+ ) -> Result<ApplyProposalsOutput, MlsError> {
+ filter_out_removal_of_committer(commit_sender, proposals)?;
+ filter_out_invalid_psks(self.cipher_suite_provider, proposals, self.psk_storage).await?;
+
+ #[cfg(feature = "by_ref_proposal")]
+ filter_out_invalid_group_extensions(proposals, self.identity_provider, commit_time).await?;
+
+ filter_out_extra_group_context_extensions(proposals)?;
+ filter_out_invalid_reinit(proposals, self.protocol_version)?;
+ filter_out_reinit_if_other_proposals(proposals)?;
+
+ self.apply_proposal_changes(proposals, commit_time).await
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(super) async fn apply_proposal_changes(
+ &self,
+ proposals: &ProposalBundle,
+ commit_time: Option<MlsTime>,
+ ) -> Result<ApplyProposalsOutput, MlsError> {
+ match proposals.group_context_extensions_proposal().cloned() {
+ Some(p) => {
+ self.apply_proposals_with_new_capabilities(proposals, p, commit_time)
+ .await
+ }
+ None => {
+ self.apply_tree_changes(proposals, self.original_group_extensions, commit_time)
+ .await
+ }
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(super) async fn apply_tree_changes(
+ &self,
+ proposals: &ProposalBundle,
+ group_extensions_in_use: &ExtensionList,
+ commit_time: Option<MlsTime>,
+ ) -> Result<ApplyProposalsOutput, MlsError> {
+ self.validate_new_nodes(proposals, group_extensions_in_use, commit_time)
+ .await?;
+
+ let mut new_tree = self.original_tree.clone();
+
+ let added = new_tree
+ .batch_edit_lite(
+ proposals,
+ group_extensions_in_use,
+ self.identity_provider,
+ self.cipher_suite_provider,
+ )
+ .await?;
+
+ let new_context_extensions = proposals
+ .group_context_extensions
+ .first()
+ .map(|gce| gce.proposal.clone());
+
+ Ok(ApplyProposalsOutput {
+ new_tree,
+ indexes_of_added_kpkgs: added,
+ external_init_index: None,
+ new_context_extensions,
+ })
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn validate_new_nodes(
+ &self,
+ proposals: &ProposalBundle,
+ group_extensions_in_use: &ExtensionList,
+ commit_time: Option<MlsTime>,
+ ) -> Result<(), MlsError> {
+ let leaf_node_validator = &LeafNodeValidator::new(
+ self.cipher_suite_provider,
+ self.identity_provider,
+ Some(group_extensions_in_use),
+ );
+
+ let adds = wrap_iter(proposals.add_proposals());
+
+ #[cfg(mls_build_async)]
+ let adds = adds.map(Ok);
+
+ { adds }
+ .try_for_each(|p| {
+ self.validate_new_node(leaf_node_validator, &p.proposal.key_package, commit_time)
+ })
+ .await
+ }
+}
+
+fn filter_out_removal_of_committer(
+ commit_sender: LeafIndex,
+ proposals: &ProposalBundle,
+) -> Result<(), MlsError> {
+ for p in &proposals.removals {
+ (p.proposal.to_remove != commit_sender)
+ .then_some(())
+ .ok_or(MlsError::CommitterSelfRemoval)?;
+ }
+
+ Ok(())
+}
+
+#[cfg(feature = "by_ref_proposal")]
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+async fn filter_out_invalid_group_extensions<C>(
+ proposals: &ProposalBundle,
+ identity_provider: &C,
+ commit_time: Option<MlsTime>,
+) -> Result<(), MlsError>
+where
+ C: IdentityProvider,
+{
+ if let Some(p) = proposals.group_context_extensions.first() {
+ if let Some(ext) = p.proposal.get_as::<ExternalSendersExt>()? {
+ ext.verify_all(identity_provider, commit_time, p.proposal())
+ .await
+ .map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))?;
+ }
+ }
+
+ Ok(())
+}
+
+fn filter_out_extra_group_context_extensions(proposals: &ProposalBundle) -> Result<(), MlsError> {
+ (proposals.group_context_extensions.len() < 2)
+ .then_some(())
+ .ok_or(MlsError::MoreThanOneGroupContextExtensionsProposal)
+}
+
+fn filter_out_invalid_reinit(
+ proposals: &ProposalBundle,
+ protocol_version: ProtocolVersion,
+) -> Result<(), MlsError> {
+ if let Some(p) = proposals.reinitializations.first() {
+ (p.proposal.version >= protocol_version)
+ .then_some(())
+ .ok_or(MlsError::InvalidProtocolVersionInReInit)?;
+ }
+
+ Ok(())
+}
+
+fn filter_out_reinit_if_other_proposals(proposals: &ProposalBundle) -> Result<(), MlsError> {
+ (proposals.reinitializations.is_empty() || proposals.length() == 1)
+ .then_some(())
+ .ok_or(MlsError::OtherProposalWithReInit)
+}
+
+#[cfg(feature = "custom_proposal")]
+pub(super) fn filter_out_unsupported_custom_proposals(
+ proposals: &ProposalBundle,
+ tree: &TreeKemPublic,
+) -> Result<(), MlsError> {
+ let supported_types = proposals
+ .custom_proposal_types()
+ .filter(|t| tree.can_support_proposal(*t))
+ .collect_vec();
+
+ for p in &proposals.custom_proposals {
+ let proposal_type = p.proposal.proposal_type();
+
+ supported_types
+ .contains(&proposal_type)
+ .then_some(())
+ .ok_or(MlsError::UnsupportedCustomProposal(proposal_type))?;
+ }
+
+ Ok(())
+}
diff --git a/src/group/proposal_ref.rs b/src/group/proposal_ref.rs
new file mode 100644
index 0000000..c97c9a1
--- /dev/null
+++ b/src/group/proposal_ref.rs
@@ -0,0 +1,226 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use core::ops::Deref;
+
+use super::*;
+use crate::hash_reference::HashReference;
+
+#[cfg_attr(
+ all(feature = "ffi", not(test)),
+ safer_ffi_gen::ffi_type(clone, opaque)
+)]
+#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+/// Unique identifier for a proposal message.
+pub struct ProposalRef(HashReference);
+
+impl Deref for ProposalRef {
+ type Target = [u8];
+
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+#[cfg_attr(all(feature = "ffi", not(test)), ::safer_ffi_gen::safer_ffi_gen)]
+impl ProposalRef {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn from_content<CS: CipherSuiteProvider>(
+ cipher_suite_provider: &CS,
+ content: &AuthenticatedContent,
+ ) -> Result<Self, MlsError> {
+ let bytes = &content.mls_encode_to_vec()?;
+
+ Ok(ProposalRef(
+ HashReference::compute(bytes, b"MLS 1.0 Proposal Reference", cipher_suite_provider)
+ .await?,
+ ))
+ }
+
+ pub fn as_slice(&self) -> &[u8] {
+ &self.0
+ }
+}
+
+#[cfg(test)]
+pub(crate) mod test_utils {
+ use super::*;
+ use crate::group::test_utils::{random_bytes, TEST_GROUP};
+ use alloc::boxed::Box;
+
+ impl ProposalRef {
+ pub fn new_fake(bytes: Vec<u8>) -> Self {
+ Self(bytes.into())
+ }
+ }
+
+ pub fn auth_content_from_proposal<S>(proposal: Proposal, sender: S) -> AuthenticatedContent
+ where
+ S: Into<Sender>,
+ {
+ AuthenticatedContent {
+ wire_format: WireFormat::PublicMessage,
+ content: FramedContent {
+ group_id: TEST_GROUP.to_vec(),
+ epoch: 0,
+ sender: sender.into(),
+ authenticated_data: vec![],
+ content: Content::Proposal(Box::new(proposal)),
+ },
+ auth: FramedContentAuthData {
+ signature: MessageSignature::from(random_bytes(128)),
+ confirmation_tag: None,
+ },
+ }
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use super::test_utils::auth_content_from_proposal;
+ use super::*;
+ use crate::{
+ crypto::test_utils::{test_cipher_suite_provider, try_test_cipher_suite_provider},
+ key_package::test_utils::test_key_package,
+ tree_kem::leaf_node::test_utils::get_basic_test_node,
+ };
+ use alloc::boxed::Box;
+
+ use crate::extension::RequiredCapabilitiesExt;
+
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn get_test_extension_list() -> ExtensionList {
+ let test_extension = RequiredCapabilitiesExt {
+ extensions: vec![42.into()],
+ proposals: Default::default(),
+ credentials: vec![],
+ };
+
+ let mut extension_list = ExtensionList::new();
+ extension_list.set_from(test_extension).unwrap();
+
+ extension_list
+ }
+
+ #[derive(serde::Serialize, serde::Deserialize)]
+ struct TestCase {
+ cipher_suite: u16,
+ #[serde(with = "hex::serde")]
+ input: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ output: Vec<u8>,
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ async fn generate_proposal_test_cases() -> Vec<TestCase> {
+ let mut test_cases = Vec::new();
+
+ for (protocol_version, cipher_suite) in
+ ProtocolVersion::all().flat_map(|p| CipherSuite::all().map(move |cs| (p, cs)))
+ {
+ let sender = LeafIndex(0);
+
+ let add = auth_content_from_proposal(
+ Proposal::Add(Box::new(AddProposal {
+ key_package: test_key_package(protocol_version, cipher_suite, "alice").await,
+ })),
+ sender,
+ );
+
+ let update = auth_content_from_proposal(
+ Proposal::Update(UpdateProposal {
+ leaf_node: get_basic_test_node(cipher_suite, "foo").await,
+ }),
+ sender,
+ );
+
+ let remove = auth_content_from_proposal(
+ Proposal::Remove(RemoveProposal {
+ to_remove: LeafIndex(1),
+ }),
+ sender,
+ );
+
+ let group_context_ext = auth_content_from_proposal(
+ Proposal::GroupContextExtensions(get_test_extension_list()),
+ sender,
+ );
+
+ let cipher_suite_provider = test_cipher_suite_provider(cipher_suite);
+
+ test_cases.push(TestCase {
+ cipher_suite: cipher_suite.into(),
+ input: add.mls_encode_to_vec().unwrap(),
+ output: ProposalRef::from_content(&cipher_suite_provider, &add)
+ .await
+ .unwrap()
+ .to_vec(),
+ });
+
+ test_cases.push(TestCase {
+ cipher_suite: cipher_suite.into(),
+ input: update.mls_encode_to_vec().unwrap(),
+ output: ProposalRef::from_content(&cipher_suite_provider, &update)
+ .await
+ .unwrap()
+ .to_vec(),
+ });
+
+ test_cases.push(TestCase {
+ cipher_suite: cipher_suite.into(),
+ input: remove.mls_encode_to_vec().unwrap(),
+ output: ProposalRef::from_content(&cipher_suite_provider, &remove)
+ .await
+ .unwrap()
+ .to_vec(),
+ });
+
+ test_cases.push(TestCase {
+ cipher_suite: cipher_suite.into(),
+ input: group_context_ext.mls_encode_to_vec().unwrap(),
+ output: ProposalRef::from_content(&cipher_suite_provider, &group_context_ext)
+ .await
+ .unwrap()
+ .to_vec(),
+ });
+ }
+
+ test_cases
+ }
+
+ #[cfg(mls_build_async)]
+ async fn load_test_cases() -> Vec<TestCase> {
+ load_test_case_json!(proposal_ref, generate_proposal_test_cases().await)
+ }
+
+ #[cfg(not(mls_build_async))]
+ fn load_test_cases() -> Vec<TestCase> {
+ load_test_case_json!(proposal_ref, generate_proposal_test_cases())
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_proposal_ref() {
+ let test_cases = load_test_cases().await;
+
+ for one_case in test_cases {
+ let Some(cs_provider) = try_test_cipher_suite_provider(one_case.cipher_suite) else {
+ continue;
+ };
+
+ let proposal_content =
+ AuthenticatedContent::mls_decode(&mut one_case.input.as_slice()).unwrap();
+
+ let proposal_ref = ProposalRef::from_content(&cs_provider, &proposal_content)
+ .await
+ .unwrap();
+
+ let expected_out = ProposalRef(HashReference::from(one_case.output));
+
+ assert_eq!(expected_out, proposal_ref);
+ }
+ }
+}
diff --git a/src/group/resumption.rs b/src/group/resumption.rs
new file mode 100644
index 0000000..3478ef3
--- /dev/null
+++ b/src/group/resumption.rs
@@ -0,0 +1,299 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::vec::Vec;
+
+use mls_rs_core::{
+ crypto::{CipherSuite, SignatureSecretKey},
+ extension::ExtensionList,
+ identity::SigningIdentity,
+ protocol_version::ProtocolVersion,
+};
+
+use crate::{client::MlsError, Client, Group, MlsMessage};
+
+use super::{
+ proposal::ReInitProposal, ClientConfig, ExportedTree, JustPreSharedKeyID, MessageProcessor,
+ NewMemberInfo, PreSharedKeyID, PskGroupId, PskSecretInput, ResumptionPSKUsage, ResumptionPsk,
+};
+
+struct ResumptionGroupParameters<'a> {
+ group_id: &'a [u8],
+ cipher_suite: CipherSuite,
+ version: ProtocolVersion,
+ extensions: &'a ExtensionList,
+}
+
+pub struct ReinitClient<C: ClientConfig + Clone> {
+ client: Client<C>,
+ reinit: ReInitProposal,
+ psk_input: PskSecretInput,
+}
+
+impl<C> Group<C>
+where
+ C: ClientConfig + Clone,
+{
+ /// Create a sub-group from a subset of the current group members.
+ ///
+ /// Membership within the resulting sub-group is indicated by providing a
+ /// key package that produces the same
+ /// [identity](crate::IdentityProvider::identity) value
+ /// as an existing group member. The identity value of each key package
+ /// is determined using the
+ /// [`IdentityProvider`](crate::IdentityProvider)
+ /// that is currently in use by this group instance.
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn branch(
+ &self,
+ sub_group_id: Vec<u8>,
+ new_key_packages: Vec<MlsMessage>,
+ ) -> Result<(Group<C>, Vec<MlsMessage>), MlsError> {
+ let new_group_params = ResumptionGroupParameters {
+ group_id: &sub_group_id,
+ cipher_suite: self.cipher_suite(),
+ version: self.protocol_version(),
+ extensions: &self.group_state().context.extensions,
+ };
+
+ resumption_create_group(
+ self.config.clone(),
+ new_key_packages,
+ &new_group_params,
+ // TODO investigate if it's worth updating your own signing identity here
+ self.current_member_signing_identity()?.clone(),
+ self.signer.clone(),
+ #[cfg(any(feature = "private_message", feature = "psk"))]
+ self.resumption_psk_input(ResumptionPSKUsage::Branch)?,
+ )
+ .await
+ }
+
+ /// Join a subgroup that was created by [`Group::branch`].
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn join_subgroup(
+ &self,
+ welcome: &MlsMessage,
+ tree_data: Option<ExportedTree<'_>>,
+ ) -> Result<(Group<C>, NewMemberInfo), MlsError> {
+ let expected_new_group_prams = ResumptionGroupParameters {
+ group_id: &[],
+ cipher_suite: self.cipher_suite(),
+ version: self.protocol_version(),
+ extensions: &self.group_state().context.extensions,
+ };
+
+ resumption_join_group(
+ self.config.clone(),
+ self.signer.clone(),
+ welcome,
+ tree_data,
+ expected_new_group_prams,
+ false,
+ self.resumption_psk_input(ResumptionPSKUsage::Branch)?,
+ )
+ .await
+ }
+
+ /// Generate a [`ReinitClient`] that can be used to create or join a new group
+ /// that is based on properties defined by a [`ReInitProposal`]
+ /// committed in a previously accepted commit. This is the only action available
+ /// after accepting such a commit. The old group can no longer be used according to the RFC.
+ ///
+ /// If the [`ReInitProposal`] changes the ciphersuite, then `new_signer`
+ /// and `new_signer_identity` must be set and match the new ciphersuite, as indicated by
+ /// [`pending_reinit_ciphersuite`](crate::group::StateUpdate::pending_reinit_ciphersuite)
+ /// of the [`StateUpdate`](crate::group::StateUpdate) outputted after processing the
+ /// commit to the reinit proposal. The value of [identity](crate::IdentityProvider::identity)
+ /// must be the same for `new_signing_identity` and the current identity in use by this
+ /// group instance.
+ pub fn get_reinit_client(
+ self,
+ new_signer: Option<SignatureSecretKey>,
+ new_signing_identity: Option<SigningIdentity>,
+ ) -> Result<ReinitClient<C>, MlsError> {
+ let psk_input = self.resumption_psk_input(ResumptionPSKUsage::Reinit)?;
+
+ let new_signing_identity = new_signing_identity
+ .map(Ok)
+ .unwrap_or_else(|| self.current_member_signing_identity().cloned())?;
+
+ let reinit = self
+ .state
+ .pending_reinit
+ .ok_or(MlsError::PendingReInitNotFound)?;
+
+ let new_signer = match new_signer {
+ Some(signer) => signer,
+ None => self.signer,
+ };
+
+ let client = Client::new(
+ self.config,
+ Some(new_signer),
+ Some((new_signing_identity, reinit.new_cipher_suite())),
+ reinit.new_version(),
+ );
+
+ Ok(ReinitClient {
+ client,
+ reinit,
+ psk_input,
+ })
+ }
+
+ fn resumption_psk_input(&self, usage: ResumptionPSKUsage) -> Result<PskSecretInput, MlsError> {
+ let psk = self.epoch_secrets.resumption_secret.clone();
+
+ let id = JustPreSharedKeyID::Resumption(ResumptionPsk {
+ usage,
+ psk_group_id: PskGroupId(self.group_id().to_vec()),
+ psk_epoch: self.current_epoch(),
+ });
+
+ let id = PreSharedKeyID::new(id, self.cipher_suite_provider())?;
+ Ok(PskSecretInput { id, psk })
+ }
+}
+
+/// A [`Client`] that can be used to create or join a new group
+/// that is based on properties defined by a [`ReInitProposal`]
+/// committed in a previously accepted commit.
+impl<C: ClientConfig + Clone> ReinitClient<C> {
+ /// Generate a key package for the new group. The key package can
+ /// be used in [`ReinitClient::commit`].
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn generate_key_package(&self) -> Result<MlsMessage, MlsError> {
+ self.client.generate_key_package_message().await
+ }
+
+ /// Create the new group using new key packages of all group members, possibly
+ /// generated by [`ReinitClient::generate_key_package`].
+ ///
+ /// # Warning
+ ///
+ /// This function will fail if the number of members in the reinitialized
+ /// group is not the same as the prior group roster.
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn commit(
+ self,
+ new_key_packages: Vec<MlsMessage>,
+ ) -> Result<(Group<C>, Vec<MlsMessage>), MlsError> {
+ let new_group_params = ResumptionGroupParameters {
+ group_id: self.reinit.group_id(),
+ cipher_suite: self.reinit.new_cipher_suite(),
+ version: self.reinit.new_version(),
+ extensions: self.reinit.new_group_context_extensions(),
+ };
+
+ resumption_create_group(
+ self.client.config.clone(),
+ new_key_packages,
+ &new_group_params,
+ // These private fields are created with `Some(x)` by `get_reinit_client`
+ self.client.signing_identity.unwrap().0,
+ self.client.signer.unwrap(),
+ #[cfg(any(feature = "private_message", feature = "psk"))]
+ self.psk_input,
+ )
+ .await
+ }
+
+ /// Join a reinitialized group that was created by [`ReinitClient::commit`].
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn join(
+ self,
+ welcome: &MlsMessage,
+ tree_data: Option<ExportedTree<'_>>,
+ ) -> Result<(Group<C>, NewMemberInfo), MlsError> {
+ let reinit = self.reinit;
+
+ let expected_group_params = ResumptionGroupParameters {
+ group_id: reinit.group_id(),
+ cipher_suite: reinit.new_cipher_suite(),
+ version: reinit.new_version(),
+ extensions: reinit.new_group_context_extensions(),
+ };
+
+ resumption_join_group(
+ self.client.config,
+ // This private field is created with `Some(x)` by `get_reinit_client`
+ self.client.signer.unwrap(),
+ welcome,
+ tree_data,
+ expected_group_params,
+ true,
+ self.psk_input,
+ )
+ .await
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+async fn resumption_create_group<C: ClientConfig + Clone>(
+ config: C,
+ new_key_packages: Vec<MlsMessage>,
+ new_group_params: &ResumptionGroupParameters<'_>,
+ signing_identity: SigningIdentity,
+ signer: SignatureSecretKey,
+ psk_input: PskSecretInput,
+) -> Result<(Group<C>, Vec<MlsMessage>), MlsError> {
+ // Create a new group with new parameters
+ let mut group = Group::new(
+ config,
+ Some(new_group_params.group_id.to_vec()),
+ new_group_params.cipher_suite,
+ new_group_params.version,
+ signing_identity,
+ new_group_params.extensions.clone(),
+ signer,
+ )
+ .await?;
+
+ // Install the resumption psk in the new group
+ group.previous_psk = Some(psk_input);
+
+ // Create a commit that adds new key packages and uses the resumption PSK
+ let mut commit = group.commit_builder();
+
+ for kp in new_key_packages.into_iter() {
+ commit = commit.add_member(kp)?;
+ }
+
+ let commit = commit.build().await?;
+ group.apply_pending_commit().await?;
+
+ // Uninstall the resumption psk on success (in case of failure, the new group is discarded anyway)
+ group.previous_psk = None;
+
+ Ok((group, commit.welcome_messages))
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+async fn resumption_join_group<C: ClientConfig + Clone>(
+ config: C,
+ signer: SignatureSecretKey,
+ welcome: &MlsMessage,
+ tree_data: Option<ExportedTree<'_>>,
+ expected_new_group_params: ResumptionGroupParameters<'_>,
+ verify_group_id: bool,
+ psk_input: PskSecretInput,
+) -> Result<(Group<C>, NewMemberInfo), MlsError> {
+ let psk_input = Some(psk_input);
+
+ let (group, new_member_info) =
+ Group::<C>::from_welcome_message(welcome, tree_data, config, signer, psk_input).await?;
+
+ if group.protocol_version() != expected_new_group_params.version {
+ Err(MlsError::ProtocolVersionMismatch)
+ } else if group.cipher_suite() != expected_new_group_params.cipher_suite {
+ Err(MlsError::CipherSuiteMismatch)
+ } else if verify_group_id && group.group_id() != expected_new_group_params.group_id {
+ Err(MlsError::GroupIdMismatch)
+ } else if &group.group_state().context.extensions != expected_new_group_params.extensions {
+ Err(MlsError::ReInitExtensionsMismatch)
+ } else {
+ Ok((group, new_member_info))
+ }
+}
diff --git a/src/group/roster.rs b/src/group/roster.rs
new file mode 100644
index 0000000..dd0a9f0
--- /dev/null
+++ b/src/group/roster.rs
@@ -0,0 +1,91 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use super::*;
+
+pub use mls_rs_core::group::Member;
+
+#[cfg(feature = "state_update")]
+pub(crate) fn member_from_key_package(key_package: &KeyPackage, index: LeafIndex) -> Member {
+ member_from_leaf_node(&key_package.leaf_node, index)
+}
+
+pub(crate) fn member_from_leaf_node(leaf_node: &LeafNode, leaf_index: LeafIndex) -> Member {
+ Member::new(
+ *leaf_index,
+ leaf_node.signing_identity.clone(),
+ leaf_node.ungreased_capabilities(),
+ leaf_node.ungreased_extensions(),
+ )
+}
+
+#[cfg_attr(
+ all(feature = "ffi", not(test)),
+ safer_ffi_gen::ffi_type(clone, opaque)
+)]
+#[derive(Clone, Debug)]
+pub struct Roster<'a> {
+ pub(crate) public_tree: &'a TreeKemPublic,
+}
+
+#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
+impl<'a> Roster<'a> {
+ /// Iterator over the current roster that lazily copies data out of the
+ /// internal group state.
+ ///
+ /// # Warning
+ ///
+ /// The indexes within this iterator do not correlate with indexes of users
+ /// within [`ReceivedMessage`] content descriptions due to the layout of
+ /// member information within a MLS group state.
+ #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen_ignore)]
+ pub fn members_iter(&self) -> impl Iterator<Item = Member> + 'a {
+ self.public_tree
+ .non_empty_leaves()
+ .map(|(index, node)| member_from_leaf_node(node, index))
+ }
+
+ /// The current set of group members. This function makes a clone of
+ /// member information from the internal group state.
+ ///
+ /// # Warning
+ ///
+ /// The indexes within this roster do not correlate with indexes of users
+ /// within [`ReceivedMessage`] content descriptions due to the layout of
+ /// member information within a MLS group state.
+ pub fn members(&self) -> Vec<Member> {
+ self.members_iter().collect()
+ }
+
+ /// Retrieve the member with given `index` within the group in time `O(1)`.
+ /// This index does correlate with indexes of users within [`ReceivedMessage`]
+ /// content descriptions.
+ pub fn member_with_index(&self, index: u32) -> Result<Member, MlsError> {
+ let index = LeafIndex(index);
+
+ self.public_tree
+ .get_leaf_node(index)
+ .map(|l| member_from_leaf_node(l, index))
+ }
+
+ /// Iterator over member's signing identities.
+ ///
+ /// # Warning
+ ///
+ /// The indexes within this iterator do not correlate with indexes of users
+ /// within [`ReceivedMessage`] content descriptions due to the layout of
+ /// member information within a MLS group state.
+ #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen_ignore)]
+ pub fn member_identities_iter(&self) -> impl Iterator<Item = &SigningIdentity> + '_ {
+ self.public_tree
+ .non_empty_leaves()
+ .map(|(_, node)| &node.signing_identity)
+ }
+}
+
+impl TreeKemPublic {
+ pub(crate) fn roster(&self) -> Roster {
+ Roster { public_tree: self }
+ }
+}
diff --git a/src/group/secret_tree.rs b/src/group/secret_tree.rs
new file mode 100644
index 0000000..df0c30f
--- /dev/null
+++ b/src/group/secret_tree.rs
@@ -0,0 +1,1115 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::vec::Vec;
+use core::{
+ fmt::{self, Debug},
+ ops::{Deref, DerefMut},
+};
+
+use zeroize::Zeroizing;
+
+use crate::{client::MlsError, tree_kem::math::TreeIndex, CipherSuiteProvider};
+
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+use mls_rs_core::error::IntoAnyError;
+
+#[cfg(feature = "std")]
+use std::collections::HashMap;
+
+#[cfg(not(feature = "std"))]
+use alloc::collections::BTreeMap;
+
+use super::key_schedule::kdf_expand_with_label;
+
+pub(crate) const MAX_RATCHET_BACK_HISTORY: u32 = 1024;
+
+#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+#[repr(u8)]
+enum SecretTreeNode {
+ Secret(TreeSecret) = 0u8,
+ Ratchet(SecretRatchets) = 1u8,
+}
+
+impl SecretTreeNode {
+ fn into_secret(self) -> Option<TreeSecret> {
+ if let SecretTreeNode::Secret(secret) = self {
+ Some(secret)
+ } else {
+ None
+ }
+ }
+}
+
+#[derive(Clone, PartialEq, MlsEncode, MlsDecode, MlsSize)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+struct TreeSecret(
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))]
+ Zeroizing<Vec<u8>>,
+);
+
+impl Debug for TreeSecret {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ mls_rs_core::debug::pretty_bytes(&self.0)
+ .named("TreeSecret")
+ .fmt(f)
+ }
+}
+
+impl Deref for TreeSecret {
+ type Target = Vec<u8>;
+
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+impl DerefMut for TreeSecret {
+ fn deref_mut(&mut self) -> &mut Self::Target {
+ &mut self.0
+ }
+}
+
+impl AsRef<[u8]> for TreeSecret {
+ fn as_ref(&self) -> &[u8] {
+ &self.0
+ }
+}
+
+impl From<Vec<u8>> for TreeSecret {
+ fn from(vec: Vec<u8>) -> Self {
+ TreeSecret(Zeroizing::new(vec))
+ }
+}
+
+impl From<Zeroizing<Vec<u8>>> for TreeSecret {
+ fn from(vec: Zeroizing<Vec<u8>>) -> Self {
+ TreeSecret(vec)
+ }
+}
+
+#[derive(Clone, Debug, PartialEq, MlsEncode, MlsDecode, MlsSize, Default)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+struct TreeSecretsVec<T: TreeIndex> {
+ #[cfg(feature = "std")]
+ inner: HashMap<T, SecretTreeNode>,
+ #[cfg(not(feature = "std"))]
+ inner: Vec<(T, SecretTreeNode)>,
+}
+
+#[cfg(feature = "std")]
+impl<T: TreeIndex> TreeSecretsVec<T> {
+ fn set_node(&mut self, index: T, value: SecretTreeNode) {
+ self.inner.insert(index, value);
+ }
+
+ fn take_node(&mut self, index: &T) -> Option<SecretTreeNode> {
+ self.inner.remove(index)
+ }
+}
+
+#[cfg(not(feature = "std"))]
+impl<T: TreeIndex> TreeSecretsVec<T> {
+ fn set_node(&mut self, index: T, value: SecretTreeNode) {
+ if let Some(i) = self.find_node(&index) {
+ self.inner[i] = (index, value)
+ } else {
+ self.inner.push((index, value))
+ }
+ }
+
+ fn take_node(&mut self, index: &T) -> Option<SecretTreeNode> {
+ self.find_node(index).map(|i| self.inner.remove(i).1)
+ }
+
+ fn find_node(&self, index: &T) -> Option<usize> {
+ use itertools::Itertools;
+
+ self.inner
+ .iter()
+ .find_position(|(i, _)| i == index)
+ .map(|(i, _)| i)
+ }
+}
+
+#[derive(Clone, Debug, PartialEq, MlsEncode, MlsDecode, MlsSize)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub struct SecretTree<T: TreeIndex> {
+ known_secrets: TreeSecretsVec<T>,
+ leaf_count: T,
+}
+
+impl<T: TreeIndex> SecretTree<T> {
+ pub(crate) fn empty() -> SecretTree<T> {
+ SecretTree {
+ known_secrets: Default::default(),
+ leaf_count: T::zero(),
+ }
+ }
+}
+
+#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub struct SecretRatchets {
+ pub application: SecretKeyRatchet,
+ pub handshake: SecretKeyRatchet,
+}
+
+impl SecretRatchets {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn message_key_generation<P: CipherSuiteProvider>(
+ &mut self,
+ cipher_suite_provider: &P,
+ generation: u32,
+ key_type: KeyType,
+ ) -> Result<MessageKeyData, MlsError> {
+ match key_type {
+ KeyType::Handshake => {
+ self.handshake
+ .get_message_key(cipher_suite_provider, generation)
+ .await
+ }
+ KeyType::Application => {
+ self.application
+ .get_message_key(cipher_suite_provider, generation)
+ .await
+ }
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn next_message_key<P: CipherSuiteProvider>(
+ &mut self,
+ cipher_suite: &P,
+ key_type: KeyType,
+ ) -> Result<MessageKeyData, MlsError> {
+ match key_type {
+ KeyType::Handshake => self.handshake.next_message_key(cipher_suite).await,
+ KeyType::Application => self.application.next_message_key(cipher_suite).await,
+ }
+ }
+}
+
+impl<T: TreeIndex> SecretTree<T> {
+ pub fn new(leaf_count: T, encryption_secret: Zeroizing<Vec<u8>>) -> SecretTree<T> {
+ let mut known_secrets = TreeSecretsVec::default();
+
+ let root_secret = SecretTreeNode::Secret(TreeSecret::from(encryption_secret));
+ known_secrets.set_node(leaf_count.root(), root_secret);
+
+ Self {
+ known_secrets,
+ leaf_count,
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn consume_node<P: CipherSuiteProvider>(
+ &mut self,
+ cipher_suite_provider: &P,
+ index: &T,
+ ) -> Result<(), MlsError> {
+ let node = self.known_secrets.take_node(index);
+
+ if let Some(secret) = node.and_then(|n| n.into_secret()) {
+ let left_index = index.left().ok_or(MlsError::LeafNodeNoChildren)?;
+ let right_index = index.right().ok_or(MlsError::LeafNodeNoChildren)?;
+
+ let left_secret =
+ kdf_expand_with_label(cipher_suite_provider, &secret, b"tree", b"left", None)
+ .await?;
+
+ let right_secret =
+ kdf_expand_with_label(cipher_suite_provider, &secret, b"tree", b"right", None)
+ .await?;
+
+ self.known_secrets
+ .set_node(left_index, SecretTreeNode::Secret(left_secret.into()));
+
+ self.known_secrets
+ .set_node(right_index, SecretTreeNode::Secret(right_secret.into()));
+ }
+
+ Ok(())
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn take_leaf_ratchet<P: CipherSuiteProvider>(
+ &mut self,
+ cipher_suite: &P,
+ leaf_index: &T,
+ ) -> Result<SecretRatchets, MlsError> {
+ let node_index = leaf_index;
+
+ let node = match self.known_secrets.take_node(node_index) {
+ Some(node) => node,
+ None => {
+ // Start at the root node and work your way down consuming any intermediates needed
+ for i in node_index.direct_copath(&self.leaf_count).into_iter().rev() {
+ self.consume_node(cipher_suite, &i.path).await?;
+ }
+
+ self.known_secrets
+ .take_node(node_index)
+ .ok_or(MlsError::InvalidLeafConsumption)?
+ }
+ };
+
+ Ok(match node {
+ SecretTreeNode::Ratchet(ratchet) => ratchet,
+ SecretTreeNode::Secret(secret) => SecretRatchets {
+ application: SecretKeyRatchet::new(cipher_suite, &secret, KeyType::Application)
+ .await?,
+ handshake: SecretKeyRatchet::new(cipher_suite, &secret, KeyType::Handshake).await?,
+ },
+ })
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn next_message_key<P: CipherSuiteProvider>(
+ &mut self,
+ cipher_suite: &P,
+ leaf_index: T,
+ key_type: KeyType,
+ ) -> Result<MessageKeyData, MlsError> {
+ let mut ratchet = self.take_leaf_ratchet(cipher_suite, &leaf_index).await?;
+ let res = ratchet.next_message_key(cipher_suite, key_type).await?;
+
+ self.known_secrets
+ .set_node(leaf_index, SecretTreeNode::Ratchet(ratchet));
+
+ Ok(res)
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn message_key_generation<P: CipherSuiteProvider>(
+ &mut self,
+ cipher_suite: &P,
+ leaf_index: T,
+ key_type: KeyType,
+ generation: u32,
+ ) -> Result<MessageKeyData, MlsError> {
+ let mut ratchet = self.take_leaf_ratchet(cipher_suite, &leaf_index).await?;
+
+ let res = ratchet
+ .message_key_generation(cipher_suite, generation, key_type)
+ .await?;
+
+ self.known_secrets
+ .set_node(leaf_index, SecretTreeNode::Ratchet(ratchet));
+
+ Ok(res)
+ }
+}
+
+#[derive(Clone, Copy)]
+pub enum KeyType {
+ Handshake,
+ Application,
+}
+
+#[cfg_attr(
+ all(feature = "ffi", not(test)),
+ safer_ffi_gen::ffi_type(clone, opaque)
+)]
+#[derive(Clone, PartialEq, Eq, MlsEncode, MlsDecode, MlsSize)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+/// AEAD key derived by the MLS secret tree.
+pub struct MessageKeyData {
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))]
+ pub(crate) nonce: Zeroizing<Vec<u8>>,
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))]
+ pub(crate) key: Zeroizing<Vec<u8>>,
+ pub(crate) generation: u32,
+}
+
+impl Debug for MessageKeyData {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("MessageKeyData")
+ .field("nonce", &mls_rs_core::debug::pretty_bytes(&self.nonce))
+ .field("key", &mls_rs_core::debug::pretty_bytes(&self.key))
+ .field("generation", &self.generation)
+ .finish()
+ }
+}
+
+#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
+impl MessageKeyData {
+ /// AEAD nonce.
+ #[cfg_attr(not(feature = "secret_tree_access"), allow(dead_code))]
+ pub fn nonce(&self) -> &[u8] {
+ &self.nonce
+ }
+
+ /// AEAD key.
+ #[cfg_attr(not(feature = "secret_tree_access"), allow(dead_code))]
+ pub fn key(&self) -> &[u8] {
+ &self.key
+ }
+
+ /// Generation of this key within the key schedule.
+ #[cfg_attr(not(feature = "secret_tree_access"), allow(dead_code))]
+ pub fn generation(&self) -> u32 {
+ self.generation
+ }
+}
+
+#[derive(Debug, Clone, PartialEq)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub struct SecretKeyRatchet {
+ secret: TreeSecret,
+ generation: u32,
+ #[cfg(all(feature = "out_of_order", feature = "std"))]
+ history: HashMap<u32, MessageKeyData>,
+ #[cfg(all(feature = "out_of_order", not(feature = "std")))]
+ history: BTreeMap<u32, MessageKeyData>,
+}
+
+impl MlsSize for SecretKeyRatchet {
+ fn mls_encoded_len(&self) -> usize {
+ let len = mls_rs_codec::byte_vec::mls_encoded_len(&self.secret)
+ + self.generation.mls_encoded_len();
+
+ #[cfg(feature = "out_of_order")]
+ return len + mls_rs_codec::iter::mls_encoded_len(self.history.values());
+ #[cfg(not(feature = "out_of_order"))]
+ return len;
+ }
+}
+
+#[cfg(feature = "out_of_order")]
+impl MlsEncode for SecretKeyRatchet {
+ fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error> {
+ mls_rs_codec::byte_vec::mls_encode(&self.secret, writer)?;
+ self.generation.mls_encode(writer)?;
+ mls_rs_codec::iter::mls_encode(self.history.values(), writer)
+ }
+}
+
+#[cfg(not(feature = "out_of_order"))]
+impl MlsEncode for SecretKeyRatchet {
+ fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error> {
+ mls_rs_codec::byte_vec::mls_encode(&self.secret, writer)?;
+ self.generation.mls_encode(writer)
+ }
+}
+
+impl MlsDecode for SecretKeyRatchet {
+ fn mls_decode(reader: &mut &[u8]) -> Result<Self, mls_rs_codec::Error> {
+ Ok(Self {
+ secret: mls_rs_codec::byte_vec::mls_decode(reader)?,
+ generation: u32::mls_decode(reader)?,
+ #[cfg(all(feature = "std", feature = "out_of_order"))]
+ history: mls_rs_codec::iter::mls_decode_collection(reader, |data| {
+ let mut items = HashMap::default();
+
+ while !data.is_empty() {
+ let item = MessageKeyData::mls_decode(data)?;
+ items.insert(item.generation, item);
+ }
+
+ Ok(items)
+ })?,
+ #[cfg(all(not(feature = "std"), feature = "out_of_order"))]
+ history: mls_rs_codec::iter::mls_decode_collection(reader, |data| {
+ let mut items = alloc::collections::BTreeMap::default();
+
+ while !data.is_empty() {
+ let item = MessageKeyData::mls_decode(data)?;
+ items.insert(item.generation, item);
+ }
+
+ Ok(items)
+ })?,
+ })
+ }
+}
+
+impl SecretKeyRatchet {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn new<P: CipherSuiteProvider>(
+ cipher_suite_provider: &P,
+ secret: &[u8],
+ key_type: KeyType,
+ ) -> Result<Self, MlsError> {
+ let label = match key_type {
+ KeyType::Handshake => b"handshake".as_slice(),
+ KeyType::Application => b"application".as_slice(),
+ };
+
+ let secret = kdf_expand_with_label(cipher_suite_provider, secret, label, &[], None)
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
+
+ Ok(Self {
+ secret: TreeSecret::from(secret),
+ generation: 0,
+ #[cfg(feature = "out_of_order")]
+ history: Default::default(),
+ })
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn get_message_key<P: CipherSuiteProvider>(
+ &mut self,
+ cipher_suite_provider: &P,
+ generation: u32,
+ ) -> Result<MessageKeyData, MlsError> {
+ #[cfg(feature = "out_of_order")]
+ if generation < self.generation {
+ return self
+ .history
+ .remove_entry(&generation)
+ .map(|(_, mk)| mk)
+ .ok_or(MlsError::KeyMissing(generation));
+ }
+
+ #[cfg(not(feature = "out_of_order"))]
+ if generation < self.generation {
+ return Err(MlsError::KeyMissing(generation));
+ }
+
+ let max_generation_allowed = self.generation + MAX_RATCHET_BACK_HISTORY;
+
+ if generation > max_generation_allowed {
+ return Err(MlsError::InvalidFutureGeneration(generation));
+ }
+
+ #[cfg(not(feature = "out_of_order"))]
+ while self.generation < generation {
+ self.next_message_key(cipher_suite_provider)?;
+ }
+
+ #[cfg(feature = "out_of_order")]
+ while self.generation < generation {
+ let key_data = self.next_message_key(cipher_suite_provider).await?;
+ self.history.insert(key_data.generation, key_data);
+ }
+
+ self.next_message_key(cipher_suite_provider).await
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn next_message_key<P: CipherSuiteProvider>(
+ &mut self,
+ cipher_suite_provider: &P,
+ ) -> Result<MessageKeyData, MlsError> {
+ let generation = self.generation;
+
+ let key = MessageKeyData {
+ nonce: self
+ .derive_secret(
+ cipher_suite_provider,
+ b"nonce",
+ cipher_suite_provider.aead_nonce_size(),
+ )
+ .await?,
+ key: self
+ .derive_secret(
+ cipher_suite_provider,
+ b"key",
+ cipher_suite_provider.aead_key_size(),
+ )
+ .await?,
+ generation,
+ };
+
+ self.secret = self
+ .derive_secret(
+ cipher_suite_provider,
+ b"secret",
+ cipher_suite_provider.kdf_extract_size(),
+ )
+ .await?
+ .into();
+
+ self.generation = generation + 1;
+
+ Ok(key)
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn derive_secret<P: CipherSuiteProvider>(
+ &self,
+ cipher_suite_provider: &P,
+ label: &[u8],
+ len: usize,
+ ) -> Result<Zeroizing<Vec<u8>>, MlsError> {
+ kdf_expand_with_label(
+ cipher_suite_provider,
+ self.secret.as_ref(),
+ label,
+ &self.generation.to_be_bytes(),
+ Some(len),
+ )
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
+ }
+}
+
+#[cfg(test)]
+pub(crate) mod test_utils {
+ use alloc::{string::String, vec::Vec};
+ use mls_rs_core::crypto::CipherSuiteProvider;
+ use zeroize::Zeroizing;
+
+ use crate::{crypto::test_utils::try_test_cipher_suite_provider, tree_kem::math::TreeIndex};
+
+ use super::{KeyType, SecretKeyRatchet, SecretTree};
+
+ pub(crate) fn get_test_tree<T: TreeIndex>(secret: Vec<u8>, leaf_count: T) -> SecretTree<T> {
+ SecretTree::new(leaf_count, Zeroizing::new(secret))
+ }
+
+ impl SecretTree<u32> {
+ pub(crate) fn get_root_secret(&self) -> Vec<u8> {
+ self.known_secrets
+ .clone()
+ .take_node(&self.leaf_count.root())
+ .unwrap()
+ .into_secret()
+ .unwrap()
+ .to_vec()
+ }
+ }
+
+ #[derive(Debug, serde::Serialize, serde::Deserialize)]
+ pub struct RatchetInteropTestCase {
+ #[serde(with = "hex::serde")]
+ secret: Vec<u8>,
+ label: String,
+ generation: u32,
+ length: usize,
+ #[serde(with = "hex::serde")]
+ out: Vec<u8>,
+ }
+
+ #[derive(Debug, serde::Serialize, serde::Deserialize)]
+ pub struct InteropTestCase {
+ cipher_suite: u16,
+ derive_tree_secret: RatchetInteropTestCase,
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_basic_crypto_test_vectors() {
+ let test_cases: Vec<InteropTestCase> =
+ load_test_case_json!(basic_crypto, Vec::<InteropTestCase>::new());
+
+ for test_case in test_cases {
+ if let Some(cs) = try_test_cipher_suite_provider(test_case.cipher_suite) {
+ test_case.derive_tree_secret.verify(&cs).await
+ }
+ }
+ }
+
+ impl RatchetInteropTestCase {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn verify<P: CipherSuiteProvider>(&self, cs: &P) {
+ let mut ratchet = SecretKeyRatchet::new(cs, &self.secret, KeyType::Application)
+ .await
+ .unwrap();
+
+ ratchet.secret = self.secret.clone().into();
+ ratchet.generation = self.generation;
+
+ let computed = ratchet
+ .derive_secret(cs, self.label.as_bytes(), self.length)
+ .await
+ .unwrap();
+
+ assert_eq!(&computed.to_vec(), &self.out);
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use alloc::vec;
+
+ use crate::{
+ cipher_suite::CipherSuite,
+ client::test_utils::TEST_CIPHER_SUITE,
+ crypto::test_utils::{
+ test_cipher_suite_provider, try_test_cipher_suite_provider, TestCryptoProvider,
+ },
+ tree_kem::node::NodeIndex,
+ };
+
+ #[cfg(not(mls_build_async))]
+ use crate::group::test_utils::random_bytes;
+
+ use super::{test_utils::get_test_tree, *};
+
+ use assert_matches::assert_matches;
+
+ #[cfg(target_arch = "wasm32")]
+ use wasm_bindgen_test::wasm_bindgen_test as test;
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_secret_tree() {
+ test_secret_tree_custom(16u32, (0..16).map(|i| 2 * i).collect(), true).await;
+ test_secret_tree_custom(1u64 << 62, (1..62).map(|i| 1u64 << i).collect(), false).await;
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn test_secret_tree_custom<T: TreeIndex>(
+ leaf_count: T,
+ leaves_to_check: Vec<T>,
+ all_deleted: bool,
+ ) {
+ for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
+ let cs_provider = test_cipher_suite_provider(cipher_suite);
+
+ let test_secret = vec![0u8; cs_provider.kdf_extract_size()];
+ let mut test_tree = get_test_tree(test_secret, leaf_count.clone());
+
+ let mut secrets = Vec::<SecretRatchets>::new();
+
+ for i in &leaves_to_check {
+ let secret = test_tree
+ .take_leaf_ratchet(&test_cipher_suite_provider(cipher_suite), i)
+ .await
+ .unwrap();
+
+ secrets.push(secret);
+ }
+
+ // Verify the tree is now completely empty
+ assert!(!all_deleted || test_tree.known_secrets.inner.is_empty());
+
+ // Verify that all the secrets are unique
+ let count = secrets.len();
+ secrets.dedup();
+ assert_eq!(count, secrets.len());
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_secret_key_ratchet() {
+ for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
+ let provider = test_cipher_suite_provider(cipher_suite);
+
+ let mut app_ratchet = SecretKeyRatchet::new(
+ &provider,
+ &vec![0u8; provider.kdf_extract_size()],
+ KeyType::Application,
+ )
+ .await
+ .unwrap();
+
+ let mut handshake_ratchet = SecretKeyRatchet::new(
+ &provider,
+ &vec![0u8; provider.kdf_extract_size()],
+ KeyType::Handshake,
+ )
+ .await
+ .unwrap();
+
+ let app_key_one = app_ratchet.next_message_key(&provider).await.unwrap();
+ let app_key_two = app_ratchet.next_message_key(&provider).await.unwrap();
+ let app_keys = vec![app_key_one, app_key_two];
+
+ let handshake_key_one = handshake_ratchet.next_message_key(&provider).await.unwrap();
+ let handshake_key_two = handshake_ratchet.next_message_key(&provider).await.unwrap();
+ let handshake_keys = vec![handshake_key_one, handshake_key_two];
+
+ // Verify that the keys have different outcomes due to their different labels
+ assert_ne!(app_keys, handshake_keys);
+
+ // Verify that the keys at each generation are different
+ assert_ne!(handshake_keys[0], handshake_keys[1]);
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_get_key() {
+ for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
+ let provider = test_cipher_suite_provider(cipher_suite);
+
+ let mut ratchet = SecretKeyRatchet::new(
+ &test_cipher_suite_provider(cipher_suite),
+ &vec![0u8; provider.kdf_extract_size()],
+ KeyType::Application,
+ )
+ .await
+ .unwrap();
+
+ let mut ratchet_clone = ratchet.clone();
+
+ // This will generate keys 0 and 1 in ratchet_clone
+ let _ = ratchet_clone.next_message_key(&provider).await.unwrap();
+ let clone_2 = ratchet_clone.next_message_key(&provider).await.unwrap();
+
+ // Going back in time should result in an error
+ let res = ratchet_clone.get_message_key(&provider, 0).await;
+ assert!(res.is_err());
+
+ // Calling get key should be the same as calling next until hitting the desired generation
+ let second_key = ratchet
+ .get_message_key(&provider, ratchet_clone.generation - 1)
+ .await
+ .unwrap();
+
+ assert_eq!(clone_2, second_key)
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_secret_ratchet() {
+ for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
+ let provider = test_cipher_suite_provider(cipher_suite);
+
+ let mut ratchet = SecretKeyRatchet::new(
+ &provider,
+ &vec![0u8; provider.kdf_extract_size()],
+ KeyType::Application,
+ )
+ .await
+ .unwrap();
+
+ let original_secret = ratchet.secret.clone();
+ let _ = ratchet.next_message_key(&provider).await.unwrap();
+ let new_secret = ratchet.secret;
+ assert_ne!(original_secret, new_secret)
+ }
+ }
+
+ #[cfg(feature = "out_of_order")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_out_of_order_keys() {
+ let cipher_suite = TEST_CIPHER_SUITE;
+ let provider = test_cipher_suite_provider(cipher_suite);
+
+ let mut ratchet = SecretKeyRatchet::new(&provider, &[0u8; 32], KeyType::Handshake)
+ .await
+ .unwrap();
+ let mut ratchet_clone = ratchet.clone();
+
+ // Ask for all the keys in order from the original ratchet
+ let mut ordered_keys = Vec::<MessageKeyData>::new();
+
+ for i in 0..=MAX_RATCHET_BACK_HISTORY {
+ ordered_keys.push(ratchet.get_message_key(&provider, i).await.unwrap());
+ }
+
+ // Ask for a key at index MAX_RATCHET_BACK_HISTORY in the clone
+ let last_key = ratchet_clone
+ .get_message_key(&provider, MAX_RATCHET_BACK_HISTORY)
+ .await
+ .unwrap();
+
+ assert_eq!(last_key, ordered_keys[ordered_keys.len() - 1]);
+
+ // Get all the other keys
+ let mut back_history_keys = Vec::<MessageKeyData>::new();
+
+ for i in 0..MAX_RATCHET_BACK_HISTORY - 1 {
+ back_history_keys.push(ratchet_clone.get_message_key(&provider, i).await.unwrap());
+ }
+
+ assert_eq!(
+ back_history_keys,
+ ordered_keys[..(MAX_RATCHET_BACK_HISTORY as usize) - 1]
+ );
+ }
+
+ #[cfg(not(feature = "out_of_order"))]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn out_of_order_keys_should_throw_error() {
+ let cipher_suite = TEST_CIPHER_SUITE;
+ let provider = test_cipher_suite_provider(cipher_suite);
+
+ let mut ratchet = SecretKeyRatchet::new(&provider, &[0u8; 32], KeyType::Handshake)
+ .await
+ .unwrap();
+
+ ratchet.get_message_key(&provider, 10).await.unwrap();
+ let res = ratchet.get_message_key(&provider, 9).await;
+ assert_matches!(res, Err(MlsError::KeyMissing(9)))
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_too_out_of_order() {
+ let cipher_suite = TEST_CIPHER_SUITE;
+ let provider = test_cipher_suite_provider(cipher_suite);
+
+ let mut ratchet = SecretKeyRatchet::new(&provider, &[0u8; 32], KeyType::Handshake)
+ .await
+ .unwrap();
+
+ let res = ratchet
+ .get_message_key(&provider, MAX_RATCHET_BACK_HISTORY + 1)
+ .await;
+
+ let invalid_generation = MAX_RATCHET_BACK_HISTORY + 1;
+
+ assert_matches!(
+ res,
+ Err(MlsError::InvalidFutureGeneration(invalid))
+ if invalid == invalid_generation
+ )
+ }
+
+ #[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)]
+ struct Ratchet {
+ application_keys: Vec<Vec<u8>>,
+ handshake_keys: Vec<Vec<u8>>,
+ }
+
+ #[derive(Debug, serde::Serialize, serde::Deserialize)]
+ struct TestCase {
+ cipher_suite: u16,
+ #[serde(with = "hex::serde")]
+ encryption_secret: Vec<u8>,
+ ratchets: Vec<Ratchet>,
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn get_ratchet_data(
+ secret_tree: &mut SecretTree<NodeIndex>,
+ cipher_suite: CipherSuite,
+ ) -> Vec<Ratchet> {
+ let provider = test_cipher_suite_provider(cipher_suite);
+ let mut ratchet_data = Vec::new();
+
+ for index in 0..16 {
+ let mut ratchets = secret_tree
+ .take_leaf_ratchet(&provider, &(index * 2))
+ .await
+ .unwrap();
+
+ let mut application_keys = Vec::new();
+
+ for _ in 0..20 {
+ let key = ratchets
+ .handshake
+ .next_message_key(&provider)
+ .await
+ .unwrap()
+ .mls_encode_to_vec()
+ .unwrap();
+
+ application_keys.push(key);
+ }
+
+ let mut handshake_keys = Vec::new();
+
+ for _ in 0..20 {
+ let key = ratchets
+ .handshake
+ .next_message_key(&provider)
+ .await
+ .unwrap()
+ .mls_encode_to_vec()
+ .unwrap();
+
+ handshake_keys.push(key);
+ }
+
+ ratchet_data.push(Ratchet {
+ application_keys,
+ handshake_keys,
+ });
+ }
+
+ ratchet_data
+ }
+
+ #[cfg(not(mls_build_async))]
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn generate_test_vector() -> Vec<TestCase> {
+ CipherSuite::all()
+ .map(|cipher_suite| {
+ let provider = test_cipher_suite_provider(cipher_suite);
+ let encryption_secret = random_bytes(provider.kdf_extract_size());
+
+ let mut secret_tree =
+ SecretTree::new(16, Zeroizing::new(encryption_secret.clone()));
+
+ TestCase {
+ cipher_suite: cipher_suite.into(),
+ encryption_secret,
+ ratchets: get_ratchet_data(&mut secret_tree, cipher_suite),
+ }
+ })
+ .collect()
+ }
+
+ #[cfg(mls_build_async)]
+ fn generate_test_vector() -> Vec<TestCase> {
+ panic!("Tests cannot be generated in async mode");
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_secret_tree_test_vectors() {
+ let test_cases: Vec<TestCase> = load_test_case_json!(secret_tree, generate_test_vector());
+
+ for case in test_cases {
+ let Some(cs_provider) = try_test_cipher_suite_provider(case.cipher_suite) else {
+ continue;
+ };
+
+ let mut secret_tree = SecretTree::new(16, Zeroizing::new(case.encryption_secret));
+ let ratchet_data = get_ratchet_data(&mut secret_tree, cs_provider.cipher_suite()).await;
+
+ assert_eq!(ratchet_data, case.ratchets);
+ }
+ }
+}
+
+#[cfg(all(test, feature = "rfc_compliant", feature = "std"))]
+mod interop_tests {
+ #[cfg(not(mls_build_async))]
+ use mls_rs_core::crypto::{CipherSuite, CipherSuiteProvider};
+ use zeroize::Zeroizing;
+
+ use crate::{
+ crypto::test_utils::try_test_cipher_suite_provider,
+ group::{ciphertext_processor::InteropSenderData, secret_tree::KeyType},
+ };
+
+ use super::SecretTree;
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn interop_test_vector() {
+ // The test vector can be found here https://github.com/mlswg/mls-implementations/blob/main/test-vectors/secret-tree.json
+ let test_cases = load_interop_test_cases();
+
+ for case in test_cases {
+ let Some(cs) = try_test_cipher_suite_provider(case.cipher_suite) else {
+ continue;
+ };
+
+ case.sender_data.verify(&cs).await;
+
+ let mut tree = SecretTree::new(
+ case.leaves.len() as u32,
+ Zeroizing::new(case.encryption_secret),
+ );
+
+ for (index, leaves) in case.leaves.iter().enumerate() {
+ for leaf in leaves.iter() {
+ let key = tree
+ .message_key_generation(
+ &cs,
+ (index as u32) * 2,
+ KeyType::Application,
+ leaf.generation,
+ )
+ .await
+ .unwrap();
+
+ assert_eq!(key.key.to_vec(), leaf.application_key);
+ assert_eq!(key.nonce.to_vec(), leaf.application_nonce);
+
+ let key = tree
+ .message_key_generation(
+ &cs,
+ (index as u32) * 2,
+ KeyType::Handshake,
+ leaf.generation,
+ )
+ .await
+ .unwrap();
+
+ assert_eq!(key.key.to_vec(), leaf.handshake_key);
+ assert_eq!(key.nonce.to_vec(), leaf.handshake_nonce);
+ }
+ }
+ }
+ }
+
+ #[derive(Debug, serde::Serialize, serde::Deserialize)]
+ struct InteropTestCase {
+ cipher_suite: u16,
+ #[serde(with = "hex::serde")]
+ encryption_secret: Vec<u8>,
+ sender_data: InteropSenderData,
+ leaves: Vec<Vec<InteropLeaf>>,
+ }
+
+ #[derive(Debug, serde::Serialize, serde::Deserialize)]
+ struct InteropLeaf {
+ generation: u32,
+ #[serde(with = "hex::serde")]
+ application_key: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ application_nonce: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ handshake_key: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ handshake_nonce: Vec<u8>,
+ }
+
+ fn load_interop_test_cases() -> Vec<InteropTestCase> {
+ load_test_case_json!(secret_tree_interop, generate_test_vector())
+ }
+
+ #[cfg(not(mls_build_async))]
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn generate_test_vector() -> Vec<InteropTestCase> {
+ let mut test_cases = vec![];
+
+ for cs in CipherSuite::all() {
+ let Some(cs) = try_test_cipher_suite_provider(*cs) else {
+ continue;
+ };
+
+ let gens = [0, 15];
+ let tree_sizes = [1, 8, 32];
+
+ for n_leaves in tree_sizes {
+ let encryption_secret = cs.random_bytes_vec(cs.kdf_extract_size()).unwrap();
+
+ let mut tree = SecretTree::new(n_leaves, Zeroizing::new(encryption_secret.clone()));
+
+ let leaves = (0..n_leaves)
+ .map(|leaf| {
+ gens.into_iter()
+ .map(|gen| {
+ let index = leaf * 2u32;
+
+ let handshake_key = tree
+ .message_key_generation(&cs, index, KeyType::Handshake, gen)
+ .unwrap();
+
+ let app_key = tree
+ .message_key_generation(&cs, index, KeyType::Application, gen)
+ .unwrap();
+
+ InteropLeaf {
+ generation: gen,
+ application_key: app_key.key.to_vec(),
+ application_nonce: app_key.nonce.to_vec(),
+ handshake_key: handshake_key.key.to_vec(),
+ handshake_nonce: handshake_key.nonce.to_vec(),
+ }
+ })
+ .collect()
+ })
+ .collect();
+
+ let case = InteropTestCase {
+ cipher_suite: *cs.cipher_suite(),
+ encryption_secret,
+ sender_data: InteropSenderData::new(&cs),
+ leaves,
+ };
+
+ test_cases.push(case);
+ }
+ }
+
+ test_cases
+ }
+
+ #[cfg(mls_build_async)]
+ fn generate_test_vector() -> Vec<InteropTestCase> {
+ panic!("Tests cannot be generated in async mode");
+ }
+}
diff --git a/src/group/snapshot.rs b/src/group/snapshot.rs
new file mode 100644
index 0000000..dca64f8
--- /dev/null
+++ b/src/group/snapshot.rs
@@ -0,0 +1,325 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use crate::{
+ client::MlsError,
+ client_config::ClientConfig,
+ group::{
+ key_schedule::KeySchedule, CommitGeneration, ConfirmationTag, Group, GroupContext,
+ GroupState, InterimTranscriptHash, ReInitProposal, TreeKemPublic,
+ },
+ tree_kem::TreeKemPrivate,
+};
+
+#[cfg(feature = "by_ref_proposal")]
+use crate::{
+ crypto::{HpkePublicKey, HpkeSecretKey},
+ group::ProposalRef,
+};
+
+#[cfg(feature = "by_ref_proposal")]
+use super::proposal_cache::{CachedProposal, ProposalCache};
+
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+
+use mls_rs_core::crypto::SignatureSecretKey;
+#[cfg(feature = "tree_index")]
+use mls_rs_core::identity::IdentityProvider;
+
+#[cfg(all(feature = "std", feature = "by_ref_proposal"))]
+use std::collections::HashMap;
+
+#[cfg(all(feature = "by_ref_proposal", not(feature = "std")))]
+use alloc::vec::Vec;
+
+use super::{cipher_suite_provider, epoch::EpochSecrets, state_repo::GroupStateRepository};
+
+#[derive(Debug, PartialEq, Clone, MlsEncode, MlsDecode, MlsSize)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub(crate) struct Snapshot {
+ version: u16,
+ pub(crate) state: RawGroupState,
+ private_tree: TreeKemPrivate,
+ epoch_secrets: EpochSecrets,
+ key_schedule: KeySchedule,
+ #[cfg(all(feature = "std", feature = "by_ref_proposal"))]
+ pending_updates: HashMap<HpkePublicKey, (HpkeSecretKey, Option<SignatureSecretKey>)>,
+ #[cfg(all(not(feature = "std"), feature = "by_ref_proposal"))]
+ pending_updates: Vec<(HpkePublicKey, (HpkeSecretKey, Option<SignatureSecretKey>))>,
+ pending_commit: Option<CommitGeneration>,
+ signer: SignatureSecretKey,
+}
+
+#[derive(Debug, MlsEncode, MlsDecode, MlsSize, PartialEq, Clone)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub(crate) struct RawGroupState {
+ pub(crate) context: GroupContext,
+ #[cfg(all(feature = "std", feature = "by_ref_proposal"))]
+ pub(crate) proposals: HashMap<ProposalRef, CachedProposal>,
+ #[cfg(all(not(feature = "std"), feature = "by_ref_proposal"))]
+ pub(crate) proposals: Vec<(ProposalRef, CachedProposal)>,
+ pub(crate) public_tree: TreeKemPublic,
+ pub(crate) interim_transcript_hash: InterimTranscriptHash,
+ pub(crate) pending_reinit: Option<ReInitProposal>,
+ pub(crate) confirmation_tag: ConfirmationTag,
+}
+
+impl RawGroupState {
+ pub(crate) fn export(state: &GroupState) -> Self {
+ #[cfg(feature = "tree_index")]
+ let public_tree = state.public_tree.clone();
+
+ #[cfg(not(feature = "tree_index"))]
+ let public_tree = {
+ let mut tree = TreeKemPublic::new();
+ tree.nodes = state.public_tree.nodes.clone();
+ tree
+ };
+
+ Self {
+ context: state.context.clone(),
+ #[cfg(feature = "by_ref_proposal")]
+ proposals: state.proposals.proposals.clone(),
+ public_tree,
+ interim_transcript_hash: state.interim_transcript_hash.clone(),
+ pending_reinit: state.pending_reinit.clone(),
+ confirmation_tag: state.confirmation_tag.clone(),
+ }
+ }
+
+ #[cfg(feature = "tree_index")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn import<C>(self, identity_provider: &C) -> Result<GroupState, MlsError>
+ where
+ C: IdentityProvider,
+ {
+ let context = self.context;
+
+ #[cfg(feature = "by_ref_proposal")]
+ let proposals = ProposalCache::import(
+ context.protocol_version,
+ context.group_id.clone(),
+ self.proposals,
+ );
+
+ let mut public_tree = self.public_tree;
+
+ public_tree
+ .initialize_index_if_necessary(identity_provider, &context.extensions)
+ .await?;
+
+ Ok(GroupState {
+ #[cfg(feature = "by_ref_proposal")]
+ proposals,
+ context,
+ public_tree,
+ interim_transcript_hash: self.interim_transcript_hash,
+ pending_reinit: self.pending_reinit,
+ confirmation_tag: self.confirmation_tag,
+ })
+ }
+
+ #[cfg(not(feature = "tree_index"))]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn import(self) -> Result<GroupState, MlsError> {
+ let context = self.context;
+
+ #[cfg(feature = "by_ref_proposal")]
+ let proposals = ProposalCache::import(
+ context.protocol_version,
+ context.group_id.clone(),
+ self.proposals,
+ );
+
+ Ok(GroupState {
+ #[cfg(feature = "by_ref_proposal")]
+ proposals,
+ context,
+ public_tree: self.public_tree,
+ interim_transcript_hash: self.interim_transcript_hash,
+ pending_reinit: self.pending_reinit,
+ confirmation_tag: self.confirmation_tag,
+ })
+ }
+}
+
+impl<C> Group<C>
+where
+ C: ClientConfig + Clone,
+{
+ /// Write the current state of the group to the
+ /// [`GroupStorageProvider`](crate::GroupStateStorage)
+ /// that is currently in use by the group.
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn write_to_storage(&mut self) -> Result<(), MlsError> {
+ self.state_repo.write_to_storage(self.snapshot()).await
+ }
+
+ pub(crate) fn snapshot(&self) -> Snapshot {
+ Snapshot {
+ state: RawGroupState::export(&self.state),
+ private_tree: self.private_tree.clone(),
+ key_schedule: self.key_schedule.clone(),
+ #[cfg(feature = "by_ref_proposal")]
+ pending_updates: self.pending_updates.clone(),
+ pending_commit: self.pending_commit.clone(),
+ epoch_secrets: self.epoch_secrets.clone(),
+ version: 1,
+ signer: self.signer.clone(),
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn from_snapshot(config: C, snapshot: Snapshot) -> Result<Self, MlsError> {
+ let cipher_suite_provider = cipher_suite_provider(
+ config.crypto_provider(),
+ snapshot.state.context.cipher_suite,
+ )?;
+
+ #[cfg(feature = "tree_index")]
+ let identity_provider = config.identity_provider();
+
+ let state_repo = GroupStateRepository::new(
+ #[cfg(feature = "prior_epoch")]
+ snapshot.state.context.group_id.clone(),
+ config.group_state_storage(),
+ config.key_package_repo(),
+ None,
+ )?;
+
+ Ok(Group {
+ config,
+ state: snapshot
+ .state
+ .import(
+ #[cfg(feature = "tree_index")]
+ &identity_provider,
+ )
+ .await?,
+ private_tree: snapshot.private_tree,
+ key_schedule: snapshot.key_schedule,
+ #[cfg(feature = "by_ref_proposal")]
+ pending_updates: snapshot.pending_updates,
+ pending_commit: snapshot.pending_commit,
+ #[cfg(test)]
+ commit_modifiers: Default::default(),
+ epoch_secrets: snapshot.epoch_secrets,
+ state_repo,
+ cipher_suite_provider,
+ #[cfg(feature = "psk")]
+ previous_psk: None,
+ signer: snapshot.signer,
+ })
+ }
+}
+
+#[cfg(test)]
+pub(crate) mod test_utils {
+ use alloc::vec;
+
+ use crate::{
+ cipher_suite::CipherSuite,
+ crypto::test_utils::test_cipher_suite_provider,
+ group::{
+ confirmation_tag::ConfirmationTag, epoch::test_utils::get_test_epoch_secrets,
+ key_schedule::test_utils::get_test_key_schedule, test_utils::get_test_group_context,
+ transcript_hash::InterimTranscriptHash,
+ },
+ tree_kem::{node::LeafIndex, TreeKemPrivate},
+ };
+
+ use super::{RawGroupState, Snapshot};
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn get_test_snapshot(cipher_suite: CipherSuite, epoch_id: u64) -> Snapshot {
+ Snapshot {
+ state: RawGroupState {
+ context: get_test_group_context(epoch_id, cipher_suite).await,
+ #[cfg(feature = "by_ref_proposal")]
+ proposals: Default::default(),
+ public_tree: Default::default(),
+ interim_transcript_hash: InterimTranscriptHash::from(vec![]),
+ pending_reinit: None,
+ confirmation_tag: ConfirmationTag::empty(&test_cipher_suite_provider(cipher_suite))
+ .await,
+ },
+ private_tree: TreeKemPrivate::new(LeafIndex(0)),
+ epoch_secrets: get_test_epoch_secrets(cipher_suite),
+ key_schedule: get_test_key_schedule(cipher_suite),
+ #[cfg(feature = "by_ref_proposal")]
+ pending_updates: Default::default(),
+ pending_commit: None,
+ version: 1,
+ signer: vec![].into(),
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use alloc::vec;
+
+ use crate::{
+ client::test_utils::{TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
+ group::{
+ test_utils::{test_group, TestGroup},
+ Group,
+ },
+ };
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn snapshot_restore(group: TestGroup) {
+ let snapshot = group.group.snapshot();
+
+ let group_restored = Group::from_snapshot(group.group.config.clone(), snapshot)
+ .await
+ .unwrap();
+
+ assert!(Group::equal_group_state(&group.group, &group_restored));
+
+ #[cfg(feature = "tree_index")]
+ assert!(group_restored
+ .state
+ .public_tree
+ .equal_internals(&group.group.state.public_tree))
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn snapshot_with_pending_commit_can_be_serialized_to_json() {
+ let mut group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ group.group.commit(vec![]).await.unwrap();
+
+ snapshot_restore(group).await
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn snapshot_with_pending_updates_can_be_serialized_to_json() {
+ let mut group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+
+ // Creating the update proposal will add it to pending updates
+ let update_proposal = group.update_proposal().await;
+
+ // This will insert the proposal into the internal proposal cache
+ let _ = group.group.proposal_message(update_proposal, vec![]).await;
+
+ snapshot_restore(group).await
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn snapshot_can_be_serialized_to_json_with_internals() {
+ let group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+
+ snapshot_restore(group).await
+ }
+
+ #[cfg(feature = "serde")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn serde() {
+ let snapshot = super::test_utils::get_test_snapshot(TEST_CIPHER_SUITE, 5).await;
+ let json = serde_json::to_string_pretty(&snapshot).unwrap();
+ let recovered = serde_json::from_str(&json).unwrap();
+ assert_eq!(snapshot, recovered);
+ }
+}
diff --git a/src/group/state.rs b/src/group/state.rs
new file mode 100644
index 0000000..4d97a04
--- /dev/null
+++ b/src/group/state.rs
@@ -0,0 +1,43 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use super::{
+ confirmation_tag::ConfirmationTag, proposal::ReInitProposal,
+ transcript_hash::InterimTranscriptHash,
+};
+use crate::group::{GroupContext, TreeKemPublic};
+
+#[derive(Clone, Debug, PartialEq)]
+#[non_exhaustive]
+pub struct GroupState {
+ #[cfg(feature = "by_ref_proposal")]
+ pub(crate) proposals: crate::group::ProposalCache,
+ pub(crate) context: GroupContext,
+ pub(crate) public_tree: TreeKemPublic,
+ pub(crate) interim_transcript_hash: InterimTranscriptHash,
+ pub(crate) pending_reinit: Option<ReInitProposal>,
+ pub(crate) confirmation_tag: ConfirmationTag,
+}
+
+impl GroupState {
+ pub(crate) fn new(
+ context: GroupContext,
+ current_tree: TreeKemPublic,
+ interim_transcript_hash: InterimTranscriptHash,
+ confirmation_tag: ConfirmationTag,
+ ) -> Self {
+ Self {
+ #[cfg(feature = "by_ref_proposal")]
+ proposals: crate::group::ProposalCache::new(
+ context.protocol_version,
+ context.group_id.clone(),
+ ),
+ context,
+ public_tree: current_tree,
+ interim_transcript_hash,
+ pending_reinit: None,
+ confirmation_tag,
+ }
+ }
+}
diff --git a/src/group/state_repo.rs b/src/group/state_repo.rs
new file mode 100644
index 0000000..6e33b0a
--- /dev/null
+++ b/src/group/state_repo.rs
@@ -0,0 +1,573 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use crate::client::MlsError;
+use crate::{group::PriorEpoch, key_package::KeyPackageRef};
+
+use alloc::collections::VecDeque;
+use alloc::vec::Vec;
+use core::fmt::{self, Debug};
+use mls_rs_codec::{MlsDecode, MlsEncode};
+use mls_rs_core::group::{EpochRecord, GroupState};
+use mls_rs_core::{error::IntoAnyError, group::GroupStateStorage, key_package::KeyPackageStorage};
+
+use super::snapshot::Snapshot;
+
+#[cfg(feature = "psk")]
+use crate::group::ResumptionPsk;
+
+#[cfg(feature = "psk")]
+use mls_rs_core::psk::PreSharedKey;
+
+/// A set of changes to apply to a GroupStateStorage implementation. These changes MUST
+/// be made in a single transaction to avoid creating invalid states.
+#[derive(Default, Clone, Debug)]
+struct EpochStorageCommit {
+ pub(crate) inserts: VecDeque<PriorEpoch>,
+ pub(crate) updates: Vec<PriorEpoch>,
+}
+
+#[derive(Clone)]
+pub(crate) struct GroupStateRepository<S, K>
+where
+ S: GroupStateStorage,
+ K: KeyPackageStorage,
+{
+ pending_commit: EpochStorageCommit,
+ pending_key_package_removal: Option<KeyPackageRef>,
+ group_id: Vec<u8>,
+ storage: S,
+ key_package_repo: K,
+}
+
+impl<S, K> Debug for GroupStateRepository<S, K>
+where
+ S: GroupStateStorage + Debug,
+ K: KeyPackageStorage + Debug,
+{
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("GroupStateRepository")
+ .field("pending_commit", &self.pending_commit)
+ .field(
+ "pending_key_package_removal",
+ &self.pending_key_package_removal,
+ )
+ .field(
+ "group_id",
+ &mls_rs_core::debug::pretty_group_id(&self.group_id),
+ )
+ .field("storage", &self.storage)
+ .field("key_package_repo", &self.key_package_repo)
+ .finish()
+ }
+}
+
+impl<S, K> GroupStateRepository<S, K>
+where
+ S: GroupStateStorage,
+ K: KeyPackageStorage,
+{
+ pub fn new(
+ group_id: Vec<u8>,
+ storage: S,
+ key_package_repo: K,
+ // Set to `None` if restoring from snapshot; set to `Some` when joining a group.
+ key_package_to_remove: Option<KeyPackageRef>,
+ ) -> Result<GroupStateRepository<S, K>, MlsError> {
+ Ok(GroupStateRepository {
+ group_id,
+ storage,
+ pending_key_package_removal: key_package_to_remove,
+ pending_commit: Default::default(),
+ key_package_repo,
+ })
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn find_max_id(&self) -> Result<Option<u64>, MlsError> {
+ if let Some(max) = self.pending_commit.inserts.back().map(|e| e.epoch_id()) {
+ Ok(Some(max))
+ } else {
+ self.storage
+ .max_epoch_id(&self.group_id)
+ .await
+ .map_err(|e| MlsError::GroupStorageError(e.into_any_error()))
+ }
+ }
+
+ #[cfg(feature = "psk")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn resumption_secret(
+ &self,
+ psk_id: &ResumptionPsk,
+ ) -> Result<Option<PreSharedKey>, MlsError> {
+ // Search the local inserts cache
+ if let Some(min) = self.pending_commit.inserts.front().map(|e| e.epoch_id()) {
+ if psk_id.psk_epoch >= min {
+ return Ok(self
+ .pending_commit
+ .inserts
+ .get((psk_id.psk_epoch - min) as usize)
+ .map(|e| e.secrets.resumption_secret.clone()));
+ }
+ }
+
+ // Search the local updates cache
+ let maybe_pending = self.find_pending(psk_id.psk_epoch);
+
+ if let Some(pending) = maybe_pending {
+ return Ok(Some(
+ self.pending_commit.updates[pending]
+ .secrets
+ .resumption_secret
+ .clone(),
+ ));
+ }
+
+ // Search the stored cache
+ self.storage
+ .epoch(&psk_id.psk_group_id.0, psk_id.psk_epoch)
+ .await
+ .map_err(|e| MlsError::GroupStorageError(e.into_any_error()))?
+ .map(|e| Ok(PriorEpoch::mls_decode(&mut &*e)?.secrets.resumption_secret))
+ .transpose()
+ }
+
+ #[cfg(feature = "private_message")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn get_epoch_mut(
+ &mut self,
+ epoch_id: u64,
+ ) -> Result<Option<&mut PriorEpoch>, MlsError> {
+ // Search the local inserts cache
+ if let Some(min) = self.pending_commit.inserts.front().map(|e| e.epoch_id()) {
+ if epoch_id >= min {
+ return Ok(self
+ .pending_commit
+ .inserts
+ .get_mut((epoch_id - min) as usize));
+ }
+ }
+
+ // Look in the cached updates map, and if not found look in disk storage
+ // and insert into the updates map for future caching
+ match self.find_pending(epoch_id) {
+ Some(i) => self.pending_commit.updates.get_mut(i).map(Ok),
+ None => self
+ .storage
+ .epoch(&self.group_id, epoch_id)
+ .await
+ .map_err(|e| MlsError::GroupStorageError(e.into_any_error()))?
+ .and_then(|epoch| {
+ PriorEpoch::mls_decode(&mut &*epoch)
+ .map(|epoch| {
+ self.pending_commit.updates.push(epoch);
+ self.pending_commit.updates.last_mut()
+ })
+ .transpose()
+ }),
+ }
+ .transpose()
+ .map_err(Into::into)
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn insert(&mut self, epoch: PriorEpoch) -> Result<(), MlsError> {
+ if epoch.group_id() != self.group_id {
+ return Err(MlsError::GroupIdMismatch);
+ }
+
+ let epoch_id = epoch.epoch_id();
+
+ if let Some(expected_id) = self.find_max_id().await?.map(|id| id + 1) {
+ if epoch_id != expected_id {
+ return Err(MlsError::InvalidEpoch);
+ }
+ }
+
+ self.pending_commit.inserts.push_back(epoch);
+
+ Ok(())
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn write_to_storage(&mut self, group_snapshot: Snapshot) -> Result<(), MlsError> {
+ let inserts = self
+ .pending_commit
+ .inserts
+ .iter()
+ .map(|e| Ok(EpochRecord::new(e.epoch_id(), e.mls_encode_to_vec()?)))
+ .collect::<Result<_, MlsError>>()?;
+
+ let updates = self
+ .pending_commit
+ .updates
+ .iter()
+ .map(|e| Ok(EpochRecord::new(e.epoch_id(), e.mls_encode_to_vec()?)))
+ .collect::<Result<_, MlsError>>()?;
+
+ let group_state = GroupState {
+ data: group_snapshot.mls_encode_to_vec()?,
+ id: group_snapshot.state.context.group_id,
+ };
+
+ self.storage
+ .write(group_state, inserts, updates)
+ .await
+ .map_err(|e| MlsError::GroupStorageError(e.into_any_error()))?;
+
+ if let Some(ref key_package_ref) = self.pending_key_package_removal {
+ self.key_package_repo
+ .delete(key_package_ref)
+ .await
+ .map_err(|e| MlsError::KeyPackageRepoError(e.into_any_error()))?;
+ }
+
+ self.pending_commit.inserts.clear();
+ self.pending_commit.updates.clear();
+
+ Ok(())
+ }
+
+ #[cfg(any(feature = "psk", feature = "private_message"))]
+ fn find_pending(&self, epoch_id: u64) -> Option<usize> {
+ self.pending_commit
+ .updates
+ .iter()
+ .position(|ep| ep.context.epoch == epoch_id)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use alloc::vec;
+ use mls_rs_codec::MlsEncode;
+
+ use crate::{
+ client::test_utils::{TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
+ group::{
+ epoch::{test_utils::get_test_epoch_with_id, SenderDataSecret},
+ test_utils::{random_bytes, test_member, TEST_GROUP},
+ PskGroupId, ResumptionPSKUsage,
+ },
+ storage_provider::in_memory::{InMemoryGroupStateStorage, InMemoryKeyPackageStorage},
+ };
+
+ use super::*;
+
+ fn test_group_state_repo(
+ retention_limit: usize,
+ ) -> GroupStateRepository<InMemoryGroupStateStorage, InMemoryKeyPackageStorage> {
+ GroupStateRepository::new(
+ TEST_GROUP.to_vec(),
+ InMemoryGroupStateStorage::new()
+ .with_max_epoch_retention(retention_limit)
+ .unwrap(),
+ InMemoryKeyPackageStorage::default(),
+ None,
+ )
+ .unwrap()
+ }
+
+ fn test_epoch(epoch_id: u64) -> PriorEpoch {
+ get_test_epoch_with_id(TEST_GROUP.to_vec(), TEST_CIPHER_SUITE, epoch_id)
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn test_snapshot(epoch_id: u64) -> Snapshot {
+ crate::group::snapshot::test_utils::get_test_snapshot(TEST_CIPHER_SUITE, epoch_id).await
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_epoch_inserts() {
+ let mut test_repo = test_group_state_repo(1);
+ let test_epoch = test_epoch(0);
+
+ test_repo.insert(test_epoch.clone()).await.unwrap();
+
+ // Check the in-memory state
+ assert_eq!(
+ test_repo.pending_commit.inserts.back().unwrap(),
+ &test_epoch
+ );
+
+ assert!(test_repo.pending_commit.updates.is_empty());
+
+ #[cfg(feature = "std")]
+ assert!(test_repo.storage.inner.lock().unwrap().is_empty());
+ #[cfg(not(feature = "std"))]
+ assert!(test_repo.storage.inner.lock().is_empty());
+
+ let psk_id = ResumptionPsk {
+ psk_epoch: 0,
+ psk_group_id: PskGroupId(test_repo.group_id.clone()),
+ usage: ResumptionPSKUsage::Application,
+ };
+
+ // Make sure you can recall an epoch sitting as a pending insert
+ let resumption = test_repo.resumption_secret(&psk_id).await.unwrap();
+ let prior_epoch = test_repo.get_epoch_mut(0).await.unwrap().cloned();
+
+ assert_eq!(
+ prior_epoch.clone().unwrap().secrets.resumption_secret,
+ resumption.unwrap()
+ );
+
+ assert_eq!(prior_epoch.unwrap(), test_epoch);
+
+ // Write to the storage
+ let snapshot = test_snapshot(test_epoch.epoch_id()).await;
+ test_repo.write_to_storage(snapshot.clone()).await.unwrap();
+
+ // Make sure the memory cache cleared
+ assert!(test_repo.pending_commit.inserts.is_empty());
+ assert!(test_repo.pending_commit.updates.is_empty());
+
+ // Make sure the storage was written
+ #[cfg(feature = "std")]
+ let storage = test_repo.storage.inner.lock().unwrap();
+ #[cfg(not(feature = "std"))]
+ let storage = test_repo.storage.inner.lock();
+
+ assert_eq!(storage.len(), 1);
+
+ let stored = storage.get(TEST_GROUP).unwrap();
+
+ assert_eq!(stored.state_data, snapshot.mls_encode_to_vec().unwrap());
+
+ assert_eq!(stored.epoch_data.len(), 1);
+
+ assert_eq!(
+ stored.epoch_data.back().unwrap(),
+ &EpochRecord::new(
+ test_epoch.epoch_id(),
+ test_epoch.mls_encode_to_vec().unwrap()
+ )
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_updates() {
+ let mut test_repo = test_group_state_repo(2);
+ let test_epoch_0 = test_epoch(0);
+
+ test_repo.insert(test_epoch_0.clone()).await.unwrap();
+
+ test_repo
+ .write_to_storage(test_snapshot(0).await)
+ .await
+ .unwrap();
+
+ // Update the stored epoch
+ let to_update = test_repo.get_epoch_mut(0).await.unwrap().unwrap();
+ assert_eq!(to_update, &test_epoch_0);
+
+ let new_sender_secret = random_bytes(32);
+ to_update.secrets.sender_data_secret = SenderDataSecret::from(new_sender_secret);
+ let to_update = to_update.clone();
+
+ assert_eq!(test_repo.pending_commit.updates.len(), 1);
+ assert!(test_repo.pending_commit.inserts.is_empty());
+
+ assert_eq!(
+ test_repo.pending_commit.updates.first().unwrap(),
+ &to_update
+ );
+
+ // Make sure you can access an epoch pending update
+ let psk_id = ResumptionPsk {
+ psk_epoch: 0,
+ psk_group_id: PskGroupId(test_repo.group_id.clone()),
+ usage: ResumptionPSKUsage::Application,
+ };
+
+ let owned = test_repo.resumption_secret(&psk_id).await.unwrap();
+ assert_eq!(owned.as_ref(), Some(&to_update.secrets.resumption_secret));
+
+ // Write the update to storage
+ let snapshot = test_snapshot(1).await;
+ test_repo.write_to_storage(snapshot.clone()).await.unwrap();
+
+ assert!(test_repo.pending_commit.updates.is_empty());
+ assert!(test_repo.pending_commit.inserts.is_empty());
+
+ // Make sure the storage was written
+ #[cfg(feature = "std")]
+ let storage = test_repo.storage.inner.lock().unwrap();
+ #[cfg(not(feature = "std"))]
+ let storage = test_repo.storage.inner.lock();
+
+ assert_eq!(storage.len(), 1);
+
+ let stored = storage.get(TEST_GROUP).unwrap();
+
+ assert_eq!(stored.state_data, snapshot.mls_encode_to_vec().unwrap());
+
+ assert_eq!(stored.epoch_data.len(), 1);
+
+ assert_eq!(
+ stored.epoch_data.back().unwrap(),
+ &EpochRecord::new(to_update.epoch_id(), to_update.mls_encode_to_vec().unwrap())
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_insert_and_update() {
+ let mut test_repo = test_group_state_repo(2);
+ let test_epoch_0 = test_epoch(0);
+
+ test_repo.insert(test_epoch_0).await.unwrap();
+
+ test_repo
+ .write_to_storage(test_snapshot(0).await)
+ .await
+ .unwrap();
+
+ // Update the stored epoch
+ let to_update = test_repo.get_epoch_mut(0).await.unwrap().unwrap();
+ let new_sender_secret = random_bytes(32);
+ to_update.secrets.sender_data_secret = SenderDataSecret::from(new_sender_secret);
+ let to_update = to_update.clone();
+
+ // Insert another epoch
+ let test_epoch_1 = test_epoch(1);
+ test_repo.insert(test_epoch_1.clone()).await.unwrap();
+
+ test_repo
+ .write_to_storage(test_snapshot(1).await)
+ .await
+ .unwrap();
+
+ assert!(test_repo.pending_commit.inserts.is_empty());
+ assert!(test_repo.pending_commit.updates.is_empty());
+
+ // Make sure the storage was written
+ #[cfg(feature = "std")]
+ let storage = test_repo.storage.inner.lock().unwrap();
+ #[cfg(not(feature = "std"))]
+ let storage = test_repo.storage.inner.lock();
+
+ assert_eq!(storage.len(), 1);
+
+ let stored = storage.get(TEST_GROUP).unwrap();
+
+ assert_eq!(stored.epoch_data.len(), 2);
+
+ assert_eq!(
+ stored.epoch_data.front().unwrap(),
+ &EpochRecord::new(to_update.epoch_id(), to_update.mls_encode_to_vec().unwrap())
+ );
+
+ assert_eq!(
+ stored.epoch_data.back().unwrap(),
+ &EpochRecord::new(
+ test_epoch_1.epoch_id(),
+ test_epoch_1.mls_encode_to_vec().unwrap()
+ )
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_many_epochs_in_storage() {
+ let epochs = (0..10).map(test_epoch).collect::<Vec<_>>();
+
+ let mut test_repo = test_group_state_repo(10);
+
+ for epoch in epochs.iter().cloned() {
+ test_repo.insert(epoch).await.unwrap()
+ }
+
+ test_repo
+ .write_to_storage(test_snapshot(9).await)
+ .await
+ .unwrap();
+
+ for mut epoch in epochs {
+ let res = test_repo.get_epoch_mut(epoch.epoch_id()).await.unwrap();
+
+ assert_eq!(res, Some(&mut epoch));
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_stored_groups_list() {
+ let mut test_repo = test_group_state_repo(2);
+ let test_epoch_0 = test_epoch(0);
+
+ test_repo.insert(test_epoch_0.clone()).await.unwrap();
+
+ test_repo
+ .write_to_storage(test_snapshot(0).await)
+ .await
+ .unwrap();
+
+ assert_eq!(
+ test_repo.storage.stored_groups(),
+ vec![test_epoch_0.context.group_id]
+ )
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn reducing_retention_limit_takes_effect_on_epoch_access() {
+ let mut repo = test_group_state_repo(1);
+
+ repo.insert(test_epoch(0)).await.unwrap();
+ repo.insert(test_epoch(1)).await.unwrap();
+
+ repo.write_to_storage(test_snapshot(0).await).await.unwrap();
+
+ let mut repo = GroupStateRepository {
+ storage: repo.storage,
+ ..test_group_state_repo(1)
+ };
+
+ let res = repo.get_epoch_mut(0).await.unwrap();
+
+ assert!(res.is_none());
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn in_memory_storage_obeys_retention_limit_after_saving() {
+ let mut repo = test_group_state_repo(1);
+
+ repo.insert(test_epoch(0)).await.unwrap();
+ repo.write_to_storage(test_snapshot(0).await).await.unwrap();
+ repo.insert(test_epoch(1)).await.unwrap();
+ repo.write_to_storage(test_snapshot(1).await).await.unwrap();
+
+ #[cfg(feature = "std")]
+ let lock = repo.storage.inner.lock().unwrap();
+ #[cfg(not(feature = "std"))]
+ let lock = repo.storage.inner.lock();
+
+ assert_eq!(lock.get(TEST_GROUP).unwrap().epoch_data.len(), 1);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn used_key_package_is_deleted() {
+ let key_package_repo = InMemoryKeyPackageStorage::default();
+
+ let key_package = test_member(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, b"member")
+ .await
+ .0;
+
+ let (id, data) = key_package.to_storage().unwrap();
+
+ key_package_repo.insert(id, data);
+
+ let mut repo = GroupStateRepository::new(
+ TEST_GROUP.to_vec(),
+ InMemoryGroupStateStorage::new(),
+ key_package_repo,
+ Some(key_package.reference.clone()),
+ )
+ .unwrap();
+
+ repo.key_package_repo.get(&key_package.reference).unwrap();
+
+ repo.write_to_storage(test_snapshot(4).await).await.unwrap();
+
+ assert!(repo.key_package_repo.get(&key_package.reference).is_none());
+ }
+}
diff --git a/src/group/state_repo_light.rs b/src/group/state_repo_light.rs
new file mode 100644
index 0000000..76d1fb6
--- /dev/null
+++ b/src/group/state_repo_light.rs
@@ -0,0 +1,132 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use crate::client::MlsError;
+use crate::key_package::KeyPackageRef;
+
+use alloc::vec::Vec;
+use mls_rs_codec::MlsEncode;
+use mls_rs_core::{
+ error::IntoAnyError,
+ group::{GroupState, GroupStateStorage},
+ key_package::KeyPackageStorage,
+};
+
+use super::snapshot::Snapshot;
+
+#[derive(Debug, Clone)]
+pub(crate) struct GroupStateRepository<S, K>
+where
+ S: GroupStateStorage,
+ K: KeyPackageStorage,
+{
+ pending_key_package_removal: Option<KeyPackageRef>,
+ storage: S,
+ key_package_repo: K,
+}
+
+impl<S, K> GroupStateRepository<S, K>
+where
+ S: GroupStateStorage,
+ K: KeyPackageStorage,
+{
+ pub fn new(
+ storage: S,
+ key_package_repo: K,
+ // Set to `None` if restoring from snapshot; set to `Some` when joining a group.
+ key_package_to_remove: Option<KeyPackageRef>,
+ ) -> Result<GroupStateRepository<S, K>, MlsError> {
+ Ok(GroupStateRepository {
+ storage,
+ pending_key_package_removal: key_package_to_remove,
+ key_package_repo,
+ })
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn write_to_storage(&mut self, group_snapshot: Snapshot) -> Result<(), MlsError> {
+ let group_state = GroupState {
+ data: group_snapshot.mls_encode_to_vec()?,
+ id: group_snapshot.state.context.group_id,
+ };
+
+ self.storage
+ .write(group_state, Vec::new(), Vec::new())
+ .await
+ .map_err(|e| MlsError::GroupStorageError(e.into_any_error()))?;
+
+ if let Some(ref key_package_ref) = self.pending_key_package_removal {
+ self.key_package_repo
+ .delete(key_package_ref)
+ .await
+ .map_err(|e| MlsError::KeyPackageRepoError(e.into_any_error()))?;
+ }
+
+ Ok(())
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use crate::{
+ client::test_utils::{TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
+ group::{
+ snapshot::{test_utils::get_test_snapshot, Snapshot},
+ test_utils::{test_member, TEST_GROUP},
+ },
+ storage_provider::in_memory::{InMemoryGroupStateStorage, InMemoryKeyPackageStorage},
+ };
+
+ use alloc::vec;
+
+ use super::GroupStateRepository;
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn test_snapshot(epoch_id: u64) -> Snapshot {
+ get_test_snapshot(TEST_CIPHER_SUITE, epoch_id).await
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_stored_groups_list() {
+ let mut test_repo = GroupStateRepository::new(
+ InMemoryGroupStateStorage::default(),
+ InMemoryKeyPackageStorage::default(),
+ None,
+ )
+ .unwrap();
+
+ test_repo
+ .write_to_storage(test_snapshot(0).await)
+ .await
+ .unwrap();
+
+ assert_eq!(test_repo.storage.stored_groups(), vec![TEST_GROUP])
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn used_key_package_is_deleted() {
+ let key_package_repo = InMemoryKeyPackageStorage::default();
+
+ let key_package = test_member(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, b"member")
+ .await
+ .0;
+
+ let (id, data) = key_package.to_storage().unwrap();
+
+ key_package_repo.insert(id, data);
+
+ let mut repo = GroupStateRepository::new(
+ InMemoryGroupStateStorage::default(),
+ key_package_repo,
+ Some(key_package.reference.clone()),
+ )
+ .unwrap();
+
+ repo.key_package_repo.get(&key_package.reference).unwrap();
+
+ repo.write_to_storage(test_snapshot(4).await).await.unwrap();
+
+ assert!(repo.key_package_repo.get(&key_package.reference).is_none());
+ }
+}
diff --git a/src/group/test_utils.rs b/src/group/test_utils.rs
new file mode 100644
index 0000000..764d5e6
--- /dev/null
+++ b/src/group/test_utils.rs
@@ -0,0 +1,521 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use core::ops::{Deref, DerefMut};
+
+use alloc::format;
+use rand::RngCore;
+
+use super::*;
+use crate::{
+ client::{
+ test_utils::{
+ test_client_with_key_pkg, test_client_with_key_pkg_custom, TEST_CIPHER_SUITE,
+ TEST_PROTOCOL_VERSION,
+ },
+ MlsError,
+ },
+ client_builder::test_utils::{TestClientBuilder, TestClientConfig},
+ crypto::test_utils::test_cipher_suite_provider,
+ extension::ExtensionType,
+ identity::basic::BasicIdentityProvider,
+ identity::test_utils::get_test_signing_identity,
+ key_package::{KeyPackageGeneration, KeyPackageGenerator},
+ mls_rules::{CommitOptions, DefaultMlsRules},
+ tree_kem::{leaf_node::test_utils::get_test_capabilities, Lifetime},
+};
+
+use crate::extension::RequiredCapabilitiesExt;
+
+#[cfg(not(feature = "by_ref_proposal"))]
+use crate::crypto::HpkePublicKey;
+
+pub const TEST_GROUP: &[u8] = b"group";
+
+#[derive(Clone)]
+pub(crate) struct TestGroup {
+ pub group: Group<TestClientConfig>,
+}
+
+impl TestGroup {
+ #[cfg(feature = "external_client")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn propose(&mut self, proposal: Proposal) -> MlsMessage {
+ self.group.proposal_message(proposal, vec![]).await.unwrap()
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn update_proposal(&mut self) -> Proposal {
+ self.group.update_proposal(None, None).await.unwrap()
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn join_with_custom_config<F>(
+ &mut self,
+ name: &str,
+ custom_kp: bool,
+ mut config: F,
+ ) -> Result<(TestGroup, MlsMessage), MlsError>
+ where
+ F: FnMut(&mut TestClientConfig),
+ {
+ let (mut new_client, new_key_package) = if custom_kp {
+ test_client_with_key_pkg_custom(
+ self.group.protocol_version(),
+ self.group.cipher_suite(),
+ name,
+ &mut config,
+ )
+ .await
+ } else {
+ test_client_with_key_pkg(
+ self.group.protocol_version(),
+ self.group.cipher_suite(),
+ name,
+ )
+ .await
+ };
+
+ // Add new member to the group
+ let CommitOutput {
+ welcome_messages,
+ ratchet_tree,
+ commit_message,
+ ..
+ } = self
+ .group
+ .commit_builder()
+ .add_member(new_key_package)
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ // Apply the commit to the original group
+ self.group.apply_pending_commit().await.unwrap();
+
+ config(&mut new_client.config);
+
+ // Group from new member's perspective
+ let (new_group, _) = Group::join(
+ &welcome_messages[0],
+ ratchet_tree,
+ new_client.config.clone(),
+ new_client.signer.clone().unwrap(),
+ )
+ .await?;
+
+ let new_test_group = TestGroup { group: new_group };
+
+ Ok((new_test_group, commit_message))
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn join(&mut self, name: &str) -> (TestGroup, MlsMessage) {
+ self.join_with_custom_config(name, false, |_| ())
+ .await
+ .unwrap()
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn process_pending_commit(
+ &mut self,
+ ) -> Result<CommitMessageDescription, MlsError> {
+ self.group.apply_pending_commit().await
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn process_message(
+ &mut self,
+ message: MlsMessage,
+ ) -> Result<ReceivedMessage, MlsError> {
+ self.group.process_incoming_message(message).await
+ }
+
+ #[cfg(feature = "private_message")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn make_plaintext(&mut self, content: Content) -> MlsMessage {
+ let auth_content = AuthenticatedContent::new_signed(
+ &self.group.cipher_suite_provider,
+ &self.group.state.context,
+ Sender::Member(*self.group.private_tree.self_index),
+ content,
+ &self.group.signer,
+ WireFormat::PublicMessage,
+ Vec::new(),
+ )
+ .await
+ .unwrap();
+
+ self.group.format_for_wire(auth_content).await.unwrap()
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(crate) async fn get_test_group_context(epoch: u64, cipher_suite: CipherSuite) -> GroupContext {
+ let cs = test_cipher_suite_provider(cipher_suite);
+
+ GroupContext {
+ protocol_version: TEST_PROTOCOL_VERSION,
+ cipher_suite,
+ group_id: TEST_GROUP.to_vec(),
+ epoch,
+ tree_hash: cs.hash(&[1, 2, 3]).await.unwrap(),
+ confirmed_transcript_hash: cs.hash(&[3, 2, 1]).await.unwrap().into(),
+ extensions: ExtensionList::from(vec![]),
+ }
+}
+
+#[cfg(feature = "prior_epoch")]
+pub(crate) fn get_test_group_context_with_id(
+ group_id: Vec<u8>,
+ epoch: u64,
+ cipher_suite: CipherSuite,
+) -> GroupContext {
+ GroupContext {
+ protocol_version: TEST_PROTOCOL_VERSION,
+ cipher_suite,
+ group_id,
+ epoch,
+ tree_hash: vec![],
+ confirmed_transcript_hash: ConfirmedTranscriptHash::from(vec![]),
+ extensions: ExtensionList::from(vec![]),
+ }
+}
+
+pub(crate) fn group_extensions() -> ExtensionList {
+ let required_capabilities = RequiredCapabilitiesExt::default();
+
+ let mut extensions = ExtensionList::new();
+ extensions.set_from(required_capabilities).unwrap();
+ extensions
+}
+
+pub(crate) fn lifetime() -> Lifetime {
+ Lifetime::years(1).unwrap()
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(crate) async fn test_member(
+ protocol_version: ProtocolVersion,
+ cipher_suite: CipherSuite,
+ identifier: &[u8],
+) -> (KeyPackageGeneration, SignatureSecretKey) {
+ let (signing_identity, signing_key) = get_test_signing_identity(cipher_suite, identifier).await;
+
+ let key_package_generator = KeyPackageGenerator {
+ protocol_version,
+ cipher_suite_provider: &test_cipher_suite_provider(cipher_suite),
+ signing_identity: &signing_identity,
+ signing_key: &signing_key,
+ identity_provider: &BasicIdentityProvider,
+ };
+
+ let key_package = key_package_generator
+ .generate(
+ lifetime(),
+ get_test_capabilities(),
+ ExtensionList::default(),
+ ExtensionList::default(),
+ )
+ .await
+ .unwrap();
+
+ (key_package, signing_key)
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(crate) async fn test_group_custom(
+ protocol_version: ProtocolVersion,
+ cipher_suite: CipherSuite,
+ extension_types: Vec<ExtensionType>,
+ leaf_extensions: Option<ExtensionList>,
+ commit_options: Option<CommitOptions>,
+) -> TestGroup {
+ let leaf_extensions = leaf_extensions.unwrap_or_default();
+ let commit_options = commit_options.unwrap_or_default();
+
+ let (signing_identity, secret_key) = get_test_signing_identity(cipher_suite, b"member").await;
+
+ let group = TestClientBuilder::new_for_test()
+ .leaf_node_extensions(leaf_extensions)
+ .mls_rules(DefaultMlsRules::default().with_commit_options(commit_options))
+ .extension_types(extension_types)
+ .protocol_versions(ProtocolVersion::all())
+ .used_protocol_version(protocol_version)
+ .signing_identity(signing_identity.clone(), secret_key, cipher_suite)
+ .build()
+ .create_group_with_id(TEST_GROUP.to_vec(), group_extensions())
+ .await
+ .unwrap();
+
+ TestGroup { group }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(crate) async fn test_group(
+ protocol_version: ProtocolVersion,
+ cipher_suite: CipherSuite,
+) -> TestGroup {
+ test_group_custom(
+ protocol_version,
+ cipher_suite,
+ Default::default(),
+ None,
+ None,
+ )
+ .await
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(crate) async fn test_group_custom_config<F>(
+ protocol_version: ProtocolVersion,
+ cipher_suite: CipherSuite,
+ custom: F,
+) -> TestGroup
+where
+ F: FnOnce(TestClientBuilder) -> TestClientBuilder,
+{
+ let (signing_identity, secret_key) = get_test_signing_identity(cipher_suite, b"member").await;
+
+ let client_builder = TestClientBuilder::new_for_test().used_protocol_version(protocol_version);
+
+ let group = custom(client_builder)
+ .signing_identity(signing_identity.clone(), secret_key, cipher_suite)
+ .build()
+ .create_group_with_id(TEST_GROUP.to_vec(), group_extensions())
+ .await
+ .unwrap();
+
+ TestGroup { group }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(crate) async fn test_n_member_group(
+ protocol_version: ProtocolVersion,
+ cipher_suite: CipherSuite,
+ num_members: usize,
+) -> Vec<TestGroup> {
+ let group = test_group(protocol_version, cipher_suite).await;
+
+ let mut groups = vec![group];
+
+ for i in 1..num_members {
+ let (new_group, commit) = groups.get_mut(0).unwrap().join(&format!("name {i}")).await;
+ process_commit(&mut groups, commit, 0).await;
+ groups.push(new_group);
+ }
+
+ groups
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(crate) async fn process_commit(groups: &mut [TestGroup], commit: MlsMessage, excluded: u32) {
+ for g in groups
+ .iter_mut()
+ .filter(|g| g.group.current_member_index() != excluded)
+ {
+ g.process_message(commit.clone()).await.unwrap();
+ }
+}
+
+pub(crate) fn get_test_25519_key(key_byte: u8) -> HpkePublicKey {
+ vec![key_byte; 32].into()
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(crate) async fn get_test_groups_with_features(
+ n: usize,
+ extensions: ExtensionList,
+ leaf_extensions: ExtensionList,
+) -> Vec<Group<TestClientConfig>> {
+ let mut clients = Vec::new();
+
+ for i in 0..n {
+ let (identity, secret_key) =
+ get_test_signing_identity(TEST_CIPHER_SUITE, format!("member{i}").as_bytes()).await;
+
+ clients.push(
+ TestClientBuilder::new_for_test()
+ .extension_type(999.into())
+ .leaf_node_extensions(leaf_extensions.clone())
+ .signing_identity(identity, secret_key, TEST_CIPHER_SUITE)
+ .build(),
+ );
+ }
+
+ let group = clients[0]
+ .create_group_with_id(b"TEST GROUP".to_vec(), extensions)
+ .await
+ .unwrap();
+
+ let mut groups = vec![group];
+
+ for client in clients.iter().skip(1) {
+ let key_package = client.generate_key_package_message().await.unwrap();
+
+ let commit_output = groups[0]
+ .commit_builder()
+ .add_member(key_package)
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ groups[0].apply_pending_commit().await.unwrap();
+
+ for group in groups.iter_mut().skip(1) {
+ group
+ .process_incoming_message(commit_output.commit_message.clone())
+ .await
+ .unwrap();
+ }
+
+ groups.push(
+ client
+ .join_group(None, &commit_output.welcome_messages[0])
+ .await
+ .unwrap()
+ .0,
+ );
+ }
+
+ groups
+}
+
+pub fn random_bytes(count: usize) -> Vec<u8> {
+ let mut buf = vec![0; count];
+ rand::thread_rng().fill_bytes(&mut buf);
+ buf
+}
+
+pub(crate) struct GroupWithoutKeySchedule {
+ inner: Group<TestClientConfig>,
+ pub secrets: Option<(TreeKemPrivate, PathSecret)>,
+ pub provisional_public_state: Option<ProvisionalState>,
+}
+
+impl Deref for GroupWithoutKeySchedule {
+ type Target = Group<TestClientConfig>;
+
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn deref(&self) -> &Self::Target {
+ &self.inner
+ }
+}
+
+impl DerefMut for GroupWithoutKeySchedule {
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn deref_mut(&mut self) -> &mut Self::Target {
+ &mut self.inner
+ }
+}
+
+#[cfg(feature = "rfc_compliant")]
+impl GroupWithoutKeySchedule {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn new(cs: CipherSuite) -> Self {
+ Self {
+ inner: test_group(TEST_PROTOCOL_VERSION, cs).await.group,
+ secrets: None,
+ provisional_public_state: None,
+ }
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(all(target_arch = "wasm32", mls_build_async), maybe_async::must_be_async(?Send))]
+#[cfg_attr(
+ all(not(target_arch = "wasm32"), mls_build_async),
+ maybe_async::must_be_async
+)]
+impl MessageProcessor for GroupWithoutKeySchedule {
+ type CipherSuiteProvider = <Group<TestClientConfig> as MessageProcessor>::CipherSuiteProvider;
+ type OutputType = <Group<TestClientConfig> as MessageProcessor>::OutputType;
+ type PreSharedKeyStorage = <Group<TestClientConfig> as MessageProcessor>::PreSharedKeyStorage;
+ type IdentityProvider = <Group<TestClientConfig> as MessageProcessor>::IdentityProvider;
+ type MlsRules = <Group<TestClientConfig> as MessageProcessor>::MlsRules;
+
+ fn group_state(&self) -> &GroupState {
+ self.inner.group_state()
+ }
+
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn group_state_mut(&mut self) -> &mut GroupState {
+ self.inner.group_state_mut()
+ }
+
+ fn mls_rules(&self) -> Self::MlsRules {
+ self.inner.mls_rules()
+ }
+
+ fn identity_provider(&self) -> Self::IdentityProvider {
+ self.inner.identity_provider()
+ }
+
+ fn cipher_suite_provider(&self) -> &Self::CipherSuiteProvider {
+ self.inner.cipher_suite_provider()
+ }
+
+ fn psk_storage(&self) -> Self::PreSharedKeyStorage {
+ self.inner.psk_storage()
+ }
+
+ fn can_continue_processing(&self, provisional_state: &ProvisionalState) -> bool {
+ self.inner.can_continue_processing(provisional_state)
+ }
+
+ #[cfg(feature = "private_message")]
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn min_epoch_available(&self) -> Option<u64> {
+ self.inner.min_epoch_available()
+ }
+
+ async fn apply_update_path(
+ &mut self,
+ sender: LeafIndex,
+ update_path: &ValidatedUpdatePath,
+ provisional_state: &mut ProvisionalState,
+ ) -> Result<Option<(TreeKemPrivate, PathSecret)>, MlsError> {
+ self.inner
+ .apply_update_path(sender, update_path, provisional_state)
+ .await
+ }
+
+ #[cfg(feature = "private_message")]
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ async fn process_ciphertext(
+ &mut self,
+ cipher_text: &PrivateMessage,
+ ) -> Result<EventOrContent<Self::OutputType>, MlsError> {
+ self.inner.process_ciphertext(cipher_text).await
+ }
+
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ async fn verify_plaintext_authentication(
+ &self,
+ message: PublicMessage,
+ ) -> Result<EventOrContent<Self::OutputType>, MlsError> {
+ self.inner.verify_plaintext_authentication(message).await
+ }
+
+ async fn update_key_schedule(
+ &mut self,
+ secrets: Option<(TreeKemPrivate, PathSecret)>,
+ _interim_transcript_hash: InterimTranscriptHash,
+ _confirmation_tag: &ConfirmationTag,
+ provisional_public_state: ProvisionalState,
+ ) -> Result<(), MlsError> {
+ self.provisional_public_state = Some(provisional_public_state);
+ self.secrets = secrets;
+ Ok(())
+ }
+
+ #[cfg(feature = "private_message")]
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn self_index(&self) -> Option<LeafIndex> {
+ <Group<TestClientConfig> as MessageProcessor>::self_index(&self.inner)
+ }
+}
diff --git a/src/group/transcript_hash.rs b/src/group/transcript_hash.rs
new file mode 100644
index 0000000..c336dfa
--- /dev/null
+++ b/src/group/transcript_hash.rs
@@ -0,0 +1,293 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::vec::Vec;
+use core::{
+ fmt::{self, Debug},
+ ops::Deref,
+};
+
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+use mls_rs_core::{crypto::CipherSuiteProvider, error::IntoAnyError};
+
+use crate::{
+ client::MlsError,
+ group::{framing::FramedContent, MessageSignature},
+ WireFormat,
+};
+
+use super::{AuthenticatedContent, ConfirmationTag};
+
+#[derive(Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub struct ConfirmedTranscriptHash(
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
+ Vec<u8>,
+);
+
+impl Debug for ConfirmedTranscriptHash {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ mls_rs_core::debug::pretty_bytes(&self.0)
+ .named("ConfirmedTranscriptHash")
+ .fmt(f)
+ }
+}
+
+impl Deref for ConfirmedTranscriptHash {
+ type Target = Vec<u8>;
+
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+impl From<Vec<u8>> for ConfirmedTranscriptHash {
+ fn from(value: Vec<u8>) -> Self {
+ Self(value)
+ }
+}
+
+impl ConfirmedTranscriptHash {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn create<P: CipherSuiteProvider>(
+ cipher_suite_provider: &P,
+ interim_transcript_hash: &InterimTranscriptHash,
+ content: &AuthenticatedContent,
+ ) -> Result<Self, MlsError> {
+ #[derive(Debug, MlsSize, MlsEncode)]
+ struct ConfirmedTranscriptHashInput<'a> {
+ wire_format: WireFormat,
+ content: &'a FramedContent,
+ signature: &'a MessageSignature,
+ }
+
+ let input = ConfirmedTranscriptHashInput {
+ wire_format: content.wire_format,
+ content: &content.content,
+ signature: &content.auth.signature,
+ };
+
+ let hash_input = [
+ interim_transcript_hash.deref(),
+ input.mls_encode_to_vec()?.deref(),
+ ]
+ .concat();
+
+ cipher_suite_provider
+ .hash(&hash_input)
+ .await
+ .map(Into::into)
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
+ }
+}
+
+#[derive(Clone, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub(crate) struct InterimTranscriptHash(
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
+ Vec<u8>,
+);
+
+impl Debug for InterimTranscriptHash {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ mls_rs_core::debug::pretty_bytes(&self.0)
+ .named("InterimTranscriptHash")
+ .fmt(f)
+ }
+}
+
+impl Deref for InterimTranscriptHash {
+ type Target = Vec<u8>;
+
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+impl From<Vec<u8>> for InterimTranscriptHash {
+ fn from(value: Vec<u8>) -> Self {
+ Self(value)
+ }
+}
+
+impl InterimTranscriptHash {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn create<P: CipherSuiteProvider>(
+ cipher_suite_provider: &P,
+ confirmed: &ConfirmedTranscriptHash,
+ confirmation_tag: &ConfirmationTag,
+ ) -> Result<Self, MlsError> {
+ #[derive(Debug, MlsSize, MlsEncode)]
+ struct InterimTranscriptHashInput<'a> {
+ confirmation_tag: &'a ConfirmationTag,
+ }
+
+ let input = InterimTranscriptHashInput { confirmation_tag }.mls_encode_to_vec()?;
+
+ cipher_suite_provider
+ .hash(&[confirmed.0.deref(), &input].concat())
+ .await
+ .map(Into::into)
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
+ }
+}
+
+// Test vectors come from the MLS interop repository and contain a proposal by reference.
+#[cfg(feature = "by_ref_proposal")]
+#[cfg(test)]
+mod tests {
+ use alloc::vec::Vec;
+
+ use mls_rs_codec::MlsDecode;
+
+ use crate::{
+ crypto::test_utils::try_test_cipher_suite_provider,
+ group::{framing::ContentType, message_signature::AuthenticatedContent, transcript_hashes},
+ };
+
+ #[cfg(not(mls_build_async))]
+ use alloc::{boxed::Box, vec};
+
+ #[cfg(not(mls_build_async))]
+ use crate::{
+ crypto::test_utils::test_cipher_suite_provider,
+ group::{
+ confirmation_tag::ConfirmationTag,
+ framing::Content,
+ proposal::{Proposal, ProposalOrRef, RemoveProposal},
+ test_utils::get_test_group_context,
+ Commit, LeafIndex, Sender,
+ },
+ mls_rs_codec::MlsEncode,
+ CipherSuite, CipherSuiteProvider, WireFormat,
+ };
+
+ #[cfg(not(mls_build_async))]
+ use super::{ConfirmedTranscriptHash, InterimTranscriptHash};
+
+ #[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
+ struct TestCase {
+ pub cipher_suite: u16,
+
+ #[serde(with = "hex::serde")]
+ pub confirmation_key: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub authenticated_content: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub interim_transcript_hash_before: Vec<u8>,
+
+ #[serde(with = "hex::serde")]
+ pub confirmed_transcript_hash_after: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub interim_transcript_hash_after: Vec<u8>,
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn transcript_hash() {
+ let test_cases: Vec<TestCase> =
+ load_test_case_json!(interop_transcript_hashes, generate_test_vector());
+
+ for test_case in test_cases.into_iter() {
+ let Some(cs) = try_test_cipher_suite_provider(test_case.cipher_suite) else {
+ continue;
+ };
+
+ let auth_content =
+ AuthenticatedContent::mls_decode(&mut &*test_case.authenticated_content).unwrap();
+
+ assert!(auth_content.content.content_type() == ContentType::Commit);
+
+ let conf_key = &test_case.confirmation_key;
+ let conf_hash_after = test_case.confirmed_transcript_hash_after.into();
+ let conf_tag = auth_content.auth.confirmation_tag.clone().unwrap();
+
+ let matches = conf_tag
+ .matches(conf_key, &conf_hash_after, &cs)
+ .await
+ .unwrap();
+
+ assert!(matches);
+
+ let (expected_interim, expected_conf) = transcript_hashes(
+ &cs,
+ &test_case.interim_transcript_hash_before.into(),
+ &auth_content,
+ )
+ .await
+ .unwrap();
+
+ assert_eq!(*expected_interim, test_case.interim_transcript_hash_after);
+ assert_eq!(expected_conf, conf_hash_after);
+ }
+ }
+
+ #[cfg(not(mls_build_async))]
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn generate_test_vector() -> Vec<TestCase> {
+ CipherSuite::all().fold(vec![], |mut test_cases, cs| {
+ let cs = test_cipher_suite_provider(cs);
+
+ let context = get_test_group_context(0x3456, cs.cipher_suite());
+
+ let proposal = Proposal::Remove(RemoveProposal {
+ to_remove: LeafIndex(1),
+ });
+
+ let proposal = ProposalOrRef::Proposal(Box::new(proposal));
+
+ let commit = Commit {
+ proposals: vec![proposal],
+ path: None,
+ };
+
+ let signer = cs.signature_key_generate().unwrap().0;
+
+ let mut auth_content = AuthenticatedContent::new_signed(
+ &cs,
+ &context,
+ Sender::Member(0),
+ Content::Commit(alloc::boxed::Box::new(commit)),
+ &signer,
+ WireFormat::PublicMessage,
+ vec![],
+ )
+ .unwrap();
+
+ let interim_hash_before = cs.random_bytes_vec(cs.kdf_extract_size()).unwrap().into();
+
+ let conf_hash_after =
+ ConfirmedTranscriptHash::create(&cs, &interim_hash_before, &auth_content).unwrap();
+
+ let conf_key = cs.random_bytes_vec(cs.kdf_extract_size()).unwrap();
+ let conf_tag = ConfirmationTag::create(&conf_key, &conf_hash_after, &cs).unwrap();
+
+ let interim_hash_after =
+ InterimTranscriptHash::create(&cs, &conf_hash_after, &conf_tag).unwrap();
+
+ auth_content.auth.confirmation_tag = Some(conf_tag);
+
+ let test_case = TestCase {
+ cipher_suite: cs.cipher_suite().into(),
+
+ confirmation_key: conf_key,
+ authenticated_content: auth_content.mls_encode_to_vec().unwrap(),
+ interim_transcript_hash_before: interim_hash_before.0,
+
+ confirmed_transcript_hash_after: conf_hash_after.0,
+ interim_transcript_hash_after: interim_hash_after.0,
+ };
+
+ test_cases.push(test_case);
+ test_cases
+ })
+ }
+
+ #[cfg(mls_build_async)]
+ fn generate_test_vector() -> Vec<TestCase> {
+ panic!("Tests cannot be generated in async mode");
+ }
+}
diff --git a/src/group/util.rs b/src/group/util.rs
new file mode 100644
index 0000000..dadfafa
--- /dev/null
+++ b/src/group/util.rs
@@ -0,0 +1,202 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use mls_rs_core::{
+ error::IntoAnyError, identity::IdentityProvider, key_package::KeyPackageStorage,
+};
+
+use crate::{
+ cipher_suite::CipherSuite,
+ client::MlsError,
+ extension::RatchetTreeExt,
+ key_package::KeyPackageGeneration,
+ protocol_version::ProtocolVersion,
+ signer::Signable,
+ tree_kem::{node::LeafIndex, tree_validator::TreeValidator, TreeKemPublic},
+ CipherSuiteProvider, CryptoProvider,
+};
+
+#[cfg(feature = "by_ref_proposal")]
+use crate::extension::ExternalSendersExt;
+
+use super::{
+ framing::Sender, message_signature::AuthenticatedContent,
+ transcript_hash::InterimTranscriptHash, ConfirmedTranscriptHash, EncryptedGroupSecrets,
+ ExportedTree, GroupInfo, GroupState,
+};
+
+use super::message_processor::ProvisionalState;
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(crate) async fn validate_group_info_common<C: CipherSuiteProvider>(
+ msg_version: ProtocolVersion,
+ group_info: &GroupInfo,
+ tree: &TreeKemPublic,
+ cs: &C,
+) -> Result<(), MlsError> {
+ if msg_version != group_info.group_context.protocol_version {
+ return Err(MlsError::ProtocolVersionMismatch);
+ }
+
+ if group_info.group_context.cipher_suite != cs.cipher_suite() {
+ return Err(MlsError::CipherSuiteMismatch);
+ }
+
+ let sender_leaf = &tree.get_leaf_node(group_info.signer)?;
+
+ group_info
+ .verify(cs, &sender_leaf.signing_identity.signature_key, &())
+ .await?;
+
+ Ok(())
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(crate) async fn validate_group_info_member<C: CipherSuiteProvider>(
+ self_state: &GroupState,
+ msg_version: ProtocolVersion,
+ group_info: &GroupInfo,
+ cs: &C,
+) -> Result<(), MlsError> {
+ validate_group_info_common(msg_version, group_info, &self_state.public_tree, cs).await?;
+
+ let self_tree = ExportedTree::new_borrowed(&self_state.public_tree.nodes);
+
+ if let Some(tree) = group_info.extensions.get_as::<RatchetTreeExt>()? {
+ (tree.tree_data == self_tree)
+ .then_some(())
+ .ok_or(MlsError::InvalidGroupInfo)?;
+ }
+
+ (group_info.group_context == self_state.context
+ && group_info.confirmation_tag == self_state.confirmation_tag)
+ .then_some(())
+ .ok_or(MlsError::InvalidGroupInfo)?;
+
+ Ok(())
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(crate) async fn validate_group_info_joiner<C, I>(
+ msg_version: ProtocolVersion,
+ group_info: &GroupInfo,
+ tree: Option<ExportedTree<'_>>,
+ id_provider: &I,
+ cs: &C,
+) -> Result<TreeKemPublic, MlsError>
+where
+ C: CipherSuiteProvider,
+ I: IdentityProvider,
+{
+ let tree = match group_info.extensions.get_as::<RatchetTreeExt>()? {
+ Some(ext) => ext.tree_data,
+ None => tree.ok_or(MlsError::RatchetTreeNotFound)?,
+ };
+
+ let context = &group_info.group_context;
+
+ let mut tree =
+ TreeKemPublic::import_node_data(tree.into(), id_provider, &context.extensions).await?;
+
+ // Verify the integrity of the ratchet tree
+ TreeValidator::new(cs, context, id_provider)
+ .validate(&mut tree)
+ .await?;
+
+ #[cfg(feature = "by_ref_proposal")]
+ if let Some(ext_senders) = context.extensions.get_as::<ExternalSendersExt>()? {
+ // TODO do joiners verify group against current time??
+ ext_senders
+ .verify_all(id_provider, None, &context.extensions)
+ .await
+ .map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))?;
+ }
+
+ validate_group_info_common(msg_version, group_info, &tree, cs).await?;
+
+ Ok(tree)
+}
+
+pub(crate) fn commit_sender(
+ sender: &Sender,
+ provisional_state: &ProvisionalState,
+) -> Result<LeafIndex, MlsError> {
+ match sender {
+ Sender::Member(index) => Ok(LeafIndex(*index)),
+ #[cfg(feature = "by_ref_proposal")]
+ Sender::External(_) => Err(MlsError::ExternalSenderCannotCommit),
+ #[cfg(feature = "by_ref_proposal")]
+ Sender::NewMemberProposal => Err(MlsError::ExpectedAddProposalForNewMemberProposal),
+ Sender::NewMemberCommit => provisional_state
+ .external_init_index
+ .ok_or(MlsError::ExternalCommitMissingExternalInit),
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(super) async fn transcript_hashes<P: CipherSuiteProvider>(
+ cipher_suite_provider: &P,
+ prev_interim_transcript_hash: &InterimTranscriptHash,
+ content: &AuthenticatedContent,
+) -> Result<(InterimTranscriptHash, ConfirmedTranscriptHash), MlsError> {
+ let confirmed_transcript_hash = ConfirmedTranscriptHash::create(
+ cipher_suite_provider,
+ prev_interim_transcript_hash,
+ content,
+ )
+ .await?;
+
+ let confirmation_tag = content
+ .auth
+ .confirmation_tag
+ .as_ref()
+ .ok_or(MlsError::InvalidConfirmationTag)?;
+
+ let interim_transcript_hash = InterimTranscriptHash::create(
+ cipher_suite_provider,
+ &confirmed_transcript_hash,
+ confirmation_tag,
+ )
+ .await?;
+
+ Ok((interim_transcript_hash, confirmed_transcript_hash))
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(crate) async fn find_key_package_generation<'a, K: KeyPackageStorage>(
+ key_package_repo: &K,
+ secrets: &'a [EncryptedGroupSecrets],
+) -> Result<(&'a EncryptedGroupSecrets, KeyPackageGeneration), MlsError> {
+ for secret in secrets {
+ if let Some(val) = key_package_repo
+ .get(&secret.new_member)
+ .await
+ .map_err(|e| MlsError::KeyPackageRepoError(e.into_any_error()))
+ .and_then(|maybe_data| {
+ if let Some(data) = maybe_data {
+ KeyPackageGeneration::from_storage(secret.new_member.to_vec(), data)
+ .map(|kpg| Some((secret, kpg)))
+ } else {
+ Ok::<_, MlsError>(None)
+ }
+ })?
+ {
+ return Ok(val);
+ }
+ }
+
+ Err(MlsError::WelcomeKeyPackageNotFound)
+}
+
+pub(crate) fn cipher_suite_provider<P>(
+ crypto: P,
+ cipher_suite: CipherSuite,
+) -> Result<P::CipherSuiteProvider, MlsError>
+where
+ P: CryptoProvider,
+{
+ crypto
+ .cipher_suite_provider(cipher_suite)
+ .ok_or(MlsError::UnsupportedCipherSuite(cipher_suite))
+}
diff --git a/src/hash_reference.rs b/src/hash_reference.rs
new file mode 100644
index 0000000..41cb156
--- /dev/null
+++ b/src/hash_reference.rs
@@ -0,0 +1,166 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use core::{
+ fmt::{self, Debug},
+ ops::Deref,
+};
+
+use crate::client::MlsError;
+use crate::CipherSuiteProvider;
+use alloc::vec::Vec;
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+use mls_rs_core::error::IntoAnyError;
+
+#[derive(MlsSize, MlsEncode)]
+struct RefHashInput<'a> {
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ pub label: &'a [u8],
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ pub value: &'a [u8],
+}
+
+impl Debug for RefHashInput<'_> {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("RefHashInput")
+ .field("label", &mls_rs_core::debug::pretty_bytes(self.label))
+ .field("value", &mls_rs_core::debug::pretty_bytes(self.value))
+ .finish()
+ }
+}
+
+#[derive(PartialEq, Eq, PartialOrd, Ord, Hash, Clone, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub struct HashReference(
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
+ Vec<u8>,
+);
+
+impl Debug for HashReference {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ mls_rs_core::debug::pretty_bytes(&self.0)
+ .named("HashReference")
+ .fmt(f)
+ }
+}
+
+impl Deref for HashReference {
+ type Target = [u8];
+
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+impl AsRef<[u8]> for HashReference {
+ fn as_ref(&self) -> &[u8] {
+ &self.0
+ }
+}
+
+impl From<Vec<u8>> for HashReference {
+ fn from(val: Vec<u8>) -> Self {
+ Self(val)
+ }
+}
+
+impl HashReference {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn compute<P: CipherSuiteProvider>(
+ value: &[u8],
+ label: &[u8],
+ cipher_suite: &P,
+ ) -> Result<HashReference, MlsError> {
+ let input = RefHashInput { label, value };
+ let input_bytes = input.mls_encode_to_vec()?;
+
+ cipher_suite
+ .hash(&input_bytes)
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
+ .map(HashReference)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use crate::crypto::test_utils::try_test_cipher_suite_provider;
+
+ #[cfg(not(mls_build_async))]
+ use crate::{cipher_suite::CipherSuite, crypto::test_utils::test_cipher_suite_provider};
+
+ use super::*;
+ use alloc::string::String;
+ use serde::{Deserialize, Serialize};
+
+ #[cfg(not(mls_build_async))]
+ use alloc::string::ToString;
+
+ #[cfg(target_arch = "wasm32")]
+ use wasm_bindgen_test::wasm_bindgen_test as test;
+
+ #[derive(Debug, Deserialize, Serialize)]
+ struct HashRefTestCase {
+ label: String,
+ #[serde(with = "hex::serde")]
+ value: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ out: Vec<u8>,
+ }
+
+ #[derive(Debug, serde::Serialize, serde::Deserialize)]
+ pub struct InteropTestCase {
+ cipher_suite: u16,
+ ref_hash: HashRefTestCase,
+ }
+
+ #[cfg(not(mls_build_async))]
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn generate_test_vector() -> Vec<InteropTestCase> {
+ CipherSuite::all()
+ .map(|cipher_suite| {
+ let provider = test_cipher_suite_provider(cipher_suite);
+
+ let input = b"test input";
+ let label = "test label";
+
+ let output = HashReference::compute(input, label.as_bytes(), &provider).unwrap();
+
+ let ref_hash = HashRefTestCase {
+ label: label.to_string(),
+ value: input.to_vec(),
+ out: output.to_vec(),
+ };
+
+ InteropTestCase {
+ cipher_suite: cipher_suite.into(),
+ ref_hash,
+ }
+ })
+ .collect()
+ }
+
+ #[cfg(mls_build_async)]
+ fn generate_test_vector() -> Vec<InteropTestCase> {
+ panic!("Tests cannot be generated in async mode");
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_basic_crypto_test_vectors() {
+ // The test vector can be found here https://github.com/mlswg/mls-implementations/blob/main/test-vectors/crypto-basics.json
+ let test_cases: Vec<InteropTestCase> =
+ load_test_case_json!(basic_crypto, generate_test_vector());
+
+ for test_case in test_cases {
+ if let Some(cs) = try_test_cipher_suite_provider(test_case.cipher_suite) {
+ let label = test_case.ref_hash.label.as_bytes();
+ let value = &test_case.ref_hash.value;
+ let computed = HashReference::compute(value, label, &cs).await.unwrap();
+ assert_eq!(&*computed, &test_case.ref_hash.out);
+ }
+ }
+ }
+}
diff --git a/src/identity.rs b/src/identity.rs
new file mode 100644
index 0000000..5de7a11
--- /dev/null
+++ b/src/identity.rs
@@ -0,0 +1,182 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+/// Basic credential identity provider.
+pub mod basic;
+
+/// X.509 certificate identity provider.
+#[cfg(feature = "x509")]
+pub mod x509 {
+ pub use mls_rs_identity_x509::*;
+}
+
+pub use mls_rs_core::identity::{
+ Credential, CredentialType, CustomCredential, MlsCredential, SigningIdentity,
+};
+
+pub use mls_rs_core::group::RosterUpdate;
+
+#[cfg(test)]
+pub(crate) mod test_utils {
+ use alloc::boxed::Box;
+ use alloc::vec;
+ use alloc::vec::Vec;
+ use mls_rs_core::{
+ crypto::{CipherSuite, CipherSuiteProvider, SignatureSecretKey},
+ error::IntoAnyError,
+ extension::ExtensionList,
+ identity::{Credential, CredentialType, IdentityProvider, SigningIdentity},
+ time::MlsTime,
+ };
+
+ use crate::crypto::test_utils::test_cipher_suite_provider;
+
+ use super::basic::{BasicCredential, BasicIdentityProvider};
+
+ #[derive(Debug)]
+ #[cfg_attr(feature = "std", derive(thiserror::Error))]
+ #[cfg_attr(
+ feature = "std",
+ error("expected basic or custom credential type 42 found: {0:?}")
+ )]
+ pub struct BasicWithCustomProviderError(CredentialType);
+
+ impl IntoAnyError for BasicWithCustomProviderError {
+ #[cfg(feature = "std")]
+ fn into_dyn_error(self) -> Result<Box<dyn std::error::Error + Send + Sync>, Self> {
+ Ok(self.into())
+ }
+ }
+
+ #[derive(Debug, Clone)]
+ pub struct BasicWithCustomProvider {
+ pub(crate) basic: BasicIdentityProvider,
+ pub(crate) allow_any_custom: bool,
+ supported_cred_types: Vec<CredentialType>,
+ }
+
+ impl BasicWithCustomProvider {
+ pub const CUSTOM_CREDENTIAL_TYPE: u16 = 42;
+
+ pub fn new(basic: BasicIdentityProvider) -> BasicWithCustomProvider {
+ BasicWithCustomProvider {
+ basic,
+ allow_any_custom: false,
+ supported_cred_types: vec![
+ CredentialType::BASIC,
+ Self::CUSTOM_CREDENTIAL_TYPE.into(),
+ ],
+ }
+ }
+
+ pub fn with_credential_type(mut self, cred_type: CredentialType) -> Self {
+ self.supported_cred_types.push(cred_type);
+ self
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn resolve_custom_identity(
+ &self,
+ signing_id: &SigningIdentity,
+ ) -> Result<Vec<u8>, BasicWithCustomProviderError> {
+ self.basic
+ .identity(signing_id, &Default::default())
+ .await
+ .or_else(|_| {
+ signing_id
+ .credential
+ .as_custom()
+ .map(|c| {
+ if c.credential_type
+ == CredentialType::from(Self::CUSTOM_CREDENTIAL_TYPE)
+ || self.allow_any_custom
+ {
+ Ok(c.data.to_vec())
+ } else {
+ Err(BasicWithCustomProviderError(c.credential_type))
+ }
+ })
+ .transpose()?
+ .ok_or_else(|| {
+ BasicWithCustomProviderError(signing_id.credential.credential_type())
+ })
+ })
+ }
+ }
+
+ impl Default for BasicWithCustomProvider {
+ fn default() -> Self {
+ Self::new(BasicIdentityProvider::new())
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ #[cfg_attr(mls_build_async, maybe_async::must_be_async)]
+ impl IdentityProvider for BasicWithCustomProvider {
+ type Error = BasicWithCustomProviderError;
+
+ async fn validate_member(
+ &self,
+ _signing_identity: &SigningIdentity,
+ _timestamp: Option<MlsTime>,
+ _extensions: Option<&ExtensionList>,
+ ) -> Result<(), Self::Error> {
+ //TODO: Is it actually beneficial to check the key, or does that already happen elsewhere before
+ //this point?
+ Ok(())
+ }
+
+ async fn validate_external_sender(
+ &self,
+ _signing_identity: &SigningIdentity,
+ _timestamp: Option<MlsTime>,
+ _extensions: Option<&ExtensionList>,
+ ) -> Result<(), Self::Error> {
+ //TODO: Is it actually beneficial to check the key, or does that already happen elsewhere before
+ //this point?
+ Ok(())
+ }
+
+ async fn identity(
+ &self,
+ signing_id: &SigningIdentity,
+ _extensions: &ExtensionList,
+ ) -> Result<Vec<u8>, Self::Error> {
+ self.resolve_custom_identity(signing_id).await
+ }
+
+ async fn valid_successor(
+ &self,
+ predecessor: &SigningIdentity,
+ successor: &SigningIdentity,
+ _extensions: &ExtensionList,
+ ) -> Result<bool, Self::Error> {
+ let predecessor = self.resolve_custom_identity(predecessor).await?;
+ let successor = self.resolve_custom_identity(successor).await?;
+
+ Ok(predecessor == successor)
+ }
+
+ fn supported_types(&self) -> Vec<CredentialType> {
+ self.supported_cred_types.clone()
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn get_test_signing_identity(
+ cipher_suite: CipherSuite,
+ identity: &[u8],
+ ) -> (SigningIdentity, SignatureSecretKey) {
+ let provider = test_cipher_suite_provider(cipher_suite);
+ let (secret_key, public_key) = provider.signature_key_generate().await.unwrap();
+
+ let basic = get_test_basic_credential(identity.to_vec());
+
+ (SigningIdentity::new(basic, public_key), secret_key)
+ }
+
+ pub fn get_test_basic_credential(identity: Vec<u8>) -> Credential {
+ BasicCredential::new(identity).into_credential()
+ }
+}
diff --git a/src/identity/basic.rs b/src/identity/basic.rs
new file mode 100644
index 0000000..b93ab6a
--- /dev/null
+++ b/src/identity/basic.rs
@@ -0,0 +1,99 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use crate::{identity::CredentialType, identity::SigningIdentity, time::MlsTime};
+use alloc::vec;
+use alloc::vec::Vec;
+pub use mls_rs_core::identity::BasicCredential;
+use mls_rs_core::{error::IntoAnyError, extension::ExtensionList, identity::IdentityProvider};
+
+#[derive(Debug)]
+#[cfg_attr(feature = "std", derive(thiserror::Error))]
+#[cfg_attr(feature = "std", error("unsupported credential type found: {0:?}"))]
+/// Error returned in the event that a non-basic
+/// credential is passed to a [`BasicIdentityProvider`].
+pub struct BasicIdentityProviderError(CredentialType);
+
+impl IntoAnyError for BasicIdentityProviderError {
+ #[cfg(feature = "std")]
+ fn into_dyn_error(self) -> Result<Box<dyn std::error::Error + Send + Sync>, Self> {
+ Ok(self.into())
+ }
+}
+
+impl BasicIdentityProviderError {
+ pub fn credential_type(&self) -> CredentialType {
+ self.0
+ }
+}
+
+#[derive(Clone, Debug, Default)]
+/// An always-valid identity provider that works with [`BasicCredential`].
+///
+/// # Warning
+///
+/// This provider always returns `true` for `validate` as long as the
+/// [`SigningIdentity`] used contains a [`BasicCredential`]. It is only
+/// recommended to use this provider for testing purposes.
+pub struct BasicIdentityProvider;
+
+impl BasicIdentityProvider {
+ pub fn new() -> Self {
+ Self
+ }
+}
+
+fn resolve_basic_identity(
+ signing_id: &SigningIdentity,
+) -> Result<&BasicCredential, BasicIdentityProviderError> {
+ signing_id
+ .credential
+ .as_basic()
+ .ok_or_else(|| BasicIdentityProviderError(signing_id.credential.credential_type()))
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(mls_build_async, maybe_async::must_be_async)]
+impl IdentityProvider for BasicIdentityProvider {
+ type Error = BasicIdentityProviderError;
+
+ async fn validate_member(
+ &self,
+ signing_identity: &SigningIdentity,
+ _timestamp: Option<MlsTime>,
+ _extensions: Option<&ExtensionList>,
+ ) -> Result<(), Self::Error> {
+ resolve_basic_identity(signing_identity).map(|_| ())
+ }
+
+ async fn validate_external_sender(
+ &self,
+ signing_identity: &SigningIdentity,
+ _timestamp: Option<MlsTime>,
+ _extensions: Option<&ExtensionList>,
+ ) -> Result<(), Self::Error> {
+ resolve_basic_identity(signing_identity).map(|_| ())
+ }
+
+ async fn identity(
+ &self,
+ signing_identity: &SigningIdentity,
+ _extensions: &ExtensionList,
+ ) -> Result<Vec<u8>, Self::Error> {
+ resolve_basic_identity(signing_identity).map(|b| b.identifier.to_vec())
+ }
+
+ async fn valid_successor(
+ &self,
+ predecessor: &SigningIdentity,
+ successor: &SigningIdentity,
+ _extensions: &ExtensionList,
+ ) -> Result<bool, Self::Error> {
+ Ok(resolve_basic_identity(predecessor)? == resolve_basic_identity(successor)?)
+ }
+
+ fn supported_types(&self) -> Vec<CredentialType> {
+ vec![BasicCredential::credential_type()]
+ }
+}
diff --git a/src/iter.rs b/src/iter.rs
new file mode 100644
index 0000000..e37f162
--- /dev/null
+++ b/src/iter.rs
@@ -0,0 +1,96 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+#[cfg(all(not(mls_build_async), feature = "rayon"))]
+mod sync_rayon {
+ use rayon::{
+ iter::IterBridge,
+ prelude::{FromParallelIterator, IntoParallelIterator, ParallelBridge, ParallelIterator},
+ };
+
+ pub fn wrap_iter<I>(it: I) -> I::Iter
+ where
+ I: IntoParallelIterator,
+ {
+ it.into_par_iter()
+ }
+
+ pub fn wrap_impl_iter<I>(it: I) -> IterBridge<I::IntoIter>
+ where
+ I: IntoIterator,
+ I::IntoIter: Send,
+ I::Item: Send,
+ {
+ it.into_iter().par_bridge()
+ }
+
+ pub trait ParallelIteratorExt {
+ type Ok: Send;
+ type Error: Send;
+
+ fn try_collect<A>(self) -> Result<A, Self::Error>
+ where
+ A: FromParallelIterator<Self::Ok>;
+ }
+
+ impl<I, T, E> ParallelIteratorExt for I
+ where
+ I: ParallelIterator<Item = Result<T, E>>,
+ T: Send,
+ E: Send,
+ {
+ type Ok = T;
+ type Error = E;
+
+ fn try_collect<A>(self) -> Result<A, Self::Error>
+ where
+ A: FromParallelIterator<Self::Ok>,
+ {
+ self.collect()
+ }
+ }
+}
+
+#[cfg(all(not(mls_build_async), feature = "rayon"))]
+pub use sync_rayon::{wrap_impl_iter, wrap_iter, ParallelIteratorExt};
+
+#[cfg(not(any(mls_build_async, feature = "rayon")))]
+mod sync {
+ pub fn wrap_iter<I>(it: I) -> I::IntoIter
+ where
+ I: IntoIterator,
+ {
+ it.into_iter()
+ }
+
+ pub fn wrap_impl_iter<I>(it: I) -> I::IntoIter
+ where
+ I: IntoIterator,
+ {
+ it.into_iter()
+ }
+}
+
+#[cfg(not(any(mls_build_async, feature = "rayon")))]
+pub use sync::{wrap_impl_iter, wrap_iter};
+
+#[cfg(mls_build_async)]
+mod async_ {
+ pub fn wrap_iter<I>(it: I) -> futures::stream::Iter<I::IntoIter>
+ where
+ I: IntoIterator,
+ {
+ futures::stream::iter(it)
+ }
+
+ pub fn wrap_impl_iter<I>(it: I) -> futures::stream::Iter<I::IntoIter>
+ where
+ I: IntoIterator,
+ {
+ futures::stream::iter(it)
+ }
+}
+
+#[cfg(mls_build_async)]
+pub use async_::{wrap_impl_iter, wrap_iter};
diff --git a/src/key_package/generator.rs b/src/key_package/generator.rs
new file mode 100644
index 0000000..4d71094
--- /dev/null
+++ b/src/key_package/generator.rs
@@ -0,0 +1,339 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::vec;
+use alloc::vec::Vec;
+use mls_rs_codec::{MlsDecode, MlsEncode};
+use mls_rs_core::{error::IntoAnyError, identity::IdentityProvider, key_package::KeyPackageData};
+
+use crate::client::MlsError;
+use crate::{
+ crypto::{HpkeSecretKey, SignatureSecretKey},
+ group::framing::MlsMessagePayload,
+ identity::SigningIdentity,
+ protocol_version::ProtocolVersion,
+ signer::Signable,
+ tree_kem::{
+ leaf_node::{ConfigProperties, LeafNode},
+ Capabilities, Lifetime,
+ },
+ CipherSuiteProvider, ExtensionList, MlsMessage,
+};
+
+use super::{KeyPackage, KeyPackageRef};
+
+#[derive(Clone, Debug)]
+pub struct KeyPackageGenerator<'a, IP, CP>
+where
+ IP: IdentityProvider,
+ CP: CipherSuiteProvider,
+{
+ pub protocol_version: ProtocolVersion,
+ pub cipher_suite_provider: &'a CP,
+ pub signing_identity: &'a SigningIdentity,
+ pub signing_key: &'a SignatureSecretKey,
+ pub identity_provider: &'a IP,
+}
+
+#[derive(Clone, Debug)]
+pub struct KeyPackageGeneration {
+ pub(crate) reference: KeyPackageRef,
+ pub(crate) key_package: KeyPackage,
+ pub(crate) init_secret_key: HpkeSecretKey,
+ pub(crate) leaf_node_secret_key: HpkeSecretKey,
+}
+
+impl KeyPackageGeneration {
+ pub fn to_storage(&self) -> Result<(Vec<u8>, KeyPackageData), MlsError> {
+ let id = self.reference.to_vec();
+
+ let data = KeyPackageData::new(
+ self.key_package.mls_encode_to_vec()?,
+ self.init_secret_key.clone(),
+ self.leaf_node_secret_key.clone(),
+ self.key_package.expiration()?,
+ );
+
+ Ok((id, data))
+ }
+
+ pub fn from_storage(id: Vec<u8>, data: KeyPackageData) -> Result<Self, MlsError> {
+ Ok(KeyPackageGeneration {
+ reference: KeyPackageRef::from(id),
+ key_package: KeyPackage::mls_decode(&mut &*data.key_package_bytes)?,
+ init_secret_key: data.init_key,
+ leaf_node_secret_key: data.leaf_node_key,
+ })
+ }
+
+ pub fn key_package_message(&self) -> MlsMessage {
+ MlsMessage::new(
+ self.key_package.version,
+ MlsMessagePayload::KeyPackage(self.key_package.clone()),
+ )
+ }
+}
+
+impl<'a, IP, CP> KeyPackageGenerator<'a, IP, CP>
+where
+ IP: IdentityProvider,
+ CP: CipherSuiteProvider,
+{
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(super) async fn sign(&self, package: &mut KeyPackage) -> Result<(), MlsError> {
+ package
+ .sign(self.cipher_suite_provider, self.signing_key, &())
+ .await
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn generate(
+ &self,
+ lifetime: Lifetime,
+ capabilities: Capabilities,
+ key_package_extensions: ExtensionList,
+ leaf_node_extensions: ExtensionList,
+ ) -> Result<KeyPackageGeneration, MlsError> {
+ let (init_secret_key, public_init) = self
+ .cipher_suite_provider
+ .kem_generate()
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
+
+ let properties = ConfigProperties {
+ capabilities,
+ extensions: leaf_node_extensions,
+ };
+
+ let (leaf_node, leaf_node_secret) = LeafNode::generate(
+ self.cipher_suite_provider,
+ properties,
+ self.signing_identity.clone(),
+ self.signing_key,
+ lifetime,
+ )
+ .await?;
+
+ let mut package = KeyPackage {
+ version: self.protocol_version,
+ cipher_suite: self.cipher_suite_provider.cipher_suite(),
+ hpke_init_key: public_init,
+ leaf_node,
+ extensions: key_package_extensions,
+ signature: vec![],
+ };
+
+ package.grease(self.cipher_suite_provider)?;
+
+ self.sign(&mut package).await?;
+
+ let reference = package.to_reference(self.cipher_suite_provider).await?;
+
+ Ok(KeyPackageGeneration {
+ key_package: package,
+ init_secret_key,
+ leaf_node_secret_key: leaf_node_secret,
+ reference,
+ })
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use assert_matches::assert_matches;
+ use mls_rs_core::crypto::CipherSuiteProvider;
+
+ use crate::{
+ crypto::test_utils::{test_cipher_suite_provider, TestCryptoProvider},
+ extension::test_utils::TestExtension,
+ group::test_utils::random_bytes,
+ identity::basic::BasicIdentityProvider,
+ identity::test_utils::get_test_signing_identity,
+ key_package::validate_key_package_properties,
+ protocol_version::ProtocolVersion,
+ tree_kem::{
+ leaf_node::{test_utils::get_test_capabilities, LeafNodeSource},
+ leaf_node_validator::{LeafNodeValidator, ValidationContext},
+ Lifetime,
+ },
+ ExtensionList,
+ };
+
+ use super::KeyPackageGenerator;
+
+ fn test_key_package_ext(val: u8) -> ExtensionList {
+ let mut ext_list = ExtensionList::new();
+ ext_list.set_from(TestExtension::from(val)).unwrap();
+ ext_list
+ }
+
+ fn test_leaf_node_ext(val: u8) -> ExtensionList {
+ let mut ext_list = ExtensionList::new();
+ ext_list.set_from(TestExtension::from(val)).unwrap();
+ ext_list
+ }
+
+ fn test_lifetime() -> Lifetime {
+ Lifetime::years(1).unwrap()
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_key_generation() {
+ for (protocol_version, cipher_suite) in ProtocolVersion::all().flat_map(|p| {
+ TestCryptoProvider::all_supported_cipher_suites()
+ .into_iter()
+ .map(move |cs| (p, cs))
+ }) {
+ let cipher_suite_provider = test_cipher_suite_provider(cipher_suite);
+
+ let (signing_identity, signing_key) =
+ get_test_signing_identity(cipher_suite, b"foo").await;
+
+ let key_package_ext = test_key_package_ext(32);
+ let leaf_node_ext = test_leaf_node_ext(42);
+ let lifetime = test_lifetime();
+
+ let test_generator = KeyPackageGenerator {
+ protocol_version,
+ cipher_suite_provider: &cipher_suite_provider,
+ signing_identity: &signing_identity,
+ signing_key: &signing_key,
+ identity_provider: &BasicIdentityProvider,
+ };
+
+ let mut capabilities = get_test_capabilities();
+ capabilities.extensions.push(42.into());
+ capabilities.extensions.push(43.into());
+ capabilities.extensions.push(32.into());
+
+ let generated = test_generator
+ .generate(
+ lifetime.clone(),
+ capabilities.clone(),
+ key_package_ext.clone(),
+ leaf_node_ext.clone(),
+ )
+ .await
+ .unwrap();
+
+ assert_matches!(generated.key_package.leaf_node.leaf_node_source,
+ LeafNodeSource::KeyPackage(ref lt) if lt == &lifetime);
+
+ assert_eq!(
+ generated.key_package.leaf_node.ungreased_capabilities(),
+ capabilities
+ );
+
+ assert_eq!(
+ generated.key_package.leaf_node.ungreased_extensions(),
+ leaf_node_ext
+ );
+
+ assert_eq!(
+ generated.key_package.ungreased_extensions(),
+ key_package_ext
+ );
+
+ assert_ne!(
+ generated.key_package.hpke_init_key.as_ref(),
+ generated.key_package.leaf_node.public_key.as_ref()
+ );
+
+ assert_eq!(generated.key_package.cipher_suite, cipher_suite);
+ assert_eq!(generated.key_package.version, protocol_version);
+
+ // Verify that the hpke key pair generated will work
+ let test_data = random_bytes(32);
+
+ let sealed = cipher_suite_provider
+ .hpke_seal(&generated.key_package.hpke_init_key, &[], None, &test_data)
+ .await
+ .unwrap();
+
+ let opened = cipher_suite_provider
+ .hpke_open(
+ &sealed,
+ &generated.init_secret_key,
+ &generated.key_package.hpke_init_key,
+ &[],
+ None,
+ )
+ .await
+ .unwrap();
+
+ assert_eq!(opened, test_data);
+
+ let validator =
+ LeafNodeValidator::new(&cipher_suite_provider, &BasicIdentityProvider, None);
+
+ validator
+ .check_if_valid(
+ &generated.key_package.leaf_node,
+ ValidationContext::Add(None),
+ )
+ .await
+ .unwrap();
+
+ validate_key_package_properties(
+ &generated.key_package,
+ protocol_version,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap();
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_randomness() {
+ for (protocol_version, cipher_suite) in ProtocolVersion::all().flat_map(|p| {
+ TestCryptoProvider::all_supported_cipher_suites()
+ .into_iter()
+ .map(move |cs| (p, cs))
+ }) {
+ let (signing_identity, signing_key) =
+ get_test_signing_identity(cipher_suite, b"foo").await;
+
+ let test_generator = KeyPackageGenerator {
+ protocol_version,
+ cipher_suite_provider: &test_cipher_suite_provider(cipher_suite),
+ signing_identity: &signing_identity,
+ signing_key: &signing_key,
+ identity_provider: &BasicIdentityProvider,
+ };
+
+ let first_key_package = test_generator
+ .generate(
+ test_lifetime(),
+ get_test_capabilities(),
+ ExtensionList::default(),
+ ExtensionList::default(),
+ )
+ .await
+ .unwrap();
+
+ for _ in 0..100 {
+ let next_key_package = test_generator
+ .generate(
+ test_lifetime(),
+ get_test_capabilities(),
+ ExtensionList::default(),
+ ExtensionList::default(),
+ )
+ .await
+ .unwrap();
+
+ assert_ne!(
+ first_key_package.key_package.hpke_init_key,
+ next_key_package.key_package.hpke_init_key
+ );
+
+ assert_ne!(
+ first_key_package.key_package.leaf_node.public_key,
+ next_key_package.key_package.leaf_node.public_key
+ );
+ }
+ }
+ }
+}
diff --git a/src/key_package/mod.rs b/src/key_package/mod.rs
new file mode 100644
index 0000000..b3ef83b
--- /dev/null
+++ b/src/key_package/mod.rs
@@ -0,0 +1,332 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use crate::cipher_suite::CipherSuite;
+use crate::client::MlsError;
+use crate::crypto::HpkePublicKey;
+use crate::hash_reference::HashReference;
+use crate::identity::SigningIdentity;
+use crate::protocol_version::ProtocolVersion;
+use crate::signer::Signable;
+use crate::tree_kem::leaf_node::{LeafNode, LeafNodeSource};
+use crate::CipherSuiteProvider;
+use alloc::vec::Vec;
+use core::{
+ fmt::{self, Debug},
+ ops::Deref,
+};
+use mls_rs_codec::MlsDecode;
+use mls_rs_codec::MlsEncode;
+use mls_rs_codec::MlsSize;
+use mls_rs_core::extension::ExtensionList;
+
+mod validator;
+pub(crate) use validator::*;
+
+pub(crate) mod generator;
+pub(crate) use generator::*;
+
+#[non_exhaustive]
+#[derive(Clone, MlsSize, MlsEncode, MlsDecode, PartialEq)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(
+ all(feature = "ffi", not(test)),
+ safer_ffi_gen::ffi_type(clone, opaque)
+)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub struct KeyPackage {
+ pub version: ProtocolVersion,
+ pub cipher_suite: CipherSuite,
+ pub hpke_init_key: HpkePublicKey,
+ pub(crate) leaf_node: LeafNode,
+ pub extensions: ExtensionList,
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
+ pub signature: Vec<u8>,
+}
+
+impl Debug for KeyPackage {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("KeyPackage")
+ .field("version", &self.version)
+ .field("cipher_suite", &self.cipher_suite)
+ .field("hpke_init_key", &self.hpke_init_key)
+ .field("leaf_node", &self.leaf_node)
+ .field("extensions", &self.extensions)
+ .field(
+ "signature",
+ &mls_rs_core::debug::pretty_bytes(&self.signature),
+ )
+ .finish()
+ }
+}
+
+#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(
+ all(feature = "ffi", not(test)),
+ safer_ffi_gen::ffi_type(clone, opaque)
+)]
+pub struct KeyPackageRef(HashReference);
+
+impl Deref for KeyPackageRef {
+ type Target = [u8];
+
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+impl From<Vec<u8>> for KeyPackageRef {
+ fn from(v: Vec<u8>) -> Self {
+ Self(HashReference::from(v))
+ }
+}
+
+#[derive(MlsSize, MlsEncode)]
+struct KeyPackageData<'a> {
+ pub version: ProtocolVersion,
+ pub cipher_suite: CipherSuite,
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ pub hpke_init_key: &'a HpkePublicKey,
+ pub leaf_node: &'a LeafNode,
+ pub extensions: &'a ExtensionList,
+}
+
+#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
+impl KeyPackage {
+ #[cfg(feature = "ffi")]
+ pub fn version(&self) -> ProtocolVersion {
+ self.version
+ }
+
+ #[cfg(feature = "ffi")]
+ pub fn cipher_suite(&self) -> CipherSuite {
+ self.cipher_suite
+ }
+
+ pub fn signing_identity(&self) -> &SigningIdentity {
+ &self.leaf_node.signing_identity
+ }
+
+ #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen_ignore)]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn to_reference<CP: CipherSuiteProvider>(
+ &self,
+ cipher_suite_provider: &CP,
+ ) -> Result<KeyPackageRef, MlsError> {
+ if cipher_suite_provider.cipher_suite() != self.cipher_suite {
+ return Err(MlsError::CipherSuiteMismatch);
+ }
+
+ Ok(KeyPackageRef(
+ HashReference::compute(
+ &self.mls_encode_to_vec()?,
+ b"MLS 1.0 KeyPackage Reference",
+ cipher_suite_provider,
+ )
+ .await?,
+ ))
+ }
+
+ pub fn expiration(&self) -> Result<u64, MlsError> {
+ if let LeafNodeSource::KeyPackage(lifetime) = &self.leaf_node.leaf_node_source {
+ Ok(lifetime.not_after)
+ } else {
+ Err(MlsError::InvalidLeafNodeSource)
+ }
+ }
+}
+
+impl<'a> Signable<'a> for KeyPackage {
+ const SIGN_LABEL: &'static str = "KeyPackageTBS";
+
+ type SigningContext = ();
+
+ fn signature(&self) -> &[u8] {
+ &self.signature
+ }
+
+ fn signable_content(
+ &self,
+ _context: &Self::SigningContext,
+ ) -> Result<Vec<u8>, mls_rs_codec::Error> {
+ KeyPackageData {
+ version: self.version,
+ cipher_suite: self.cipher_suite,
+ hpke_init_key: &self.hpke_init_key,
+ leaf_node: &self.leaf_node,
+ extensions: &self.extensions,
+ }
+ .mls_encode_to_vec()
+ }
+
+ fn write_signature(&mut self, signature: Vec<u8>) {
+ self.signature = signature
+ }
+}
+
+#[cfg(test)]
+pub(crate) mod test_utils {
+ use super::*;
+ use crate::{
+ crypto::test_utils::test_cipher_suite_provider,
+ group::framing::MlsMessagePayload,
+ identity::basic::BasicIdentityProvider,
+ identity::test_utils::get_test_signing_identity,
+ tree_kem::{leaf_node::test_utils::get_test_capabilities, Lifetime},
+ MlsMessage,
+ };
+
+ use mls_rs_core::crypto::SignatureSecretKey;
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn test_key_package(
+ protocol_version: ProtocolVersion,
+ cipher_suite: CipherSuite,
+ id: &str,
+ ) -> KeyPackage {
+ test_key_package_with_signer(protocol_version, cipher_suite, id)
+ .await
+ .0
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn test_key_package_with_signer(
+ protocol_version: ProtocolVersion,
+ cipher_suite: CipherSuite,
+ id: &str,
+ ) -> (KeyPackage, SignatureSecretKey) {
+ let (signing_identity, secret_key) =
+ get_test_signing_identity(cipher_suite, id.as_bytes()).await;
+
+ let generator = KeyPackageGenerator {
+ protocol_version,
+ cipher_suite_provider: &test_cipher_suite_provider(cipher_suite),
+ signing_identity: &signing_identity,
+ signing_key: &secret_key,
+ identity_provider: &BasicIdentityProvider,
+ };
+
+ let key_package = generator
+ .generate(
+ Lifetime::years(1).unwrap(),
+ get_test_capabilities(),
+ ExtensionList::default(),
+ ExtensionList::default(),
+ )
+ .await
+ .unwrap()
+ .key_package;
+
+ (key_package, secret_key)
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn test_key_package_message(
+ protocol_version: ProtocolVersion,
+ cipher_suite: CipherSuite,
+ id: &str,
+ ) -> MlsMessage {
+ MlsMessage::new(
+ protocol_version,
+ MlsMessagePayload::KeyPackage(
+ test_key_package(protocol_version, cipher_suite, id).await,
+ ),
+ )
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use crate::{
+ client::test_utils::{TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
+ crypto::test_utils::{test_cipher_suite_provider, try_test_cipher_suite_provider},
+ };
+
+ use super::{test_utils::test_key_package, *};
+ use alloc::format;
+ use assert_matches::assert_matches;
+
+ #[derive(serde::Deserialize, serde::Serialize)]
+ struct TestCase {
+ cipher_suite: u16,
+ #[serde(with = "hex::serde")]
+ input: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ output: Vec<u8>,
+ }
+
+ impl TestCase {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ async fn generate() -> Vec<TestCase> {
+ let mut test_cases = Vec::new();
+
+ for (i, (protocol_version, cipher_suite)) in ProtocolVersion::all()
+ .flat_map(|p| CipherSuite::all().map(move |cs| (p, cs)))
+ .enumerate()
+ {
+ let pkg =
+ test_key_package(protocol_version, cipher_suite, &format!("alice{i}")).await;
+
+ let pkg_ref = pkg
+ .to_reference(&test_cipher_suite_provider(cipher_suite))
+ .await
+ .unwrap();
+
+ let case = TestCase {
+ cipher_suite: cipher_suite.into(),
+ input: pkg.mls_encode_to_vec().unwrap(),
+ output: pkg_ref.to_vec(),
+ };
+
+ test_cases.push(case);
+ }
+
+ test_cases
+ }
+ }
+
+ #[cfg(mls_build_async)]
+ async fn load_test_cases() -> Vec<TestCase> {
+ load_test_case_json!(key_package_ref, TestCase::generate().await)
+ }
+
+ #[cfg(not(mls_build_async))]
+ fn load_test_cases() -> Vec<TestCase> {
+ load_test_case_json!(key_package_ref, TestCase::generate())
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_key_package_ref() {
+ let cases = load_test_cases().await;
+
+ for one_case in cases {
+ let Some(provider) = try_test_cipher_suite_provider(one_case.cipher_suite) else {
+ continue;
+ };
+
+ let key_package = KeyPackage::mls_decode(&mut one_case.input.as_slice()).unwrap();
+
+ let key_package_ref = key_package.to_reference(&provider).await.unwrap();
+
+ let expected_out = KeyPackageRef::from(one_case.output);
+ assert_eq!(expected_out, key_package_ref);
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn key_package_ref_fails_invalid_cipher_suite() {
+ let key_package = test_key_package(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "test").await;
+
+ for another_cipher_suite in CipherSuite::all().filter(|cs| cs != &TEST_CIPHER_SUITE) {
+ if let Some(cs) = try_test_cipher_suite_provider(*another_cipher_suite) {
+ let res = key_package.to_reference(&cs).await;
+
+ assert_matches!(res, Err(MlsError::CipherSuiteMismatch));
+ }
+ }
+ }
+}
diff --git a/src/key_package/validator.rs b/src/key_package/validator.rs
new file mode 100644
index 0000000..9cf1dae
--- /dev/null
+++ b/src/key_package/validator.rs
@@ -0,0 +1,39 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use mls_rs_core::{crypto::CipherSuiteProvider, protocol_version::ProtocolVersion};
+
+use crate::{client::MlsError, signer::Signable, KeyPackage};
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(crate) async fn validate_key_package_properties<CSP: CipherSuiteProvider>(
+ package: &KeyPackage,
+ version: ProtocolVersion,
+ cs: &CSP,
+) -> Result<(), MlsError> {
+ package
+ .verify(cs, &package.leaf_node.signing_identity.signature_key, &())
+ .await?;
+
+ // Verify that the protocol version matches
+ if package.version != version {
+ return Err(MlsError::ProtocolVersionMismatch);
+ }
+
+ // Verify that the cipher suite matches
+ if package.cipher_suite != cs.cipher_suite() {
+ return Err(MlsError::CipherSuiteMismatch);
+ }
+
+ // Verify that the public init key is a valid format for this cipher suite
+ cs.kem_public_key_validate(&package.hpke_init_key)
+ .map_err(|_| MlsError::InvalidInitKey)?;
+
+ // Verify that the init key and the leaf node public key are different
+ if package.hpke_init_key.as_ref() == package.leaf_node.public_key.as_ref() {
+ return Err(MlsError::InitLeafKeyEquality);
+ }
+
+ Ok(())
+}
diff --git a/src/lib.rs b/src/lib.rs
new file mode 100644
index 0000000..115b3f8
--- /dev/null
+++ b/src/lib.rs
@@ -0,0 +1,218 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+//! An implementation of the [IETF Messaging Layer Security](https://messaginglayersecurity.rocks)
+//! end-to-end encryption (E2EE) protocol.
+//!
+//! ## What is MLS?
+//!
+//! MLS is a new IETF end-to-end encryption standard that is designed to
+//! provide transport agnostic, asynchronous, and highly performant
+//! communication between a group of clients.
+//!
+//! ## MLS Protocol Features
+//!
+//! - Multi-party E2EE [group evolution](https://www.rfc-editor.org/rfc/rfc9420.html#name-cryptographic-state-and-evo)
+//! via a propose-then-commit mechanism.
+//! - Asynchronous by design with pre-computed [key packages](https://www.rfc-editor.org/rfc/rfc9420.html#name-key-packages),
+//! allowing members to be added to a group while offline.
+//! - Customizable credential system with built in support for X.509 certificates.
+//! - [Extension system](https://www.rfc-editor.org/rfc/rfc9420.html#name-extensions)
+//! allowing for application specific data to be negotiated via the protocol.
+//! - Strong forward secrecy and post compromise security.
+//! - Crypto agility via support for multiple [cipher suites](https://www.rfc-editor.org/rfc/rfc9420.html#name-cipher-suites).
+//! - Pre-shared key support.
+//! - Subgroup branching.
+//! - Group reinitialization for breaking changes such as protocol upgrades.
+//!
+//! ## Features
+//!
+//! - Easy to use client interface that can manage multiple MLS identities and groups.
+//! - 100% RFC 9420 conformance with support for all default credential, proposal,
+//! and extension types.
+//! - Support for WASM builds.
+//! - Configurable storage for key packages, secrets and group state
+//! via traits along with provided "in memory" and SQLite implementations.
+//! - Support for custom user proposal and extension types.
+//! - Ability to create user defined credentials with custom validation
+//! routines that can bridge to existing credential schemes.
+//! - OpenSSL and Rust Crypto based cipher suite implementations.
+//! - Crypto agility with support for user defined cipher suite.
+//! - Extensive test suite including security and interop focused tests against
+//! pre-computed test vectors.
+//!
+//! ## Crypto Providers
+//!
+//! For cipher suite descriptions see the RFC documentation [here](https://www.rfc-editor.org/rfc/rfc9420.html#name-mls-cipher-suites)
+//!
+//! | Name | Cipher Suites | X509 Support |
+//! |------|---------------|--------------|
+//! | OpenSSL | 1-7 | Stable |
+//! | AWS-LC | 1,2,3,5,7 | Stable |
+//! | Rust Crypto | 1,2,3 | ⚠️ Experimental |
+//!
+//! ## Security Notice
+//!
+//! This library has been validated for conformance to the RFC 9420 specification but has not yet received a full security audit by a 3rd party.
+
+#![allow(clippy::enum_variant_names)]
+#![allow(clippy::result_large_err)]
+#![allow(clippy::nonstandard_macro_braces)]
+#![cfg_attr(not(feature = "std"), no_std)]
+#![cfg_attr(docsrs, feature(doc_cfg))]
+#![cfg_attr(coverage_nightly, feature(coverage_attribute))]
+extern crate alloc;
+
+#[cfg(all(test, target_arch = "wasm32"))]
+wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser);
+
+#[cfg(all(test, target_arch = "wasm32"))]
+use wasm_bindgen_test::wasm_bindgen_test as futures_test;
+
+#[cfg(all(test, mls_build_async, not(target_arch = "wasm32")))]
+use futures_test::test as futures_test;
+
+#[cfg(test)]
+macro_rules! hex {
+ ($input:literal) => {
+ hex::decode($input).expect("invalid hex value")
+ };
+}
+
+#[cfg(test)]
+macro_rules! load_test_case_json {
+ ($name:ident, $generate:expr) => {
+ load_test_case_json!($name, $generate, to_vec_pretty)
+ };
+ ($name:ident, $generate:expr, $to_json:ident) => {{
+ #[cfg(any(target_arch = "wasm32", not(feature = "std")))]
+ {
+ // Do not remove `async`! (The goal of this line is to remove warnings
+ // about `$generate` not being used. Actually calling it will make tests fail.)
+ let _ = async { $generate };
+ serde_json::from_slice(include_bytes!(concat!(
+ env!("CARGO_MANIFEST_DIR"),
+ "/test_data/",
+ stringify!($name),
+ ".json"
+ )))
+ .unwrap()
+ }
+
+ #[cfg(all(not(target_arch = "wasm32"), feature = "std"))]
+ {
+ let path = concat!(
+ env!("CARGO_MANIFEST_DIR"),
+ "/test_data/",
+ stringify!($name),
+ ".json"
+ );
+ if !std::path::Path::new(path).exists() {
+ std::fs::write(path, serde_json::$to_json(&$generate).unwrap()).unwrap();
+ }
+ serde_json::from_slice(&std::fs::read(path).unwrap()).unwrap()
+ }
+ }};
+}
+
+mod cipher_suite {
+ pub use mls_rs_core::crypto::CipherSuite;
+}
+
+pub use cipher_suite::CipherSuite;
+
+mod protocol_version {
+ pub use mls_rs_core::protocol_version::ProtocolVersion;
+}
+
+pub use protocol_version::ProtocolVersion;
+
+pub mod client;
+pub mod client_builder;
+mod client_config;
+/// Dependencies of [`CryptoProvider`] and [`CipherSuiteProvider`]
+pub mod crypto;
+/// Extension utilities and built-in extension types.
+pub mod extension;
+/// Tools to observe groups without being a member, useful
+/// for server implementations.
+#[cfg(feature = "external_client")]
+#[cfg_attr(docsrs, doc(cfg(feature = "external_client")))]
+pub mod external_client;
+mod grease;
+/// E2EE group created by a [`Client`].
+pub mod group;
+mod hash_reference;
+/// Identity providers to use with [`ClientBuilder`](client_builder::ClientBuilder).
+pub mod identity;
+mod iter;
+mod key_package;
+/// Pre-shared key support.
+pub mod psk;
+mod signer;
+/// Storage providers to use with
+/// [`ClientBuilder`](client_builder::ClientBuilder).
+pub mod storage_provider;
+
+pub use mls_rs_core::{
+ crypto::{CipherSuiteProvider, CryptoProvider},
+ group::GroupStateStorage,
+ identity::IdentityProvider,
+ key_package::KeyPackageStorage,
+ psk::PreSharedKeyStorage,
+};
+
+/// Dependencies of [`MlsRules`].
+pub mod mls_rules {
+ pub use crate::group::{
+ mls_rules::{
+ CommitDirection, CommitOptions, CommitSource, DefaultMlsRules, EncryptionOptions,
+ },
+ proposal_filter::{ProposalBundle, ProposalInfo, ProposalSource},
+ };
+
+ #[cfg(feature = "by_ref_proposal")]
+ pub use crate::group::proposal_ref::ProposalRef;
+}
+
+pub use mls_rs_core::extension::{Extension, ExtensionList};
+
+pub use crate::{
+ client::Client,
+ group::{
+ framing::{MlsMessage, WireFormat},
+ mls_rules::MlsRules,
+ Group,
+ },
+ key_package::{KeyPackage, KeyPackageRef},
+};
+
+/// Error types.
+pub mod error {
+ pub use crate::client::MlsError;
+ pub use mls_rs_core::error::{AnyError, IntoAnyError};
+ pub use mls_rs_core::extension::ExtensionError;
+}
+
+/// WASM compatible timestamp.
+pub mod time {
+ pub use mls_rs_core::time::*;
+}
+
+mod tree_kem;
+
+pub use mls_rs_codec;
+
+mod private {
+ pub trait Sealed {}
+}
+
+use private::Sealed;
+
+#[cfg(any(test, feature = "test_util"))]
+#[doc(hidden)]
+pub mod test_utils;
+
+#[cfg(feature = "ffi")]
+pub use safer_ffi_gen;
diff --git a/src/message.rs b/src/message.rs
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/message.rs
diff --git a/src/psk.rs b/src/psk.rs
new file mode 100644
index 0000000..5bf95c3
--- /dev/null
+++ b/src/psk.rs
@@ -0,0 +1,200 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::vec::Vec;
+
+#[cfg(any(test, feature = "external_client"))]
+use alloc::vec;
+
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+
+#[cfg(any(test, feature = "external_client"))]
+use mls_rs_core::psk::PreSharedKeyStorage;
+
+#[cfg(any(test, feature = "external_client"))]
+use core::convert::Infallible;
+use core::fmt::{self, Debug};
+
+#[cfg(feature = "psk")]
+use crate::{client::MlsError, CipherSuiteProvider};
+
+#[cfg(feature = "psk")]
+use mls_rs_core::error::IntoAnyError;
+
+#[cfg(feature = "psk")]
+pub(crate) mod resolver;
+pub(crate) mod secret;
+
+pub use mls_rs_core::psk::{ExternalPskId, PreSharedKey};
+
+#[derive(Clone, Debug, Eq, Hash, PartialEq, PartialOrd, Ord, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub(crate) struct PreSharedKeyID {
+ pub key_id: JustPreSharedKeyID,
+ pub psk_nonce: PskNonce,
+}
+
+impl PreSharedKeyID {
+ #[cfg(feature = "psk")]
+ pub(crate) fn new<P: CipherSuiteProvider>(
+ key_id: JustPreSharedKeyID,
+ cs: &P,
+ ) -> Result<Self, MlsError> {
+ Ok(Self {
+ key_id,
+ psk_nonce: PskNonce::random(cs)
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?,
+ })
+ }
+}
+
+#[derive(Clone, Debug, Eq, Hash, Ord, PartialOrd, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+#[repr(u8)]
+pub(crate) enum JustPreSharedKeyID {
+ External(ExternalPskId) = 1u8,
+ Resumption(ResumptionPsk) = 2u8,
+}
+
+#[derive(Clone, Eq, Hash, Ord, PartialOrd, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub(crate) struct PskGroupId(
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
+ pub Vec<u8>,
+);
+
+impl Debug for PskGroupId {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ mls_rs_core::debug::pretty_bytes(&self.0)
+ .named("PskGroupId")
+ .fmt(f)
+ }
+}
+
+#[derive(Clone, Eq, Hash, PartialEq, PartialOrd, Ord, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub(crate) struct PskNonce(
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
+ pub Vec<u8>,
+);
+
+impl Debug for PskNonce {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ mls_rs_core::debug::pretty_bytes(&self.0)
+ .named("PskNonce")
+ .fmt(f)
+ }
+}
+
+#[cfg(feature = "psk")]
+impl PskNonce {
+ pub fn random<P: CipherSuiteProvider>(
+ cipher_suite_provider: &P,
+ ) -> Result<Self, <P as CipherSuiteProvider>::Error> {
+ Ok(Self(cipher_suite_provider.random_bytes_vec(
+ cipher_suite_provider.kdf_extract_size(),
+ )?))
+ }
+}
+
+#[derive(Clone, Debug, Eq, Hash, Ord, PartialOrd, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub(crate) struct ResumptionPsk {
+ pub usage: ResumptionPSKUsage,
+ pub psk_group_id: PskGroupId,
+ pub psk_epoch: u64,
+}
+
+#[derive(Clone, Debug, Eq, Hash, PartialEq, Ord, PartialOrd, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+#[repr(u8)]
+pub(crate) enum ResumptionPSKUsage {
+ Application = 1u8,
+ Reinit = 2u8,
+ Branch = 3u8,
+}
+
+#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode)]
+struct PSKLabel<'a> {
+ id: &'a PreSharedKeyID,
+ index: u16,
+ count: u16,
+}
+
+#[cfg(any(test, feature = "external_client"))]
+#[derive(Clone, Copy, Debug)]
+pub(crate) struct AlwaysFoundPskStorage;
+
+#[cfg(any(test, feature = "external_client"))]
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(mls_build_async, maybe_async::must_be_async)]
+impl PreSharedKeyStorage for AlwaysFoundPskStorage {
+ type Error = Infallible;
+
+ async fn get(&self, _: &ExternalPskId) -> Result<Option<PreSharedKey>, Self::Error> {
+ Ok(Some(vec![].into()))
+ }
+}
+
+#[cfg(feature = "psk")]
+#[cfg(test)]
+pub(crate) mod test_utils {
+ use crate::crypto::test_utils::test_cipher_suite_provider;
+
+ use super::PskNonce;
+ use mls_rs_core::crypto::CipherSuite;
+
+ #[cfg(not(mls_build_async))]
+ use mls_rs_core::{crypto::CipherSuiteProvider, psk::ExternalPskId};
+
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ #[cfg(not(mls_build_async))]
+ pub(crate) fn make_external_psk_id<P: CipherSuiteProvider>(
+ cipher_suite_provider: &P,
+ ) -> ExternalPskId {
+ ExternalPskId::new(
+ cipher_suite_provider
+ .random_bytes_vec(cipher_suite_provider.kdf_extract_size())
+ .unwrap(),
+ )
+ }
+
+ pub(crate) fn make_nonce(cipher_suite: CipherSuite) -> PskNonce {
+ PskNonce::random(&test_cipher_suite_provider(cipher_suite)).unwrap()
+ }
+}
+
+#[cfg(feature = "psk")]
+#[cfg(test)]
+mod tests {
+ use crate::crypto::test_utils::TestCryptoProvider;
+ use core::iter;
+
+ #[cfg(target_arch = "wasm32")]
+ use wasm_bindgen_test::wasm_bindgen_test as test;
+
+ use super::test_utils::make_nonce;
+
+ #[test]
+ fn random_generation_of_nonces_is_random() {
+ let good = TestCryptoProvider::all_supported_cipher_suites()
+ .into_iter()
+ .all(|cipher_suite| {
+ let nonce = make_nonce(cipher_suite);
+ iter::repeat_with(|| make_nonce(cipher_suite))
+ .take(1000)
+ .all(|other| other != nonce)
+ });
+
+ assert!(good);
+ }
+}
diff --git a/src/psk/resolver.rs b/src/psk/resolver.rs
new file mode 100644
index 0000000..0e3b7c9
--- /dev/null
+++ b/src/psk/resolver.rs
@@ -0,0 +1,95 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::vec::Vec;
+use mls_rs_core::{
+ crypto::CipherSuiteProvider,
+ error::IntoAnyError,
+ group::GroupStateStorage,
+ key_package::KeyPackageStorage,
+ psk::{ExternalPskId, PreSharedKey, PreSharedKeyStorage},
+};
+
+use crate::{
+ client::MlsError,
+ group::{epoch::EpochSecrets, state_repo::GroupStateRepository, GroupContext},
+ psk::secret::PskSecret,
+};
+
+use super::{secret::PskSecretInput, JustPreSharedKeyID, PreSharedKeyID, ResumptionPsk};
+
+pub(crate) struct PskResolver<'a, GS, K, PS>
+where
+ GS: GroupStateStorage,
+ PS: PreSharedKeyStorage,
+ K: KeyPackageStorage,
+{
+ pub group_context: Option<&'a GroupContext>,
+ pub current_epoch: Option<&'a EpochSecrets>,
+ pub prior_epochs: Option<&'a GroupStateRepository<GS, K>>,
+ pub psk_store: &'a PS,
+}
+
+impl<GS: GroupStateStorage, K: KeyPackageStorage, PS: PreSharedKeyStorage>
+ PskResolver<'_, GS, K, PS>
+{
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn resolve_resumption(&self, psk_id: &ResumptionPsk) -> Result<PreSharedKey, MlsError> {
+ if let Some(ctx) = self.group_context {
+ if ctx.epoch == psk_id.psk_epoch && ctx.group_id == psk_id.psk_group_id.0 {
+ let epoch = self.current_epoch.ok_or(MlsError::OldGroupStateNotFound)?;
+ return Ok(epoch.resumption_secret.clone());
+ }
+ }
+
+ #[cfg(feature = "prior_epoch")]
+ if let Some(eps) = self.prior_epochs {
+ if let Some(psk) = eps.resumption_secret(psk_id).await? {
+ return Ok(psk);
+ }
+ }
+
+ Err(MlsError::OldGroupStateNotFound)
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn resolve_external(&self, psk_id: &ExternalPskId) -> Result<PreSharedKey, MlsError> {
+ self.psk_store
+ .get(psk_id)
+ .await
+ .map_err(|e| MlsError::PskStoreError(e.into_any_error()))?
+ .ok_or(MlsError::MissingRequiredPsk)
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn resolve(&self, id: &[PreSharedKeyID]) -> Result<Vec<PskSecretInput>, MlsError> {
+ let mut secret_inputs = Vec::new();
+
+ for id in id {
+ let psk = match &id.key_id {
+ JustPreSharedKeyID::External(external) => self.resolve_external(external).await,
+ JustPreSharedKeyID::Resumption(resumption) => {
+ self.resolve_resumption(resumption).await
+ }
+ }?;
+
+ secret_inputs.push(PskSecretInput {
+ id: id.clone(),
+ psk,
+ })
+ }
+
+ Ok(secret_inputs)
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn resolve_to_secret<P: CipherSuiteProvider>(
+ &self,
+ id: &[PreSharedKeyID],
+ cipher_suite_provider: &P,
+ ) -> Result<PskSecret, MlsError> {
+ let psk = self.resolve(id).await?;
+ PskSecret::calculate(&psk, cipher_suite_provider).await
+ }
+}
diff --git a/src/psk/secret.rs b/src/psk/secret.rs
new file mode 100644
index 0000000..4fe9cc8
--- /dev/null
+++ b/src/psk/secret.rs
@@ -0,0 +1,239 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::vec;
+use alloc::vec::Vec;
+use core::{
+ fmt::{self, Debug},
+ ops::Deref,
+};
+use mls_rs_core::crypto::CipherSuiteProvider;
+use zeroize::Zeroizing;
+
+#[cfg(feature = "psk")]
+use mls_rs_codec::MlsEncode;
+
+#[cfg(feature = "psk")]
+use mls_rs_core::{error::IntoAnyError, psk::PreSharedKey};
+
+#[cfg(feature = "psk")]
+use crate::{
+ client::MlsError,
+ group::key_schedule::kdf_expand_with_label,
+ psk::{PSKLabel, PreSharedKeyID},
+};
+
+#[cfg(feature = "psk")]
+#[derive(Clone)]
+pub(crate) struct PskSecretInput {
+ pub id: PreSharedKeyID,
+ pub psk: PreSharedKey,
+}
+
+#[derive(PartialEq, Eq, Clone)]
+pub(crate) struct PskSecret(Zeroizing<Vec<u8>>);
+
+impl Debug for PskSecret {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ mls_rs_core::debug::pretty_bytes(&self.0)
+ .named("PskSecret")
+ .fmt(f)
+ }
+}
+
+#[cfg(test)]
+impl From<Vec<u8>> for PskSecret {
+ fn from(value: Vec<u8>) -> Self {
+ PskSecret(Zeroizing::new(value))
+ }
+}
+
+impl Deref for PskSecret {
+ type Target = [u8];
+
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+impl PskSecret {
+ pub(crate) fn new<P: CipherSuiteProvider>(provider: &P) -> PskSecret {
+ PskSecret(Zeroizing::new(vec![0u8; provider.kdf_extract_size()]))
+ }
+
+ #[cfg(feature = "psk")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn calculate<P: CipherSuiteProvider>(
+ input: &[PskSecretInput],
+ cipher_suite_provider: &P,
+ ) -> Result<PskSecret, MlsError> {
+ let len = u16::try_from(input.len()).map_err(|_| MlsError::TooManyPskIds)?;
+ let mut psk_secret = PskSecret::new(cipher_suite_provider);
+
+ for (index, psk_secret_input) in input.iter().enumerate() {
+ let index = index as u16;
+
+ let label = PSKLabel {
+ id: &psk_secret_input.id,
+ index,
+ count: len,
+ };
+
+ let psk_extracted = cipher_suite_provider
+ .kdf_extract(
+ &vec![0; cipher_suite_provider.kdf_extract_size()],
+ &psk_secret_input.psk,
+ )
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
+
+ let psk_input = kdf_expand_with_label(
+ cipher_suite_provider,
+ &psk_extracted,
+ b"derived psk",
+ &label.mls_encode_to_vec()?,
+ None,
+ )
+ .await?;
+
+ psk_secret = cipher_suite_provider
+ .kdf_extract(&psk_input, &psk_secret)
+ .await
+ .map(PskSecret)
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
+ }
+
+ Ok(psk_secret)
+ }
+}
+
+#[cfg(feature = "psk")]
+#[cfg(test)]
+mod tests {
+ use alloc::vec::Vec;
+ #[cfg(not(mls_build_async))]
+ use core::iter;
+ use serde::{Deserialize, Serialize};
+
+ use crate::{
+ crypto::test_utils::try_test_cipher_suite_provider,
+ psk::ExternalPskId,
+ psk::{JustPreSharedKeyID, PreSharedKeyID, PskNonce},
+ CipherSuiteProvider,
+ };
+
+ #[cfg(not(mls_build_async))]
+ use crate::{
+ crypto::test_utils::test_cipher_suite_provider, psk::test_utils::make_external_psk_id,
+ CipherSuite,
+ };
+
+ use super::{PskSecret, PskSecretInput};
+
+ #[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
+ struct PskInfo {
+ #[serde(with = "hex::serde")]
+ id: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ psk: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ nonce: Vec<u8>,
+ }
+
+ impl From<PskInfo> for PskSecretInput {
+ fn from(info: PskInfo) -> Self {
+ let id = PreSharedKeyID {
+ key_id: JustPreSharedKeyID::External(ExternalPskId::new(info.id)),
+ psk_nonce: PskNonce(info.nonce),
+ };
+
+ PskSecretInput {
+ id,
+ psk: info.psk.into(),
+ }
+ }
+ }
+
+ #[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
+ struct TestScenario {
+ cipher_suite: u16,
+ psks: Vec<PskInfo>,
+ #[serde(with = "hex::serde")]
+ psk_secret: Vec<u8>,
+ }
+
+ impl TestScenario {
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ #[cfg(not(mls_build_async))]
+ fn make_psk_list<CS: CipherSuiteProvider>(cs: &CS, n: usize) -> Vec<PskInfo> {
+ iter::repeat_with(
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ || PskInfo {
+ id: make_external_psk_id(cs).to_vec(),
+ psk: cs.random_bytes_vec(cs.kdf_extract_size()).unwrap(),
+ nonce: crate::psk::test_utils::make_nonce(cs.cipher_suite()).0,
+ },
+ )
+ .take(n)
+ .collect::<Vec<_>>()
+ }
+
+ #[cfg(not(mls_build_async))]
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn generate() -> Vec<TestScenario> {
+ CipherSuite::all()
+ .flat_map(
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ |cs| (1..=10).map(move |n| (cs, n)),
+ )
+ .map(
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ |(cs, n)| {
+ let provider = test_cipher_suite_provider(cs);
+ let psks = Self::make_psk_list(&provider, n);
+ let psk_secret = Self::compute_psk_secret(&provider, psks.clone());
+ TestScenario {
+ cipher_suite: cs.into(),
+ psks: psks.to_vec(),
+ psk_secret: psk_secret.to_vec(),
+ }
+ },
+ )
+ .collect()
+ }
+
+ #[cfg(mls_build_async)]
+ fn generate() -> Vec<TestScenario> {
+ panic!("Tests cannot be generated in async mode");
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn compute_psk_secret<P: CipherSuiteProvider>(
+ provider: &P,
+ psks: Vec<PskInfo>,
+ ) -> PskSecret {
+ let input = psks
+ .into_iter()
+ .map(PskSecretInput::from)
+ .collect::<Vec<_>>();
+
+ PskSecret::calculate(&input, provider).await.unwrap()
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn expected_psk_secret_is_produced() {
+ let scenarios: Vec<TestScenario> =
+ load_test_case_json!(psk_secret, TestScenario::generate());
+
+ for scenario in scenarios {
+ if let Some(provider) = try_test_cipher_suite_provider(scenario.cipher_suite) {
+ let computed =
+ TestScenario::compute_psk_secret(&provider, scenario.psks.clone()).await;
+
+ assert_eq!(scenario.psk_secret, computed.to_vec());
+ }
+ }
+ }
+}
diff --git a/src/signer.rs b/src/signer.rs
new file mode 100644
index 0000000..12970ec
--- /dev/null
+++ b/src/signer.rs
@@ -0,0 +1,357 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::vec::Vec;
+use core::fmt::{self, Debug};
+use mls_rs_codec::{MlsEncode, MlsSize};
+use mls_rs_core::error::IntoAnyError;
+
+use crate::client::MlsError;
+use crate::crypto::{CipherSuiteProvider, SignaturePublicKey, SignatureSecretKey};
+
+#[derive(Clone, MlsSize, MlsEncode)]
+struct SignContent {
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ label: Vec<u8>,
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ content: Vec<u8>,
+}
+
+impl Debug for SignContent {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("SignContent")
+ .field("label", &mls_rs_core::debug::pretty_bytes(&self.label))
+ .field("content", &mls_rs_core::debug::pretty_bytes(&self.content))
+ .finish()
+ }
+}
+
+impl SignContent {
+ pub fn new(label: &str, content: Vec<u8>) -> Self {
+ Self {
+ label: [b"MLS 1.0 ", label.as_bytes()].concat(),
+ content,
+ }
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(all(target_arch = "wasm32", mls_build_async), maybe_async::must_be_async(?Send))]
+#[cfg_attr(
+ all(not(target_arch = "wasm32"), mls_build_async),
+ maybe_async::must_be_async
+)]
+pub(crate) trait Signable<'a> {
+ const SIGN_LABEL: &'static str;
+
+ type SigningContext: Send + Sync;
+
+ fn signature(&self) -> &[u8];
+
+ fn signable_content(
+ &self,
+ context: &Self::SigningContext,
+ ) -> Result<Vec<u8>, mls_rs_codec::Error>;
+
+ fn write_signature(&mut self, signature: Vec<u8>);
+
+ async fn sign<P: CipherSuiteProvider>(
+ &mut self,
+ signature_provider: &P,
+ signer: &SignatureSecretKey,
+ context: &Self::SigningContext,
+ ) -> Result<(), MlsError> {
+ let sign_content = SignContent::new(Self::SIGN_LABEL, self.signable_content(context)?);
+
+ let signature = signature_provider
+ .sign(signer, &sign_content.mls_encode_to_vec()?)
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
+
+ self.write_signature(signature);
+
+ Ok(())
+ }
+
+ async fn verify<P: CipherSuiteProvider>(
+ &self,
+ signature_provider: &P,
+ public_key: &SignaturePublicKey,
+ context: &Self::SigningContext,
+ ) -> Result<(), MlsError> {
+ let sign_content = SignContent::new(Self::SIGN_LABEL, self.signable_content(context)?);
+
+ signature_provider
+ .verify(
+ public_key,
+ self.signature(),
+ &sign_content.mls_encode_to_vec()?,
+ )
+ .await
+ .map_err(|_| MlsError::InvalidSignature)
+ }
+}
+
+#[cfg(test)]
+pub(crate) mod test_utils {
+ use alloc::vec;
+ use alloc::{string::String, vec::Vec};
+ use mls_rs_core::crypto::CipherSuiteProvider;
+
+ use crate::crypto::test_utils::try_test_cipher_suite_provider;
+
+ use super::Signable;
+
+ #[derive(Debug, serde::Serialize, serde::Deserialize)]
+ pub struct SignatureInteropTestCase {
+ #[serde(with = "hex::serde", rename = "priv")]
+ secret: Vec<u8>,
+ #[serde(with = "hex::serde", rename = "pub")]
+ public: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ content: Vec<u8>,
+ label: String,
+ #[serde(with = "hex::serde")]
+ signature: Vec<u8>,
+ }
+
+ #[derive(Debug, serde::Serialize, serde::Deserialize)]
+ pub struct InteropTestCase {
+ cipher_suite: u16,
+ sign_with_label: SignatureInteropTestCase,
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_basic_crypto_test_vectors() {
+ let test_cases: Vec<InteropTestCase> =
+ load_test_case_json!(basic_crypto, Vec::<InteropTestCase>::new());
+
+ for test_case in test_cases {
+ if let Some(cs) = try_test_cipher_suite_provider(test_case.cipher_suite) {
+ test_case.sign_with_label.verify(&cs).await;
+ }
+ }
+ }
+
+ pub struct TestSignable {
+ pub content: Vec<u8>,
+ pub signature: Vec<u8>,
+ }
+
+ impl<'a> Signable<'a> for TestSignable {
+ const SIGN_LABEL: &'static str = "SignWithLabel";
+
+ type SigningContext = Vec<u8>;
+
+ fn signature(&self) -> &[u8] {
+ &self.signature
+ }
+
+ fn signable_content(
+ &self,
+ context: &Self::SigningContext,
+ ) -> Result<Vec<u8>, mls_rs_codec::Error> {
+ Ok([context.as_slice(), self.content.as_slice()].concat())
+ }
+
+ fn write_signature(&mut self, signature: Vec<u8>) {
+ self.signature = signature
+ }
+ }
+
+ impl SignatureInteropTestCase {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn verify<P: CipherSuiteProvider>(&self, cs: &P) {
+ let public = self.public.clone().into();
+
+ let signable = TestSignable {
+ content: self.content.clone(),
+ signature: self.signature.clone(),
+ };
+
+ signable.verify(cs, &public, &vec![]).await.unwrap();
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::{test_utils::TestSignable, *};
+ use crate::{
+ client::test_utils::TEST_CIPHER_SUITE,
+ crypto::test_utils::{
+ test_cipher_suite_provider, try_test_cipher_suite_provider, TestCryptoProvider,
+ },
+ group::test_utils::random_bytes,
+ };
+ use alloc::vec;
+ use assert_matches::assert_matches;
+
+ #[derive(Debug, serde::Serialize, serde::Deserialize)]
+ struct TestCase {
+ cipher_suite: u16,
+ #[serde(with = "hex::serde")]
+ content: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ context: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ signature: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ signer: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ public: Vec<u8>,
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ async fn generate_test_cases() -> Vec<TestCase> {
+ let mut test_cases = Vec::new();
+
+ for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
+ let provider = test_cipher_suite_provider(cipher_suite);
+
+ let (signer, public) = provider.signature_key_generate().await.unwrap();
+
+ let content = random_bytes(32);
+ let context = random_bytes(32);
+
+ let mut test_signable = TestSignable {
+ content: content.clone(),
+ signature: Vec::new(),
+ };
+
+ test_signable
+ .sign(&provider, &signer, &context)
+ .await
+ .unwrap();
+
+ test_cases.push(TestCase {
+ cipher_suite: cipher_suite.into(),
+ content,
+ context,
+ signature: test_signable.signature,
+ signer: signer.to_vec(),
+ public: public.to_vec(),
+ });
+ }
+
+ test_cases
+ }
+
+ #[cfg(mls_build_async)]
+ async fn load_test_cases() -> Vec<TestCase> {
+ load_test_case_json!(signatures, generate_test_cases().await)
+ }
+
+ #[cfg(not(mls_build_async))]
+ fn load_test_cases() -> Vec<TestCase> {
+ load_test_case_json!(signatures, generate_test_cases())
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_signatures() {
+ let cases = load_test_cases().await;
+
+ for one_case in cases {
+ let Some(cipher_suite_provider) = try_test_cipher_suite_provider(one_case.cipher_suite)
+ else {
+ continue;
+ };
+
+ let public_key = SignaturePublicKey::from(one_case.public);
+
+ // Wasm uses incompatible signature secret key format
+ #[cfg(not(target_arch = "wasm32"))]
+ {
+ // Test signature generation
+ let mut test_signable = TestSignable {
+ content: one_case.content.clone(),
+ signature: Vec::new(),
+ };
+
+ let signature_key = SignatureSecretKey::from(one_case.signer);
+
+ test_signable
+ .sign(&cipher_suite_provider, &signature_key, &one_case.context)
+ .await
+ .unwrap();
+
+ test_signable
+ .verify(&cipher_suite_provider, &public_key, &one_case.context)
+ .await
+ .unwrap();
+ }
+
+ // Test verifying an existing signature
+ let test_signable = TestSignable {
+ content: one_case.content,
+ signature: one_case.signature,
+ };
+
+ test_signable
+ .verify(&cipher_suite_provider, &public_key, &one_case.context)
+ .await
+ .unwrap();
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_invalid_signature() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let (correct_secret, _) = cipher_suite_provider
+ .signature_key_generate()
+ .await
+ .unwrap();
+ let (_, incorrect_public) = cipher_suite_provider
+ .signature_key_generate()
+ .await
+ .unwrap();
+
+ let mut test_signable = TestSignable {
+ content: random_bytes(32),
+ signature: vec![],
+ };
+
+ test_signable
+ .sign(&cipher_suite_provider, &correct_secret, &vec![])
+ .await
+ .unwrap();
+
+ let res = test_signable
+ .verify(&cipher_suite_provider, &incorrect_public, &vec![])
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidSignature));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_invalid_context() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let (secret, public) = cipher_suite_provider
+ .signature_key_generate()
+ .await
+ .unwrap();
+
+ let correct_context = random_bytes(32);
+ let incorrect_context = random_bytes(32);
+
+ let mut test_signable = TestSignable {
+ content: random_bytes(32),
+ signature: vec![],
+ };
+
+ test_signable
+ .sign(&cipher_suite_provider, &secret, &correct_context)
+ .await
+ .unwrap();
+
+ let res = test_signable
+ .verify(&cipher_suite_provider, &public, &incorrect_context)
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidSignature));
+ }
+}
diff --git a/src/storage_provider.rs b/src/storage_provider.rs
new file mode 100644
index 0000000..ffe8cd9
--- /dev/null
+++ b/src/storage_provider.rs
@@ -0,0 +1,14 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+/// Storage providers that operate completely in memory.
+pub mod in_memory;
+pub(crate) mod key_package;
+
+pub use key_package::*;
+
+#[cfg(feature = "sqlite")]
+#[cfg_attr(docsrs, doc(cfg(feature = "sqlite")))]
+/// SQLite based storage providers.
+pub mod sqlite;
diff --git a/src/storage_provider/group_state.rs b/src/storage_provider/group_state.rs
new file mode 100644
index 0000000..b6c854d
--- /dev/null
+++ b/src/storage_provider/group_state.rs
@@ -0,0 +1,43 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::vec::Vec;
+use mls_rs_codec::MlsEncode;
+pub use mls_rs_core::group::{EpochRecord, GroupState};
+
+use crate::group::snapshot::Snapshot;
+
+#[cfg(feature = "prior_epoch")]
+use crate::group::epoch::PriorEpoch;
+
+#[cfg(feature = "prior_epoch")]
+impl EpochRecord for PriorEpoch {
+ fn id(&self) -> u64 {
+ self.epoch_id()
+ }
+}
+
+impl GroupState for Snapshot {
+ fn id(&self) -> Vec<u8> {
+ self.group_id().to_vec()
+ }
+}
+
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub(crate) struct EpochData {
+ pub(crate) id: u64,
+ pub(crate) data: Vec<u8>,
+}
+
+impl EpochData {
+ pub(crate) fn new<T>(value: T) -> Result<Self, mls_rs_codec::Error>
+ where
+ T: MlsEncode + EpochRecord,
+ {
+ Ok(Self {
+ id: value.id(),
+ data: value.mls_encode_to_vec()?,
+ })
+ }
+}
diff --git a/src/storage_provider/in_memory.rs b/src/storage_provider/in_memory.rs
new file mode 100644
index 0000000..cb8f5d7
--- /dev/null
+++ b/src/storage_provider/in_memory.rs
@@ -0,0 +1,11 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+mod group_state_storage;
+mod key_package_storage;
+mod psk_storage;
+
+pub use group_state_storage::*;
+pub use key_package_storage::*;
+pub use psk_storage::*;
diff --git a/src/storage_provider/in_memory/group_state_storage.rs b/src/storage_provider/in_memory/group_state_storage.rs
new file mode 100644
index 0000000..5999ed0
--- /dev/null
+++ b/src/storage_provider/in_memory/group_state_storage.rs
@@ -0,0 +1,354 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::collections::VecDeque;
+
+#[cfg(target_has_atomic = "ptr")]
+use alloc::sync::Arc;
+
+#[cfg(mls_build_async)]
+use alloc::boxed::Box;
+use alloc::vec::Vec;
+use core::{
+ convert::Infallible,
+ fmt::{self, Debug},
+};
+use mls_rs_core::group::{EpochRecord, GroupState, GroupStateStorage};
+#[cfg(not(target_has_atomic = "ptr"))]
+use portable_atomic_util::Arc;
+
+use crate::client::MlsError;
+
+#[cfg(feature = "std")]
+use std::collections::{hash_map::Entry, HashMap};
+
+#[cfg(not(feature = "std"))]
+use alloc::collections::{btree_map::Entry, BTreeMap};
+
+#[cfg(feature = "std")]
+use std::sync::Mutex;
+
+#[cfg(not(feature = "std"))]
+use spin::Mutex;
+
+pub(crate) const DEFAULT_EPOCH_RETENTION_LIMIT: usize = 3;
+
+#[derive(Clone)]
+pub(crate) struct InMemoryGroupData {
+ pub(crate) state_data: Vec<u8>,
+ pub(crate) epoch_data: VecDeque<EpochRecord>,
+}
+
+impl Debug for InMemoryGroupData {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("InMemoryGroupData")
+ .field(
+ "state_data",
+ &mls_rs_core::debug::pretty_bytes(&self.state_data),
+ )
+ .field("epoch_data", &self.epoch_data)
+ .finish()
+ }
+}
+
+impl InMemoryGroupData {
+ pub fn new(state_data: Vec<u8>) -> InMemoryGroupData {
+ InMemoryGroupData {
+ state_data,
+ epoch_data: Default::default(),
+ }
+ }
+
+ fn get_epoch_data_index(&self, epoch_id: u64) -> Option<u64> {
+ self.epoch_data
+ .front()
+ .and_then(|e| epoch_id.checked_sub(e.id))
+ }
+
+ pub fn get_epoch(&self, epoch_id: u64) -> Option<&EpochRecord> {
+ self.get_epoch_data_index(epoch_id)
+ .and_then(|i| self.epoch_data.get(i as usize))
+ }
+
+ pub fn get_mut_epoch(&mut self, epoch_id: u64) -> Option<&mut EpochRecord> {
+ self.get_epoch_data_index(epoch_id)
+ .and_then(|i| self.epoch_data.get_mut(i as usize))
+ }
+
+ pub fn insert_epoch(&mut self, epoch: EpochRecord) {
+ self.epoch_data.push_back(epoch)
+ }
+
+ // This function does not fail if an update can't be made. If the epoch
+ // is not in the store, then it can no longer be accessed by future
+ // get_epoch calls and is no longer relevant.
+ pub fn update_epoch(&mut self, epoch: EpochRecord) {
+ if let Some(existing_epoch) = self.get_mut_epoch(epoch.id) {
+ *existing_epoch = epoch
+ }
+ }
+
+ pub fn trim_epochs(&mut self, max_epoch_retention: usize) {
+ while self.epoch_data.len() > max_epoch_retention {
+ self.epoch_data.pop_front();
+ }
+ }
+}
+
+#[derive(Clone)]
+/// In memory group state storage backed by a HashMap.
+///
+/// All clones of an instance of this type share the same underlying HashMap.
+pub struct InMemoryGroupStateStorage {
+ #[cfg(feature = "std")]
+ pub(crate) inner: Arc<Mutex<HashMap<Vec<u8>, InMemoryGroupData>>>,
+ #[cfg(not(feature = "std"))]
+ pub(crate) inner: Arc<Mutex<BTreeMap<Vec<u8>, InMemoryGroupData>>>,
+ pub(crate) max_epoch_retention: usize,
+}
+
+impl Debug for InMemoryGroupStateStorage {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("InMemoryGroupStateStorage")
+ .field(
+ "inner",
+ &mls_rs_core::debug::pretty_with(|f| {
+ f.debug_map()
+ .entries(
+ self.lock()
+ .iter()
+ .map(|(k, v)| (mls_rs_core::debug::pretty_bytes(k), v)),
+ )
+ .finish()
+ }),
+ )
+ .field("max_epoch_retention", &self.max_epoch_retention)
+ .finish()
+ }
+}
+
+impl InMemoryGroupStateStorage {
+ /// Create an empty group state storage.
+ pub fn new() -> Self {
+ Self {
+ inner: Default::default(),
+ max_epoch_retention: DEFAULT_EPOCH_RETENTION_LIMIT,
+ }
+ }
+
+ pub fn with_max_epoch_retention(self, max_epoch_retention: usize) -> Result<Self, MlsError> {
+ (max_epoch_retention > 0)
+ .then_some(())
+ .ok_or(MlsError::NonZeroRetentionRequired)?;
+
+ Ok(Self {
+ inner: self.inner,
+ max_epoch_retention,
+ })
+ }
+
+ /// Get the set of unique group ids that have data stored.
+ pub fn stored_groups(&self) -> Vec<Vec<u8>> {
+ self.lock().keys().cloned().collect()
+ }
+
+ /// Delete all data corresponding to `group_id`.
+ pub fn delete_group(&self, group_id: &[u8]) {
+ self.lock().remove(group_id);
+ }
+
+ #[cfg(feature = "std")]
+ fn lock(&self) -> std::sync::MutexGuard<'_, HashMap<Vec<u8>, InMemoryGroupData>> {
+ self.inner.lock().unwrap()
+ }
+
+ #[cfg(not(feature = "std"))]
+ fn lock(&self) -> spin::mutex::MutexGuard<'_, BTreeMap<Vec<u8>, InMemoryGroupData>> {
+ self.inner.lock()
+ }
+}
+
+impl Default for InMemoryGroupStateStorage {
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(mls_build_async, maybe_async::must_be_async)]
+impl GroupStateStorage for InMemoryGroupStateStorage {
+ type Error = Infallible;
+
+ async fn max_epoch_id(&self, group_id: &[u8]) -> Result<Option<u64>, Self::Error> {
+ Ok(self
+ .lock()
+ .get(group_id)
+ .and_then(|group_data| group_data.epoch_data.back().map(|e| e.id)))
+ }
+
+ async fn state(&self, group_id: &[u8]) -> Result<Option<Vec<u8>>, Self::Error> {
+ Ok(self
+ .lock()
+ .get(group_id)
+ .map(|data| data.state_data.clone()))
+ }
+
+ async fn epoch(&self, group_id: &[u8], epoch_id: u64) -> Result<Option<Vec<u8>>, Self::Error> {
+ Ok(self
+ .lock()
+ .get(group_id)
+ .and_then(|data| data.get_epoch(epoch_id).map(|ep| ep.data.clone())))
+ }
+
+ async fn write(
+ &mut self,
+ state: GroupState,
+ epoch_inserts: Vec<EpochRecord>,
+ epoch_updates: Vec<EpochRecord>,
+ ) -> Result<(), Self::Error> {
+ let mut group_map = self.lock();
+
+ let group_data = match group_map.entry(state.id) {
+ Entry::Occupied(entry) => {
+ let data = entry.into_mut();
+ data.state_data = state.data;
+ data
+ }
+ Entry::Vacant(entry) => entry.insert(InMemoryGroupData::new(state.data)),
+ };
+
+ epoch_inserts
+ .into_iter()
+ .for_each(|e| group_data.insert_epoch(e));
+
+ epoch_updates
+ .into_iter()
+ .for_each(|e| group_data.update_epoch(e));
+
+ group_data.trim_epochs(self.max_epoch_retention);
+
+ Ok(())
+ }
+}
+
+#[cfg(all(test, feature = "prior_epoch"))]
+mod tests {
+ use alloc::{format, vec, vec::Vec};
+ use assert_matches::assert_matches;
+
+ use super::{InMemoryGroupData, InMemoryGroupStateStorage};
+ use crate::{client::MlsError, group::test_utils::TEST_GROUP};
+
+ use mls_rs_core::group::{EpochRecord, GroupState, GroupStateStorage};
+
+ impl InMemoryGroupStateStorage {
+ fn test_data(&self) -> InMemoryGroupData {
+ self.lock().get(TEST_GROUP).unwrap().clone()
+ }
+ }
+
+ fn test_storage(retention_limit: usize) -> Result<InMemoryGroupStateStorage, MlsError> {
+ InMemoryGroupStateStorage::new().with_max_epoch_retention(retention_limit)
+ }
+
+ fn test_epoch(epoch_id: u64) -> EpochRecord {
+ EpochRecord::new(epoch_id, format!("epoch {epoch_id}").as_bytes().to_vec())
+ }
+
+ fn test_snapshot(epoch_id: u64) -> GroupState {
+ GroupState {
+ id: TEST_GROUP.into(),
+ data: format!("snapshot {epoch_id}").as_bytes().to_vec(),
+ }
+ }
+
+ #[test]
+ fn test_zero_max_retention() {
+ assert_matches!(test_storage(0), Err(MlsError::NonZeroRetentionRequired))
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn existing_storage_can_have_larger_epoch_count() {
+ let mut storage = test_storage(2).unwrap();
+
+ let epoch_inserts = vec![test_epoch(0), test_epoch(1)];
+
+ storage
+ .write(test_snapshot(0), epoch_inserts, Vec::new())
+ .await
+ .unwrap();
+
+ assert_eq!(storage.test_data().epoch_data.len(), 2);
+
+ storage.max_epoch_retention = 4;
+
+ let epoch_inserts = vec![test_epoch(3), test_epoch(4)];
+
+ storage
+ .write(test_snapshot(1), epoch_inserts, Vec::new())
+ .await
+ .unwrap();
+
+ assert_eq!(storage.test_data().epoch_data.len(), 4);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn existing_storage_can_have_smaller_epoch_count() {
+ let mut storage = test_storage(4).unwrap();
+
+ let epoch_inserts = vec![test_epoch(0), test_epoch(1), test_epoch(3), test_epoch(4)];
+
+ storage
+ .write(test_snapshot(1), epoch_inserts, Vec::new())
+ .await
+ .unwrap();
+
+ assert_eq!(storage.test_data().epoch_data.len(), 4);
+
+ storage.max_epoch_retention = 2;
+
+ let epoch_inserts = vec![test_epoch(5)];
+
+ storage
+ .write(test_snapshot(1), epoch_inserts, Vec::new())
+ .await
+ .unwrap();
+
+ assert_eq!(storage.test_data().epoch_data.len(), 2);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn epoch_insert_over_limit() {
+ test_epoch_insert_over_limit(false).await
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn epoch_insert_over_limit_with_update() {
+ test_epoch_insert_over_limit(true).await
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn test_epoch_insert_over_limit(with_update: bool) {
+ let mut storage = test_storage(1).unwrap();
+
+ let mut epoch_inserts = vec![test_epoch(0), test_epoch(1)];
+ let updates = with_update
+ .then_some(vec![test_epoch(0)])
+ .unwrap_or_default();
+ let snapshot = test_snapshot(1);
+
+ storage
+ .write(snapshot.clone(), epoch_inserts.clone(), updates)
+ .await
+ .unwrap();
+
+ let stored = storage.test_data();
+
+ assert_eq!(stored.state_data, snapshot.data);
+ assert_eq!(stored.epoch_data.len(), 1);
+
+ let expected = epoch_inserts.pop().unwrap();
+ assert_eq!(stored.epoch_data[0], expected);
+ }
+}
diff --git a/src/storage_provider/in_memory/key_package_storage.rs b/src/storage_provider/in_memory/key_package_storage.rs
new file mode 100644
index 0000000..427a8a4
--- /dev/null
+++ b/src/storage_provider/in_memory/key_package_storage.rs
@@ -0,0 +1,120 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+#[cfg(target_has_atomic = "ptr")]
+use alloc::sync::Arc;
+
+#[cfg(not(target_has_atomic = "ptr"))]
+use portable_atomic_util::Arc;
+
+use core::{
+ convert::Infallible,
+ fmt::{self, Debug},
+};
+
+#[cfg(feature = "std")]
+use std::collections::HashMap;
+
+#[cfg(not(feature = "std"))]
+use alloc::collections::BTreeMap;
+use alloc::vec::Vec;
+use mls_rs_core::key_package::{KeyPackageData, KeyPackageStorage};
+
+#[cfg(feature = "std")]
+use std::sync::Mutex;
+
+#[cfg(mls_build_async)]
+use alloc::boxed::Box;
+#[cfg(not(feature = "std"))]
+use spin::Mutex;
+
+#[derive(Clone, Default)]
+/// In memory key package storage backed by a HashMap.
+///
+/// All clones of an instance of this type share the same underlying HashMap.
+pub struct InMemoryKeyPackageStorage {
+ #[cfg(feature = "std")]
+ inner: Arc<Mutex<HashMap<Vec<u8>, KeyPackageData>>>,
+ #[cfg(not(feature = "std"))]
+ inner: Arc<Mutex<BTreeMap<Vec<u8>, KeyPackageData>>>,
+}
+
+impl Debug for InMemoryKeyPackageStorage {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("InMemoryKeyPackageStorage")
+ .field(
+ "inner",
+ &mls_rs_core::debug::pretty_with(|f| {
+ f.debug_map()
+ .entries(
+ self.lock()
+ .iter()
+ .map(|(k, v)| (mls_rs_core::debug::pretty_bytes(k), v)),
+ )
+ .finish()
+ }),
+ )
+ .finish()
+ }
+}
+
+impl InMemoryKeyPackageStorage {
+ /// Create an empty key package storage.
+ pub fn new() -> Self {
+ Default::default()
+ }
+
+ /// Insert key package data.
+ pub fn insert(&self, id: Vec<u8>, pkg: KeyPackageData) {
+ self.lock().insert(id, pkg);
+ }
+
+ /// Get a key package data by `id`.
+ pub fn get(&self, id: &[u8]) -> Option<KeyPackageData> {
+ self.lock().get(id).cloned()
+ }
+
+ /// Delete key package data by `id`.
+ pub fn delete(&self, id: &[u8]) {
+ self.lock().remove(id);
+ }
+
+ /// Get all key packages that are currently stored.
+ pub fn key_packages(&self) -> Vec<(Vec<u8>, KeyPackageData)> {
+ self.lock()
+ .iter()
+ .map(|(k, v)| (k.clone(), v.clone()))
+ .collect()
+ }
+
+ #[cfg(feature = "std")]
+ fn lock(&self) -> std::sync::MutexGuard<'_, HashMap<Vec<u8>, KeyPackageData>> {
+ self.inner.lock().unwrap()
+ }
+
+ #[cfg(not(feature = "std"))]
+ fn lock(&self) -> spin::mutex::MutexGuard<'_, BTreeMap<Vec<u8>, KeyPackageData>> {
+ self.inner.lock()
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(mls_build_async, maybe_async::must_be_async)]
+impl KeyPackageStorage for InMemoryKeyPackageStorage {
+ type Error = Infallible;
+
+ async fn delete(&mut self, id: &[u8]) -> Result<(), Self::Error> {
+ (*self).delete(id);
+ Ok(())
+ }
+
+ async fn insert(&mut self, id: Vec<u8>, pkg: KeyPackageData) -> Result<(), Self::Error> {
+ (*self).insert(id, pkg);
+ Ok(())
+ }
+
+ async fn get(&self, id: &[u8]) -> Result<Option<KeyPackageData>, Self::Error> {
+ Ok(self.get(id))
+ }
+}
diff --git a/src/storage_provider/in_memory/psk_storage.rs b/src/storage_provider/in_memory/psk_storage.rs
new file mode 100644
index 0000000..e1b0b75
--- /dev/null
+++ b/src/storage_provider/in_memory/psk_storage.rs
@@ -0,0 +1,83 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+#[cfg(target_has_atomic = "ptr")]
+use alloc::sync::Arc;
+
+#[cfg(not(target_has_atomic = "ptr"))]
+use portable_atomic_util::Arc;
+
+use core::convert::Infallible;
+
+#[cfg(feature = "std")]
+use std::collections::HashMap;
+
+#[cfg(not(feature = "std"))]
+use alloc::collections::BTreeMap;
+
+use mls_rs_core::psk::{ExternalPskId, PreSharedKey, PreSharedKeyStorage};
+
+#[cfg(mls_build_async)]
+use alloc::boxed::Box;
+#[cfg(feature = "std")]
+use std::sync::Mutex;
+
+#[cfg(not(feature = "std"))]
+use spin::Mutex;
+
+#[derive(Clone, Debug, Default)]
+/// In memory pre-shared key storage backed by a HashMap.
+///
+/// All clones of an instance of this type share the same underlying HashMap.
+pub struct InMemoryPreSharedKeyStorage {
+ #[cfg(feature = "std")]
+ inner: Arc<Mutex<HashMap<ExternalPskId, PreSharedKey>>>,
+ #[cfg(not(feature = "std"))]
+ inner: Arc<Mutex<BTreeMap<ExternalPskId, PreSharedKey>>>,
+}
+
+impl InMemoryPreSharedKeyStorage {
+ /// Insert a pre-shared key into storage.
+ pub fn insert(&mut self, id: ExternalPskId, psk: PreSharedKey) {
+ #[cfg(feature = "std")]
+ let mut lock = self.inner.lock().unwrap();
+
+ #[cfg(not(feature = "std"))]
+ let mut lock = self.inner.lock();
+
+ lock.insert(id, psk);
+ }
+
+ /// Get a pre-shared key by `id`.
+ pub fn get(&self, id: &ExternalPskId) -> Option<PreSharedKey> {
+ #[cfg(feature = "std")]
+ let lock = self.inner.lock().unwrap();
+
+ #[cfg(not(feature = "std"))]
+ let lock = self.inner.lock();
+
+ lock.get(id).cloned()
+ }
+
+ /// Delete a pre-shared key from storage.
+ pub fn delete(&mut self, id: &ExternalPskId) {
+ #[cfg(feature = "std")]
+ let mut lock = self.inner.lock().unwrap();
+
+ #[cfg(not(feature = "std"))]
+ let mut lock = self.inner.lock();
+
+ lock.remove(id);
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(mls_build_async, maybe_async::must_be_async)]
+impl PreSharedKeyStorage for InMemoryPreSharedKeyStorage {
+ type Error = Infallible;
+
+ async fn get(&self, id: &ExternalPskId) -> Result<Option<PreSharedKey>, Self::Error> {
+ Ok(self.get(id))
+ }
+}
diff --git a/src/storage_provider/key_package.rs b/src/storage_provider/key_package.rs
new file mode 100644
index 0000000..1e209fb
--- /dev/null
+++ b/src/storage_provider/key_package.rs
@@ -0,0 +1,5 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+pub use mls_rs_core::key_package::KeyPackageData;
diff --git a/src/storage_provider/sqlite.rs b/src/storage_provider/sqlite.rs
new file mode 100644
index 0000000..f4e4f1f
--- /dev/null
+++ b/src/storage_provider/sqlite.rs
@@ -0,0 +1,5 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+pub use mls_rs_provider_sqlite::*;
diff --git a/src/test_utils/benchmarks.rs b/src/test_utils/benchmarks.rs
new file mode 100644
index 0000000..93d8964
--- /dev/null
+++ b/src/test_utils/benchmarks.rs
@@ -0,0 +1,140 @@
+use mls_rs_codec::MlsEncode;
+use mls_rs_core::protocol_version::ProtocolVersion;
+
+use crate::{
+ cipher_suite::CipherSuite,
+ client_builder::{BaseConfig, MlsConfig, WithCryptoProvider, WithIdentityProvider},
+ group::{framing::MlsMessage, Group},
+ identity::basic::BasicIdentityProvider,
+ test_utils::{generate_basic_client, get_test_groups},
+};
+
+pub use mls_rs_crypto_openssl::OpensslCryptoProvider as MlsCryptoProvider;
+
+pub type TestClientConfig =
+ WithIdentityProvider<BasicIdentityProvider, WithCryptoProvider<MlsCryptoProvider, BaseConfig>>;
+
+macro_rules! load_test_case_mls {
+ ($name:ident, $generate:expr) => {
+ load_test_case_mls!($name, $generate, to_vec_pretty)
+ };
+ ($name:ident, $generate:expr, $to_json:ident) => {{
+ #[cfg(any(target_arch = "wasm32", not(feature = "std")))]
+ {
+ // Do not remove `async`! (The goal of this line is to remove warnings
+ // about `$generate` not being used. Actually calling it will make tests fail.)
+ let _ = async { $generate };
+
+ mls_rs_codec::MlsDecode::mls_decode(&mut &include_bytes!(concat!(
+ env!("CARGO_MANIFEST_DIR"),
+ "/test_data/",
+ stringify!($name),
+ ".mls"
+ )))
+ .unwrap()
+ }
+
+ #[cfg(all(not(target_arch = "wasm32"), feature = "std"))]
+ {
+ let path = concat!(
+ env!("CARGO_MANIFEST_DIR"),
+ "/test_data/",
+ stringify!($name),
+ ".mls"
+ );
+
+ if !std::path::Path::new(path).exists() {
+ std::fs::write(path, $generate.mls_encode_to_vec().unwrap()).unwrap();
+ }
+
+ mls_rs_codec::MlsDecode::mls_decode(&mut std::fs::read(path).unwrap().as_slice())
+ .unwrap()
+ }
+ }};
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+async fn generate_test_cases(cs: CipherSuite) -> Vec<MlsMessage> {
+ let mut cases = Vec::new();
+
+ for size in [16, 64, 128] {
+ let group = get_test_groups(
+ ProtocolVersion::MLS_10,
+ cs,
+ size,
+ None,
+ false,
+ &MlsCryptoProvider::new(),
+ )
+ .await
+ .pop()
+ .unwrap();
+
+ let group_info = group
+ .group_info_message_allowing_ext_commit(true)
+ .await
+ .unwrap();
+
+ cases.push(group_info)
+ }
+
+ cases
+}
+
+#[derive(Clone)]
+pub struct GroupStates<C: MlsConfig> {
+ pub sender: Group<C>,
+ pub receiver: Group<C>,
+}
+
+#[cfg(mls_build_async)]
+pub fn load_group_states(cs: CipherSuite) -> Vec<GroupStates<impl MlsConfig>> {
+ let group_info = load_test_case_mls!(group_state, block_on(generate_test_cases(cs)), to_vec);
+ join_group(cs, group_info)
+}
+
+#[cfg(not(mls_build_async))]
+pub fn load_group_states(cs: CipherSuite) -> Vec<GroupStates<impl MlsConfig>> {
+ let group_infos: Vec<MlsMessage> =
+ load_test_case_mls!(group_state, generate_test_cases(cs), to_vec);
+
+ group_infos
+ .into_iter()
+ .map(|info| join_group(cs, info))
+ .collect()
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub async fn join_group(cs: CipherSuite, group_info: MlsMessage) -> GroupStates<impl MlsConfig> {
+ let client = generate_basic_client(
+ cs,
+ ProtocolVersion::MLS_10,
+ 99999999999,
+ None,
+ false,
+ &MlsCryptoProvider::new(),
+ None,
+ );
+
+ let mut sender = client.commit_external(group_info).await.unwrap().0;
+
+ let client = generate_basic_client(
+ cs,
+ ProtocolVersion::MLS_10,
+ 99999999998,
+ None,
+ false,
+ &MlsCryptoProvider::new(),
+ None,
+ );
+
+ let group_info = sender
+ .group_info_message_allowing_ext_commit(true)
+ .await
+ .unwrap();
+
+ let (receiver, commit) = client.commit_external(group_info).await.unwrap();
+ sender.process_incoming_message(commit).await.unwrap();
+
+ GroupStates { sender, receiver }
+}
diff --git a/src/test_utils/fuzz_tests.rs b/src/test_utils/fuzz_tests.rs
new file mode 100644
index 0000000..9ec143e
--- /dev/null
+++ b/src/test_utils/fuzz_tests.rs
@@ -0,0 +1,109 @@
+use std::sync::Mutex;
+
+use mls_rs_core::{
+ crypto::{CipherSuiteProvider, CryptoProvider, SignatureSecretKey},
+ identity::BasicCredential,
+};
+
+use once_cell::sync::Lazy;
+
+use crate::{
+ cipher_suite::CipherSuite,
+ client::MlsError,
+ client_builder::{BaseConfig, WithCryptoProvider, WithIdentityProvider},
+ group::{
+ framing::{Content, MlsMessage, Sender, WireFormat},
+ message_processor::MessageProcessor,
+ message_signature::AuthenticatedContent,
+ Commit, Group,
+ },
+ identity::{basic::BasicIdentityProvider, SigningIdentity},
+ Client, ExtensionList,
+};
+
+#[cfg(awslc)]
+pub use mls_rs_crypto_awslc::AwsLcCryptoProvider as MlsCryptoProvider;
+#[cfg(not(any(awslc, rustcrypto)))]
+pub use mls_rs_crypto_openssl::OpensslCryptoProvider as MlsCryptoProvider;
+#[cfg(rustcrypto)]
+pub use mls_rs_crypto_rustcrypto::RustCryptoProvider as MlsCryptoProvider;
+
+pub type TestClientConfig =
+ WithIdentityProvider<BasicIdentityProvider, WithCryptoProvider<MlsCryptoProvider, BaseConfig>>;
+
+pub static GROUP: Lazy<Mutex<Group<TestClientConfig>>> = Lazy::new(|| Mutex::new(create_group()));
+
+pub fn create_group() -> Group<TestClientConfig> {
+ let cipher_suite = CipherSuite::CURVE25519_AES128;
+ let alice = make_client(cipher_suite, "alice");
+ let bob = make_client(cipher_suite, "bob");
+
+ let mut alice = alice.create_group(ExtensionList::new()).unwrap();
+
+ alice
+ .commit_builder()
+ .add_member(bob.generate_key_package_message().unwrap())
+ .unwrap()
+ .build()
+ .unwrap();
+
+ alice.apply_pending_commit().unwrap();
+
+ alice
+}
+
+pub fn create_fuzz_commit_message(
+ group_id: Vec<u8>,
+ epoch: u64,
+ authenticated_data: Vec<u8>,
+) -> Result<MlsMessage, MlsError> {
+ let mut group = GROUP.lock().unwrap();
+
+ let mut context = group.context().clone();
+ context.group_id = group_id;
+ context.epoch = epoch;
+
+ #[cfg(feature = "private_message")]
+ let wire_format = WireFormat::PrivateMessage;
+
+ #[cfg(not(feature = "private_message"))]
+ let wire_format = WireFormat::PublicMessage;
+
+ let auth_content = AuthenticatedContent::new_signed(
+ group.cipher_suite_provider(),
+ &context,
+ Sender::Member(0),
+ Content::Commit(alloc::boxed::Box::new(Commit {
+ proposals: Vec::new(),
+ path: None,
+ })),
+ &group.signer,
+ wire_format,
+ authenticated_data,
+ )?;
+
+ group.format_for_wire(auth_content)
+}
+
+fn make_client(cipher_suite: CipherSuite, name: &str) -> Client<TestClientConfig> {
+ let (secret, signing_identity) = make_identity(cipher_suite, name);
+
+ // TODO : consider fuzzing on encrypted controls (doesn't seem very useful)
+ Client::builder()
+ .identity_provider(BasicIdentityProvider)
+ .crypto_provider(MlsCryptoProvider::default())
+ .signing_identity(signing_identity, secret, cipher_suite)
+ .build()
+}
+
+fn make_identity(cipher_suite: CipherSuite, name: &str) -> (SignatureSecretKey, SigningIdentity) {
+ let cipher_suite = MlsCryptoProvider::new()
+ .cipher_suite_provider(cipher_suite)
+ .unwrap();
+
+ let (secret, public) = cipher_suite.signature_key_generate().unwrap();
+ let basic_identity = BasicCredential::new(name.as_bytes().to_vec());
+ let signing_identity = SigningIdentity::new(basic_identity.into_credential(), public);
+
+ (secret, signing_identity)
+}
diff --git a/src/test_utils/mod.rs b/src/test_utils/mod.rs
new file mode 100644
index 0000000..d7c238b
--- /dev/null
+++ b/src/test_utils/mod.rs
@@ -0,0 +1,184 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+#[cfg(all(feature = "benchmark_util", not(mls_build_async)))]
+pub mod benchmarks;
+
+#[cfg(all(feature = "fuzz_util", not(mls_build_async)))]
+pub mod fuzz_tests;
+
+use mls_rs_core::{
+ crypto::{CipherSuite, CipherSuiteProvider, CryptoProvider},
+ identity::{BasicCredential, Credential, SigningIdentity},
+ protocol_version::ProtocolVersion,
+ psk::ExternalPskId,
+};
+
+use crate::{
+ client_builder::{ClientBuilder, MlsConfig},
+ identity::basic::BasicIdentityProvider,
+ mls_rules::{CommitOptions, DefaultMlsRules},
+ tree_kem::Lifetime,
+ Client, Group, MlsMessage,
+};
+
+#[cfg(feature = "private_message")]
+use crate::group::{mls_rules::EncryptionOptions, padding::PaddingMode};
+
+use alloc::{vec, vec::Vec};
+
+#[cfg_attr(coverage_nightly, coverage(off))]
+pub fn get_test_basic_credential(identity: Vec<u8>) -> Credential {
+ BasicCredential::new(identity).into_credential()
+}
+
+pub const TEST_EXT_PSK_ID: &[u8] = b"external psk";
+
+#[cfg_attr(coverage_nightly, coverage(off))]
+pub fn make_test_ext_psk() -> Vec<u8> {
+ b"secret psk key".to_vec()
+}
+
+pub fn is_edwards(cs: u16) -> bool {
+ [
+ CipherSuite::CURVE25519_AES128,
+ CipherSuite::CURVE25519_CHACHA,
+ CipherSuite::CURVE448_AES256,
+ CipherSuite::CURVE448_CHACHA,
+ ]
+ .contains(&cs.into())
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(coverage_nightly, coverage(off))]
+pub async fn generate_basic_client<C: CryptoProvider + Clone>(
+ cipher_suite: CipherSuite,
+ protocol_version: ProtocolVersion,
+ id: usize,
+ commit_options: Option<CommitOptions>,
+ #[cfg(feature = "private_message")] encrypt_controls: bool,
+ #[cfg(not(feature = "private_message"))] _encrypt_controls: bool,
+ crypto: &C,
+ lifetime: Option<Lifetime>,
+) -> Client<impl MlsConfig> {
+ let cs = crypto.cipher_suite_provider(cipher_suite).unwrap();
+
+ let (secret_key, public_key) = cs.signature_key_generate().await.unwrap();
+ let credential = get_test_basic_credential(alloc::format!("{id}").into_bytes());
+
+ let identity = SigningIdentity::new(credential, public_key);
+
+ let mls_rules =
+ DefaultMlsRules::default().with_commit_options(commit_options.unwrap_or_default());
+
+ #[cfg(feature = "private_message")]
+ let mls_rules = if encrypt_controls {
+ mls_rules.with_encryption_options(EncryptionOptions::new(true, PaddingMode::None))
+ } else {
+ mls_rules
+ };
+
+ let mut builder = ClientBuilder::new()
+ .crypto_provider(crypto.clone())
+ .identity_provider(BasicIdentityProvider::new())
+ .mls_rules(mls_rules)
+ .psk(
+ ExternalPskId::new(TEST_EXT_PSK_ID.to_vec()),
+ make_test_ext_psk().into(),
+ )
+ .used_protocol_version(protocol_version)
+ .signing_identity(identity, secret_key, cipher_suite);
+
+ if let Some(lifetime) = lifetime {
+ builder = builder
+ .key_package_lifetime(lifetime.not_after - lifetime.not_before)
+ .key_package_not_before(lifetime.not_before);
+ }
+
+ builder.build()
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(coverage_nightly, coverage(off))]
+pub async fn get_test_groups<C: CryptoProvider + Clone>(
+ version: ProtocolVersion,
+ cipher_suite: CipherSuite,
+ num_participants: usize,
+ commit_options: Option<CommitOptions>,
+ encrypt_controls: bool,
+ crypto: &C,
+) -> Vec<Group<impl MlsConfig>> {
+ // Create the group with Alice as the group initiator
+ let creator = generate_basic_client(
+ cipher_suite,
+ version,
+ 0,
+ commit_options,
+ encrypt_controls,
+ crypto,
+ None,
+ )
+ .await;
+
+ let mut creator_group = creator.create_group(Default::default()).await.unwrap();
+
+ let mut receiver_clients = Vec::new();
+ let mut commit_builder = creator_group.commit_builder();
+
+ for i in 1..num_participants {
+ let client = generate_basic_client(
+ cipher_suite,
+ version,
+ i,
+ commit_options,
+ encrypt_controls,
+ crypto,
+ None,
+ )
+ .await;
+ let kp = client.generate_key_package_message().await.unwrap();
+
+ receiver_clients.push(client);
+ commit_builder = commit_builder.add_member(kp.clone()).unwrap();
+ }
+
+ let welcome = commit_builder.build().await.unwrap().welcome_messages;
+
+ creator_group.apply_pending_commit().await.unwrap();
+
+ let tree_data = creator_group.export_tree().into_owned();
+
+ let mut groups = vec![creator_group];
+
+ for client in &receiver_clients {
+ let (test_client, _info) = client
+ .join_group(Some(tree_data.clone()), &welcome[0])
+ .await
+ .unwrap();
+
+ groups.push(test_client);
+ }
+
+ groups
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(coverage_nightly, coverage(off))]
+pub async fn all_process_message<C: MlsConfig>(
+ groups: &mut [Group<C>],
+ message: &MlsMessage,
+ sender: usize,
+ is_commit: bool,
+) {
+ for group in groups {
+ if sender != group.current_member_index() as usize {
+ group
+ .process_incoming_message(message.clone())
+ .await
+ .unwrap();
+ } else if is_commit {
+ group.apply_pending_commit().await.unwrap();
+ }
+ }
+}
diff --git a/src/tree_kem/capabilities.rs b/src/tree_kem/capabilities.rs
new file mode 100644
index 0000000..6fc498d
--- /dev/null
+++ b/src/tree_kem/capabilities.rs
@@ -0,0 +1,5 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+pub use mls_rs_core::group::Capabilities;
diff --git a/src/tree_kem/hpke_encryption.rs b/src/tree_kem/hpke_encryption.rs
new file mode 100644
index 0000000..77a598a
--- /dev/null
+++ b/src/tree_kem/hpke_encryption.rs
@@ -0,0 +1,172 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::vec::Vec;
+use core::fmt::{self, Debug};
+use mls_rs_codec::{MlsEncode, MlsSize};
+use mls_rs_core::{
+ crypto::{CipherSuiteProvider, HpkeCiphertext, HpkePublicKey, HpkeSecretKey},
+ error::IntoAnyError,
+};
+use zeroize::Zeroizing;
+
+use crate::client::MlsError;
+
+#[derive(Clone, MlsSize, MlsEncode)]
+struct EncryptContext<'a> {
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ label: Vec<u8>,
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ context: &'a [u8],
+}
+
+impl Debug for EncryptContext<'_> {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("EncryptContext")
+ .field("label", &mls_rs_core::debug::pretty_bytes(&self.label))
+ .field("context", &mls_rs_core::debug::pretty_bytes(self.context))
+ .finish()
+ }
+}
+
+impl<'a> EncryptContext<'a> {
+ pub fn new(label: &str, context: &'a [u8]) -> Self {
+ Self {
+ label: [b"MLS 1.0 ", label.as_bytes()].concat(),
+ context,
+ }
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(all(target_arch = "wasm32", mls_build_async), maybe_async::must_be_async(?Send))]
+#[cfg_attr(
+ all(not(target_arch = "wasm32"), mls_build_async),
+ maybe_async::must_be_async
+)]
+
+pub(crate) trait HpkeEncryptable: Sized {
+ const ENCRYPT_LABEL: &'static str;
+
+ async fn encrypt<P: CipherSuiteProvider>(
+ &self,
+ cipher_suite_provider: &P,
+ public_key: &HpkePublicKey,
+ context: &[u8],
+ ) -> Result<HpkeCiphertext, MlsError> {
+ let context = EncryptContext::new(Self::ENCRYPT_LABEL, context)
+ .mls_encode_to_vec()
+ .map(Zeroizing::new)?;
+
+ let content = self.get_bytes().map(Zeroizing::new)?;
+
+ cipher_suite_provider
+ .hpke_seal(public_key, &context, None, &content)
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
+ }
+
+ async fn decrypt<P: CipherSuiteProvider>(
+ cipher_suite_provider: &P,
+ secret_key: &HpkeSecretKey,
+ public_key: &HpkePublicKey,
+ context: &[u8],
+ ciphertext: &HpkeCiphertext,
+ ) -> Result<Self, MlsError> {
+ let context = EncryptContext::new(Self::ENCRYPT_LABEL, context).mls_encode_to_vec()?;
+
+ let plaintext = cipher_suite_provider
+ .hpke_open(ciphertext, secret_key, public_key, &context, None)
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
+
+ Self::from_bytes(plaintext.to_vec())
+ }
+
+ fn from_bytes(bytes: Vec<u8>) -> Result<Self, MlsError>;
+ fn get_bytes(&self) -> Result<Vec<u8>, MlsError>;
+}
+
+#[cfg(test)]
+pub(crate) mod test_utils {
+ use alloc::{string::String, vec::Vec};
+ use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+ use mls_rs_core::crypto::{CipherSuiteProvider, HpkeCiphertext};
+
+ use crate::{client::MlsError, crypto::test_utils::try_test_cipher_suite_provider};
+
+ use super::HpkeEncryptable;
+
+ #[derive(Debug, serde::Serialize, serde::Deserialize)]
+ pub struct HpkeInteropTestCase {
+ #[serde(with = "hex::serde", rename = "priv")]
+ secret: Vec<u8>,
+ #[serde(with = "hex::serde", rename = "pub")]
+ public: Vec<u8>,
+ label: String,
+ #[serde(with = "hex::serde")]
+ context: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ plaintext: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ kem_output: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ ciphertext: Vec<u8>,
+ }
+
+ #[derive(Debug, serde::Serialize, serde::Deserialize)]
+ pub struct InteropTestCase {
+ cipher_suite: u16,
+ encrypt_with_label: HpkeInteropTestCase,
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_basic_crypto_test_vectors() {
+ // The test vector can be found here https://github.com/mlswg/mls-implementations/blob/main/test-vectors/crypto-basics.json
+ let test_cases: Vec<InteropTestCase> =
+ load_test_case_json!(basic_crypto, Vec::<InteropTestCase>::new());
+
+ for test_case in test_cases {
+ if let Some(cs) = try_test_cipher_suite_provider(test_case.cipher_suite) {
+ test_case.encrypt_with_label.verify(&cs).await
+ }
+ }
+ }
+
+ #[derive(Clone, Debug, MlsSize, MlsEncode, MlsDecode)]
+ struct TestEncryptable(#[mls_codec(with = "mls_rs_codec::byte_vec")] Vec<u8>);
+
+ impl HpkeEncryptable for TestEncryptable {
+ const ENCRYPT_LABEL: &'static str = "EncryptWithLabel";
+
+ fn from_bytes(bytes: Vec<u8>) -> Result<Self, MlsError> {
+ Ok(Self(bytes))
+ }
+
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn get_bytes(&self) -> Result<Vec<u8>, MlsError> {
+ Ok(self.0.clone())
+ }
+ }
+
+ impl HpkeInteropTestCase {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn verify<P: CipherSuiteProvider>(&self, cs: &P) {
+ let secret = self.secret.clone().into();
+ let public = self.public.clone().into();
+
+ let ciphertext = HpkeCiphertext {
+ kem_output: self.kem_output.clone(),
+ ciphertext: self.ciphertext.clone(),
+ };
+
+ let computed_plaintext =
+ TestEncryptable::decrypt(cs, &secret, &public, &self.context, &ciphertext)
+ .await
+ .unwrap();
+
+ assert_eq!(&computed_plaintext.0, &self.plaintext)
+ }
+ }
+}
diff --git a/src/tree_kem/interop_test_vectors.rs b/src/tree_kem/interop_test_vectors.rs
new file mode 100644
index 0000000..50e0077
--- /dev/null
+++ b/src/tree_kem/interop_test_vectors.rs
@@ -0,0 +1,199 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::vec;
+use alloc::vec::Vec;
+use mls_rs_codec::{MlsDecode, MlsEncode};
+use mls_rs_core::crypto::{CipherSuite, CipherSuiteProvider};
+
+use itertools::Itertools;
+
+use crate::{
+ crypto::test_utils::try_test_cipher_suite_provider, identity::basic::BasicIdentityProvider,
+};
+
+use super::{
+ node::NodeVec, test_utils::TreeWithSigners, tree_validator::TreeValidator, TreeKemPublic,
+};
+
+#[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
+struct ValidationTestCase {
+ pub cipher_suite: u16,
+
+ #[serde(with = "hex::serde")]
+ pub tree: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub group_id: Vec<u8>,
+ pub tree_hashes: Vec<TreeHash>,
+ pub resolutions: Vec<Vec<u32>>,
+}
+
+#[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
+struct TreeHash(#[serde(with = "hex::serde")] pub Vec<u8>);
+
+impl From<crate::tree_kem::tree_hash::TreeHash> for TreeHash {
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn from(value: crate::tree_kem::tree_hash::TreeHash) -> Self {
+ TreeHash(value.to_vec())
+ }
+}
+
+impl ValidationTestCase {
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn new<P: CipherSuiteProvider>(tree: TreeKemPublic, group_id: &[u8], cs: &P) -> Self {
+ let tree_size = tree.total_leaf_count() * 2 - 1;
+
+ assert!(
+ tree.tree_hashes.current.len() == tree_size as usize,
+ "hashes not initialized"
+ );
+
+ let resolutions = (0..tree_size)
+ .map(
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ |i| tree.nodes.get_resolution_index(i).unwrap(),
+ )
+ .collect();
+
+ Self {
+ cipher_suite: cs.cipher_suite().into(),
+ tree: tree.nodes.mls_encode_to_vec().unwrap(),
+ tree_hashes: tree
+ .tree_hashes
+ .current
+ .into_iter()
+ .map(TreeHash::from)
+ .collect(),
+ group_id: group_id.to_vec(),
+ resolutions,
+ }
+ }
+}
+
+#[cfg(feature = "rfc_compliant")]
+#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+#[cfg_attr(coverage_nightly, coverage(off))]
+async fn validation() {
+ use crate::group::test_utils::get_test_group_context;
+
+ #[cfg(mls_build_async)]
+ let test_cases: Vec<ValidationTestCase> = load_test_case_json!(
+ interop_tree_validation,
+ generate_validation_test_vector().await
+ );
+
+ #[cfg(not(mls_build_async))]
+ let test_cases: Vec<ValidationTestCase> =
+ load_test_case_json!(interop_tree_validation, generate_validation_test_vector());
+
+ for test_case in test_cases.into_iter() {
+ let Some(cs) = try_test_cipher_suite_provider(test_case.cipher_suite) else {
+ continue;
+ };
+
+ let mut tree = TreeKemPublic::import_node_data(
+ NodeVec::mls_decode(&mut &*test_case.tree).unwrap(),
+ &BasicIdentityProvider,
+ &Default::default(),
+ )
+ .await
+ .unwrap();
+
+ let tree_hash = tree.tree_hash(&cs).await.unwrap();
+
+ tree.tree_hashes
+ .current
+ .iter()
+ .zip_eq(test_case.tree_hashes.iter())
+ .for_each(|(l, r)| assert_eq!(**l, *r.0));
+
+ test_case
+ .resolutions
+ .iter()
+ .enumerate()
+ .for_each(|(i, res)| {
+ assert_eq!(&tree.nodes.get_resolution_index(i as u32).unwrap(), res)
+ });
+
+ let mut context = get_test_group_context(1, test_case.cipher_suite.into()).await;
+ context.tree_hash = tree_hash;
+ context.group_id = test_case.group_id;
+
+ TreeValidator::new(&cs, &context, &BasicIdentityProvider)
+ .validate(&mut tree)
+ .await
+ .unwrap();
+ }
+}
+
+#[cfg(feature = "rfc_compliant")]
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(coverage_nightly, coverage(off))]
+async fn generate_validation_test_vector() -> Vec<ValidationTestCase> {
+ let mut test_cases = vec![];
+
+ for cs in CipherSuite::all() {
+ let Some(cs) = try_test_cipher_suite_provider(*cs) else {
+ continue;
+ };
+
+ let mut trees = vec![];
+
+ // Generate trees with increasing complexity. Start: full complete trees
+ for n_leaves in [2, 4, 8, 32] {
+ trees.push(TreeWithSigners::make_full_tree(n_leaves, &cs).await);
+ }
+
+ // Internal blanks, no skipping : 8 leaves, 0 commits removing 2, 3 and adding new member
+ let mut tree = TreeWithSigners::make_full_tree(8, &cs).await;
+ tree.remove_member(2);
+ tree.remove_member(3);
+ tree.add_member("Bob", &cs).await;
+ tree.update_committer_path(0, &cs).await;
+ trees.push(tree);
+
+ // Blanks at the end, no skipping
+ for n_leaves in [3, 5, 7, 33] {
+ trees.push(TreeWithSigners::make_full_tree(n_leaves, &cs).await);
+ }
+
+ // Internal blanks, with skipping : 8 leaves, 0 commits removing 1, 2, 3
+ let mut tree = TreeWithSigners::make_full_tree(8, &cs).await;
+ [1, 2, 3].into_iter().for_each(
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ |i| tree.remove_member(i),
+ );
+ tree.update_committer_path(0, &cs).await;
+ trees.push(tree);
+
+ // Blanks at the end, with skipping
+ for n_leaves in [6, 34] {
+ trees.push(TreeWithSigners::make_full_tree(n_leaves, &cs).await);
+ }
+
+ // Unmerged leaves, no skipping : 7 leaves; 0 commits adding a member
+ let mut tree = TreeWithSigners::make_full_tree(7, &cs).await;
+ tree.add_member("Bob", &cs).await;
+ tree.update_committer_path(0, &cs).await;
+ trees.push(tree);
+
+ // Unmerged leaves, with skipping : figure 20 in the RFC
+ let mut tree = TreeWithSigners::make_full_tree(7, &cs).await;
+ tree.remove_member(5);
+ tree.update_committer_path(0, &cs).await;
+ tree.update_committer_path(4, &cs).await;
+ tree.add_member("Bob", &cs).await;
+ tree.tree.tree_hashes.current = vec![];
+ tree.tree.tree_hash(&cs).await.unwrap();
+ trees.push(tree);
+
+ // Generate tests
+ trees.into_iter().for_each(
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ |tree| test_cases.push(ValidationTestCase::new(tree.tree, &tree.group_id, &cs)),
+ );
+ }
+
+ test_cases
+}
diff --git a/src/tree_kem/kem.rs b/src/tree_kem/kem.rs
new file mode 100644
index 0000000..cedeb0e
--- /dev/null
+++ b/src/tree_kem/kem.rs
@@ -0,0 +1,699 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use crate::client::MlsError;
+use crate::crypto::{CipherSuiteProvider, SignatureSecretKey};
+use crate::group::GroupContext;
+use crate::identity::SigningIdentity;
+use crate::iter::wrap_iter;
+use crate::tree_kem::math as tree_math;
+use alloc::vec;
+use alloc::vec::Vec;
+use itertools::Itertools;
+use mls_rs_codec::MlsEncode;
+use tree_math::{CopathNode, TreeIndex};
+
+#[cfg(all(not(mls_build_async), feature = "rayon"))]
+use {crate::iter::ParallelIteratorExt, rayon::prelude::*};
+
+#[cfg(mls_build_async)]
+use futures::{StreamExt, TryStreamExt};
+
+#[cfg(feature = "std")]
+use std::collections::HashSet;
+
+use super::hpke_encryption::HpkeEncryptable;
+use super::leaf_node::ConfigProperties;
+use super::node::NodeTypeResolver;
+use super::{
+ node::{LeafIndex, NodeIndex},
+ path_secret::{PathSecret, PathSecretGenerator},
+ TreeKemPrivate, TreeKemPublic, UpdatePath, UpdatePathNode, ValidatedUpdatePath,
+};
+
+#[cfg(test)]
+use crate::{group::CommitModifiers, signer::Signable};
+
+pub struct TreeKem<'a> {
+ tree_kem_public: &'a mut TreeKemPublic,
+ private_key: &'a mut TreeKemPrivate,
+}
+
+pub struct EncapGeneration {
+ pub update_path: UpdatePath,
+ pub path_secrets: Vec<Option<PathSecret>>,
+ pub commit_secret: PathSecret,
+}
+
+impl<'a> TreeKem<'a> {
+ pub fn new(
+ tree_kem_public: &'a mut TreeKemPublic,
+ private_key: &'a mut TreeKemPrivate,
+ ) -> Self {
+ TreeKem {
+ tree_kem_public,
+ private_key,
+ }
+ }
+
+ #[allow(clippy::too_many_arguments)]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn encap<P>(
+ self,
+ context: &mut GroupContext,
+ excluding: &[LeafIndex],
+ signer: &SignatureSecretKey,
+ update_leaf_properties: ConfigProperties,
+ signing_identity: Option<SigningIdentity>,
+ cipher_suite_provider: &P,
+ #[cfg(test)] commit_modifiers: &CommitModifiers,
+ ) -> Result<EncapGeneration, MlsError>
+ where
+ P: CipherSuiteProvider + Send + Sync,
+ {
+ let self_index = self.private_key.self_index;
+ let path = self.tree_kem_public.nodes.direct_copath(self_index);
+ let filtered = self.tree_kem_public.nodes.filtered(self_index)?;
+
+ self.private_key.secret_keys.resize(path.len() + 1, None);
+
+ let mut secret_generator = PathSecretGenerator::new(cipher_suite_provider);
+ let mut path_secrets = vec![];
+
+ for (i, (node, f)) in path.iter().zip(&filtered).enumerate() {
+ if !f {
+ let secret = secret_generator.next_secret().await?;
+
+ let (secret_key, public_key) =
+ secret.to_hpke_key_pair(cipher_suite_provider).await?;
+
+ self.private_key.secret_keys[i + 1] = Some(secret_key);
+ self.tree_kem_public.update_node(public_key, node.path)?;
+ path_secrets.push(Some(secret));
+ } else {
+ self.private_key.secret_keys[i + 1] = None;
+ path_secrets.push(None);
+ }
+ }
+
+ #[cfg(test)]
+ (commit_modifiers.modify_tree)(self.tree_kem_public);
+
+ self.tree_kem_public
+ .update_parent_hashes(self_index, false, cipher_suite_provider)
+ .await?;
+
+ let update_path_leaf = {
+ let own_leaf = self.tree_kem_public.nodes.borrow_as_leaf_mut(self_index)?;
+
+ self.private_key.secret_keys[0] = Some(
+ own_leaf
+ .commit(
+ cipher_suite_provider,
+ &context.group_id,
+ *self_index,
+ update_leaf_properties,
+ signing_identity,
+ signer,
+ )
+ .await?,
+ );
+
+ #[cfg(test)]
+ if let Some(signer) = (commit_modifiers.modify_leaf)(own_leaf, signer) {
+ let context = &(context.group_id.as_slice(), *self_index).into();
+
+ own_leaf
+ .sign(cipher_suite_provider, &signer, context)
+ .await
+ .unwrap();
+ }
+
+ own_leaf.clone()
+ };
+
+ // Tree modifications are all done so we can update the tree hash and encrypt with the new context
+ self.tree_kem_public
+ .update_hashes(&[self_index], cipher_suite_provider)
+ .await?;
+
+ context.tree_hash = self
+ .tree_kem_public
+ .tree_hash(cipher_suite_provider)
+ .await?;
+
+ let context_bytes = context.mls_encode_to_vec()?;
+
+ let node_updates = self
+ .encrypt_path_secrets(
+ path,
+ &path_secrets,
+ &context_bytes,
+ cipher_suite_provider,
+ excluding,
+ )
+ .await?;
+
+ #[cfg(test)]
+ let node_updates = (commit_modifiers.modify_path)(node_updates);
+
+ // Create an update path with the new node and parent node updates
+ let update_path = UpdatePath {
+ leaf_node: update_path_leaf,
+ nodes: node_updates,
+ };
+
+ Ok(EncapGeneration {
+ update_path,
+ path_secrets,
+ commit_secret: secret_generator.next_secret().await?,
+ })
+ }
+
+ #[cfg(any(mls_build_async, not(feature = "rayon")))]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn encrypt_path_secrets<P: CipherSuiteProvider>(
+ &self,
+ path: Vec<CopathNode<NodeIndex>>,
+ path_secrets: &[Option<PathSecret>],
+ context_bytes: &[u8],
+ cipher_suite: &P,
+ excluding: &[LeafIndex],
+ ) -> Result<Vec<UpdatePathNode>, MlsError> {
+ let excluding = excluding.iter().copied().map(NodeIndex::from);
+
+ #[cfg(feature = "std")]
+ let excluding = excluding.collect::<HashSet<NodeIndex>>();
+ #[cfg(not(feature = "std"))]
+ let excluding = excluding.collect::<Vec<NodeIndex>>();
+
+ let mut node_updates = Vec::new();
+
+ for (index, path_secret) in path.into_iter().zip(path_secrets.iter()) {
+ if let Some(path_secret) = path_secret {
+ node_updates.push(
+ self.encrypt_copath_node_resolution(
+ cipher_suite,
+ path_secret,
+ index.copath,
+ context_bytes,
+ &excluding,
+ )
+ .await?,
+ );
+ }
+ }
+
+ Ok(node_updates)
+ }
+
+ #[cfg(all(not(mls_build_async), feature = "rayon"))]
+ fn encrypt_path_secrets<P: CipherSuiteProvider>(
+ &self,
+ path: Vec<CopathNode<NodeIndex>>,
+ path_secrets: &[Option<PathSecret>],
+ context_bytes: &[u8],
+ cipher_suite: &P,
+ excluding: &[LeafIndex],
+ ) -> Result<Vec<UpdatePathNode>, MlsError> {
+ let excluding = excluding.iter().copied().map(NodeIndex::from);
+
+ #[cfg(feature = "std")]
+ let excluding = excluding.collect::<HashSet<NodeIndex>>();
+ #[cfg(not(feature = "std"))]
+ let excluding = excluding.collect::<Vec<NodeIndex>>();
+
+ path.into_par_iter()
+ .zip(path_secrets.par_iter())
+ .filter_map(|(node, path_secret)| {
+ path_secret.as_ref().map(|path_secret| {
+ self.encrypt_copath_node_resolution(
+ cipher_suite,
+ path_secret,
+ node.copath,
+ context_bytes,
+ &excluding,
+ )
+ })
+ })
+ .collect()
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn decap<CP>(
+ self,
+ sender_index: LeafIndex,
+ update_path: &ValidatedUpdatePath,
+ added_leaves: &[LeafIndex],
+ context_bytes: &[u8],
+ cipher_suite_provider: &CP,
+ ) -> Result<PathSecret, MlsError>
+ where
+ CP: CipherSuiteProvider,
+ {
+ let self_index = self.private_key.self_index;
+
+ let lca_index =
+ tree_math::leaf_lca_level(self_index.into(), sender_index.into()) as usize - 2;
+
+ let mut path = self.tree_kem_public.nodes.direct_copath(self_index);
+ let leaf = CopathNode::new(self_index.into(), 0);
+ path.insert(0, leaf);
+ let resolved_pos = self.find_resolved_pos(&path, lca_index)?;
+
+ let ct_pos =
+ self.find_ciphertext_pos(path[lca_index].path, path[resolved_pos].path, added_leaves)?;
+
+ let lca_node = update_path.nodes[lca_index]
+ .as_ref()
+ .ok_or(MlsError::LcaNotFoundInDirectPath)?;
+
+ let ct = lca_node
+ .encrypted_path_secret
+ .get(ct_pos)
+ .ok_or(MlsError::LcaNotFoundInDirectPath)?;
+
+ let secret = self.private_key.secret_keys[resolved_pos]
+ .as_ref()
+ .ok_or(MlsError::UpdateErrorNoSecretKey)?;
+
+ let public = self
+ .tree_kem_public
+ .nodes
+ .borrow_node(path[resolved_pos].path)?
+ .as_ref()
+ .ok_or(MlsError::UpdateErrorNoSecretKey)?
+ .public_key();
+
+ let lca_path_secret =
+ PathSecret::decrypt(cipher_suite_provider, secret, public, context_bytes, ct).await?;
+
+ // Derive the rest of the secrets for the tree and assign to the proper nodes
+ let mut node_secret_gen =
+ PathSecretGenerator::starting_with(cipher_suite_provider, lca_path_secret);
+
+ // Update secrets based on the decrypted path secret in the update
+ self.private_key.secret_keys.resize(path.len() + 1, None);
+
+ for (i, update) in update_path.nodes.iter().enumerate().skip(lca_index) {
+ if let Some(update) = update {
+ let secret = node_secret_gen.next_secret().await?;
+
+ // Verify the private key we calculated properly matches the public key we inserted into the tree. This guarantees
+ // that we will be able to decrypt later.
+ let (hpke_private, hpke_public) =
+ secret.to_hpke_key_pair(cipher_suite_provider).await?;
+
+ if hpke_public != update.public_key {
+ return Err(MlsError::PubKeyMismatch);
+ }
+
+ self.private_key.secret_keys[i + 1] = Some(hpke_private);
+ } else {
+ self.private_key.secret_keys[i + 1] = None;
+ }
+ }
+
+ node_secret_gen.next_secret().await
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn encrypt_copath_node_resolution<P: CipherSuiteProvider>(
+ &self,
+ cipher_suite_provider: &P,
+ path_secret: &PathSecret,
+ copath_index: NodeIndex,
+ context: &[u8],
+ #[cfg(feature = "std")] excluding: &HashSet<NodeIndex>,
+ #[cfg(not(feature = "std"))] excluding: &[NodeIndex],
+ ) -> Result<UpdatePathNode, MlsError> {
+ let reso = self
+ .tree_kem_public
+ .nodes
+ .get_resolution_index(copath_index)?;
+
+ let make_ctxt = |idx| async move {
+ let node = self
+ .tree_kem_public
+ .nodes
+ .borrow_node(idx)?
+ .as_non_empty()?;
+
+ path_secret
+ .encrypt(cipher_suite_provider, node.public_key(), context)
+ .await
+ };
+
+ let ctxts = wrap_iter(reso).filter(|&idx| async move { !excluding.contains(&idx) });
+
+ #[cfg(not(mls_build_async))]
+ let ctxts = ctxts.map(make_ctxt);
+
+ #[cfg(mls_build_async)]
+ let ctxts = ctxts.then(make_ctxt);
+
+ let ctxts = ctxts.try_collect().await?;
+
+ let path_index = copath_index
+ .parent_sibling(&self.tree_kem_public.total_leaf_count())
+ .ok_or(MlsError::ExpectedNode)?
+ .parent;
+
+ Ok(UpdatePathNode {
+ public_key: self
+ .tree_kem_public
+ .nodes
+ .borrow_as_parent(path_index)?
+ .public_key
+ .clone(),
+ encrypted_path_secret: ctxts,
+ })
+ }
+
+ #[inline]
+ fn find_resolved_pos(
+ &self,
+ path: &[CopathNode<NodeIndex>],
+ mut lca_index: usize,
+ ) -> Result<usize, MlsError> {
+ while self.tree_kem_public.nodes.is_blank(path[lca_index].path)? {
+ lca_index -= 1;
+ }
+
+ // If we don't have the key, we should be an unmerged leaf at the resolved node. (If
+ // we're not, an error will be thrown later.)
+ if self.private_key.secret_keys[lca_index].is_none() {
+ lca_index = 0;
+ }
+
+ Ok(lca_index)
+ }
+
+ #[inline]
+ fn find_ciphertext_pos(
+ &self,
+ lca: NodeIndex,
+ resolved: NodeIndex,
+ excluding: &[LeafIndex],
+ ) -> Result<usize, MlsError> {
+ let reso = self.tree_kem_public.nodes.get_resolution_index(lca)?;
+
+ let (ct_pos, _) = reso
+ .iter()
+ .filter(|idx| **idx % 2 == 1 || !excluding.contains(&LeafIndex(**idx / 2)))
+ .find_position(|idx| idx == &&resolved)
+ .ok_or(MlsError::UpdateErrorNoSecretKey)?;
+
+ Ok(ct_pos)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::{tree_math, TreeKem};
+ use crate::{
+ cipher_suite::CipherSuite,
+ client::test_utils::TEST_CIPHER_SUITE,
+ crypto::test_utils::{test_cipher_suite_provider, TestCryptoProvider},
+ extension::test_utils::TestExtension,
+ group::test_utils::{get_test_group_context, random_bytes},
+ identity::basic::BasicIdentityProvider,
+ tree_kem::{
+ leaf_node::{
+ test_utils::{get_basic_test_node_sig_key, get_test_capabilities},
+ ConfigProperties,
+ },
+ node::LeafIndex,
+ Capabilities, TreeKemPrivate, TreeKemPublic, UpdatePath, ValidatedUpdatePath,
+ },
+ ExtensionList,
+ };
+ use alloc::{format, vec, vec::Vec};
+ use mls_rs_codec::MlsEncode;
+ use mls_rs_core::crypto::CipherSuiteProvider;
+ use tree_math::TreeIndex;
+
+ // Verify that the tree is in the correct state after generating an update path
+ fn verify_tree_update_path(
+ tree: &TreeKemPublic,
+ update_path: &UpdatePath,
+ index: LeafIndex,
+ capabilities: Option<Capabilities>,
+ extensions: Option<ExtensionList>,
+ ) {
+ // Make sure the update path is based on the direct path of the sender
+ let direct_path = tree.nodes.direct_copath(index);
+
+ for (i, n) in direct_path.iter().enumerate() {
+ assert_eq!(
+ *tree
+ .nodes
+ .borrow_node(n.path)
+ .unwrap()
+ .as_ref()
+ .unwrap()
+ .public_key(),
+ update_path.nodes[i].public_key
+ );
+ }
+
+ // Verify that the leaf from the update path has been installed
+ assert_eq!(
+ tree.nodes.borrow_as_leaf(index).unwrap(),
+ &update_path.leaf_node
+ );
+
+ // Verify that updated capabilities were installed
+ if let Some(capabilities) = capabilities {
+ assert_eq!(update_path.leaf_node.capabilities, capabilities);
+ }
+
+ // Verify that update extensions were installed
+ if let Some(extensions) = extensions {
+ assert_eq!(update_path.leaf_node.extensions, extensions);
+ }
+
+ // Verify that we have a public keys up to the root
+ let root = tree.total_leaf_count().root();
+ assert!(tree.nodes.borrow_node(root).unwrap().is_some());
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn verify_tree_private_path(
+ cipher_suite: &CipherSuite,
+ public_tree: &TreeKemPublic,
+ private_tree: &TreeKemPrivate,
+ index: LeafIndex,
+ ) {
+ let provider = test_cipher_suite_provider(*cipher_suite);
+
+ assert_eq!(private_tree.self_index, index);
+
+ // Make sure we have private values along the direct path, and the public keys match
+ let path_iter = public_tree
+ .nodes
+ .direct_copath(index)
+ .into_iter()
+ .enumerate();
+
+ for (i, n) in path_iter {
+ let secret_key = private_tree.secret_keys[i + 1].as_ref().unwrap();
+
+ let public_key = public_tree
+ .nodes
+ .borrow_node(n.path)
+ .unwrap()
+ .as_ref()
+ .unwrap()
+ .public_key();
+
+ let test_data = random_bytes(32);
+
+ let sealed = provider
+ .hpke_seal(public_key, &[], None, &test_data)
+ .await
+ .unwrap();
+
+ let opened = provider
+ .hpke_open(&sealed, secret_key, public_key, &[], None)
+ .await
+ .unwrap();
+
+ assert_eq!(test_data, opened);
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn encap_decap(
+ cipher_suite: CipherSuite,
+ size: usize,
+ capabilities: Option<Capabilities>,
+ extensions: Option<ExtensionList>,
+ ) {
+ let cipher_suite_provider = test_cipher_suite_provider(cipher_suite);
+
+ // Generate signing keys and key package generations, and private keys for multiple
+ // participants in order to set up state
+
+ let mut leaf_nodes = Vec::new();
+ let mut private_keys = Vec::new();
+
+ for index in 1..size {
+ let (leaf_node, hpke_secret, _) =
+ get_basic_test_node_sig_key(cipher_suite, &format!("{index}")).await;
+
+ let private_key = TreeKemPrivate::new_self_leaf(LeafIndex(index as u32), hpke_secret);
+
+ leaf_nodes.push(leaf_node);
+ private_keys.push(private_key);
+ }
+
+ let (encap_node, encap_hpke_secret, encap_signer) =
+ get_basic_test_node_sig_key(cipher_suite, "encap").await;
+
+ // Build a test tree we can clone for all leaf nodes
+ let (mut test_tree, mut encap_private_key) = TreeKemPublic::derive(
+ encap_node,
+ encap_hpke_secret,
+ &BasicIdentityProvider,
+ &Default::default(),
+ )
+ .await
+ .unwrap();
+
+ test_tree
+ .add_leaves(leaf_nodes, &BasicIdentityProvider, &cipher_suite_provider)
+ .await
+ .unwrap();
+
+ // Clone the tree for the first leaf, generate a new key package for that leaf
+ let mut encap_tree = test_tree.clone();
+
+ let update_leaf_properties = ConfigProperties {
+ capabilities: capabilities.clone().unwrap_or_else(get_test_capabilities),
+ extensions: extensions.clone().unwrap_or_default(),
+ };
+
+ // Perform the encap function
+ let encap_gen = TreeKem::new(&mut encap_tree, &mut encap_private_key)
+ .encap(
+ &mut get_test_group_context(42, cipher_suite).await,
+ &[],
+ &encap_signer,
+ update_leaf_properties,
+ None,
+ &cipher_suite_provider,
+ #[cfg(test)]
+ &Default::default(),
+ )
+ .await
+ .unwrap();
+
+ // Verify that the state of the tree matches the produced update path
+ verify_tree_update_path(
+ &encap_tree,
+ &encap_gen.update_path,
+ LeafIndex(0),
+ capabilities,
+ extensions,
+ );
+
+ // Verify that the private key matches the data in the public key
+ verify_tree_private_path(&cipher_suite, &encap_tree, &encap_private_key, LeafIndex(0))
+ .await;
+
+ let filtered = test_tree.nodes.filtered(LeafIndex(0)).unwrap();
+ let mut unfiltered_nodes = vec![None; filtered.len()];
+ filtered
+ .into_iter()
+ .enumerate()
+ .filter(|(_, f)| !*f)
+ .zip(encap_gen.update_path.nodes.iter())
+ .for_each(|((i, _), node)| {
+ unfiltered_nodes[i] = Some(node.clone());
+ });
+
+ // Apply the update path to the rest of the leaf nodes using the decap function
+ let validated_update_path = ValidatedUpdatePath {
+ leaf_node: encap_gen.update_path.leaf_node,
+ nodes: unfiltered_nodes,
+ };
+
+ encap_tree
+ .update_hashes(&[LeafIndex(0)], &cipher_suite_provider)
+ .await
+ .unwrap();
+
+ let mut receiver_trees: Vec<TreeKemPublic> = (1..size).map(|_| test_tree.clone()).collect();
+
+ for (i, tree) in receiver_trees.iter_mut().enumerate() {
+ tree.apply_update_path(
+ LeafIndex(0),
+ &validated_update_path,
+ &Default::default(),
+ BasicIdentityProvider,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap();
+
+ let mut context = get_test_group_context(42, cipher_suite).await;
+ context.tree_hash = tree.tree_hash(&cipher_suite_provider).await.unwrap();
+
+ TreeKem::new(tree, &mut private_keys[i])
+ .decap(
+ LeafIndex(0),
+ &validated_update_path,
+ &[],
+ &context.mls_encode_to_vec().unwrap(),
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap();
+
+ tree.update_hashes(&[LeafIndex(0)], &cipher_suite_provider)
+ .await
+ .unwrap();
+
+ assert_eq!(tree, &encap_tree);
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_encap_decap() {
+ for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
+ encap_decap(cipher_suite, 10, None, None).await;
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_encap_capabilities() {
+ let cipher_suite = TEST_CIPHER_SUITE;
+ let mut capabilities = get_test_capabilities();
+ capabilities.extensions.push(42.into());
+
+ encap_decap(cipher_suite, 10, Some(capabilities.clone()), None).await;
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_encap_extensions() {
+ let cipher_suite = TEST_CIPHER_SUITE;
+ let mut extensions = ExtensionList::default();
+ extensions.set_from(TestExtension { foo: 10 }).unwrap();
+
+ encap_decap(cipher_suite, 10, None, Some(extensions)).await;
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_encap_capabilities_extensions() {
+ let cipher_suite = TEST_CIPHER_SUITE;
+ let mut capabilities = get_test_capabilities();
+ capabilities.extensions.push(42.into());
+
+ let mut extensions = ExtensionList::default();
+ extensions.set_from(TestExtension { foo: 10 }).unwrap();
+
+ encap_decap(cipher_suite, 10, Some(capabilities), Some(extensions)).await;
+ }
+}
diff --git a/src/tree_kem/leaf_node.rs b/src/tree_kem/leaf_node.rs
new file mode 100644
index 0000000..c59ed78
--- /dev/null
+++ b/src/tree_kem/leaf_node.rs
@@ -0,0 +1,688 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use super::{parent_hash::ParentHash, Capabilities, Lifetime};
+use crate::client::MlsError;
+use crate::crypto::{CipherSuiteProvider, HpkePublicKey, HpkeSecretKey, SignatureSecretKey};
+use crate::{identity::SigningIdentity, signer::Signable, ExtensionList};
+use alloc::vec::Vec;
+use core::fmt::{self, Debug};
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+use mls_rs_core::error::IntoAnyError;
+
+#[derive(Debug, Clone, MlsSize, MlsEncode, MlsDecode, PartialEq, Eq)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+#[repr(u8)]
+pub enum LeafNodeSource {
+ KeyPackage(Lifetime) = 1u8,
+ Update = 2u8,
+ Commit(ParentHash) = 3u8,
+}
+
+#[derive(Clone, MlsSize, MlsEncode, MlsDecode, PartialEq)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+#[non_exhaustive]
+pub struct LeafNode {
+ pub public_key: HpkePublicKey,
+ pub signing_identity: SigningIdentity,
+ pub capabilities: Capabilities,
+ pub leaf_node_source: LeafNodeSource,
+ pub extensions: ExtensionList,
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
+ pub signature: Vec<u8>,
+}
+
+impl Debug for LeafNode {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("LeafNode")
+ .field("public_key", &self.public_key)
+ .field("signing_identity", &self.signing_identity)
+ .field("capabilities", &self.capabilities)
+ .field("leaf_node_source", &self.leaf_node_source)
+ .field("extensions", &self.extensions)
+ .field(
+ "signature",
+ &mls_rs_core::debug::pretty_bytes(&self.signature),
+ )
+ .finish()
+ }
+}
+
+#[derive(Clone, Debug)]
+pub struct ConfigProperties {
+ pub capabilities: Capabilities,
+ pub extensions: ExtensionList,
+}
+
+impl LeafNode {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn generate<CSP>(
+ cipher_suite_provider: &CSP,
+ properties: ConfigProperties,
+ signing_identity: SigningIdentity,
+ signer: &SignatureSecretKey,
+ lifetime: Lifetime,
+ ) -> Result<(Self, HpkeSecretKey), MlsError>
+ where
+ CSP: CipherSuiteProvider,
+ {
+ let (secret_key, public_key) = cipher_suite_provider
+ .kem_generate()
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
+
+ let mut leaf_node = LeafNode {
+ public_key,
+ signing_identity,
+ capabilities: properties.capabilities,
+ leaf_node_source: LeafNodeSource::KeyPackage(lifetime),
+ extensions: properties.extensions,
+ signature: Default::default(),
+ };
+
+ leaf_node.grease(cipher_suite_provider)?;
+
+ leaf_node
+ .sign(
+ cipher_suite_provider,
+ signer,
+ &LeafNodeSigningContext::default(),
+ )
+ .await?;
+
+ Ok((leaf_node, secret_key))
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn update<P: CipherSuiteProvider>(
+ &mut self,
+ cipher_suite_provider: &P,
+ group_id: &[u8],
+ leaf_index: u32,
+ new_properties: ConfigProperties,
+ signing_identity: Option<SigningIdentity>,
+ signer: &SignatureSecretKey,
+ ) -> Result<HpkeSecretKey, MlsError> {
+ let (secret, public) = cipher_suite_provider
+ .kem_generate()
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
+
+ self.public_key = public;
+ self.capabilities = new_properties.capabilities;
+ self.extensions = new_properties.extensions;
+ self.leaf_node_source = LeafNodeSource::Update;
+
+ self.grease(cipher_suite_provider)?;
+
+ if let Some(signing_identity) = signing_identity {
+ self.signing_identity = signing_identity;
+ }
+
+ self.sign(
+ cipher_suite_provider,
+ signer,
+ &(group_id, leaf_index).into(),
+ )
+ .await?;
+
+ Ok(secret)
+ }
+
+ #[allow(clippy::too_many_arguments)]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn commit<P: CipherSuiteProvider>(
+ &mut self,
+ cipher_suite_provider: &P,
+ group_id: &[u8],
+ leaf_index: u32,
+ new_properties: ConfigProperties,
+ new_signing_identity: Option<SigningIdentity>,
+ signer: &SignatureSecretKey,
+ ) -> Result<HpkeSecretKey, MlsError> {
+ let (secret, public) = cipher_suite_provider
+ .kem_generate()
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
+
+ self.public_key = public;
+ self.capabilities = new_properties.capabilities;
+ self.extensions = new_properties.extensions;
+
+ if let Some(new_signing_identity) = new_signing_identity {
+ self.signing_identity = new_signing_identity;
+ }
+
+ self.sign(
+ cipher_suite_provider,
+ signer,
+ &(group_id, leaf_index).into(),
+ )
+ .await?;
+
+ Ok(secret)
+ }
+}
+
+#[derive(Debug)]
+struct LeafNodeTBS<'a> {
+ public_key: &'a HpkePublicKey,
+ signing_identity: &'a SigningIdentity,
+ capabilities: &'a Capabilities,
+ leaf_node_source: &'a LeafNodeSource,
+ extensions: &'a ExtensionList,
+ group_id: Option<&'a [u8]>,
+ leaf_index: Option<u32>,
+}
+
+impl<'a> MlsSize for LeafNodeTBS<'a> {
+ fn mls_encoded_len(&self) -> usize {
+ self.public_key.mls_encoded_len()
+ + self.signing_identity.mls_encoded_len()
+ + self.capabilities.mls_encoded_len()
+ + self.leaf_node_source.mls_encoded_len()
+ + self.extensions.mls_encoded_len()
+ + self
+ .group_id
+ .as_ref()
+ .map_or(0, mls_rs_codec::byte_vec::mls_encoded_len)
+ + self.leaf_index.map_or(0, |i| i.mls_encoded_len())
+ }
+}
+
+impl<'a> MlsEncode for LeafNodeTBS<'a> {
+ fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error> {
+ self.public_key.mls_encode(writer)?;
+ self.signing_identity.mls_encode(writer)?;
+ self.capabilities.mls_encode(writer)?;
+ self.leaf_node_source.mls_encode(writer)?;
+ self.extensions.mls_encode(writer)?;
+
+ if let Some(ref group_id) = self.group_id {
+ mls_rs_codec::byte_vec::mls_encode(group_id, writer)?;
+ }
+
+ if let Some(leaf_index) = self.leaf_index {
+ leaf_index.mls_encode(writer)?;
+ }
+
+ Ok(())
+ }
+}
+
+#[derive(Clone, Debug, Default)]
+pub(crate) struct LeafNodeSigningContext<'a> {
+ pub group_id: Option<&'a [u8]>,
+ pub leaf_index: Option<u32>,
+}
+
+impl<'a> From<(&'a [u8], u32)> for LeafNodeSigningContext<'a> {
+ fn from((group_id, leaf_index): (&'a [u8], u32)) -> Self {
+ Self {
+ group_id: Some(group_id),
+ leaf_index: Some(leaf_index),
+ }
+ }
+}
+
+impl<'a> Signable<'a> for LeafNode {
+ const SIGN_LABEL: &'static str = "LeafNodeTBS";
+
+ type SigningContext = LeafNodeSigningContext<'a>;
+
+ fn signature(&self) -> &[u8] {
+ &self.signature
+ }
+
+ fn signable_content(
+ &self,
+ context: &Self::SigningContext,
+ ) -> Result<Vec<u8>, mls_rs_codec::Error> {
+ LeafNodeTBS {
+ public_key: &self.public_key,
+ signing_identity: &self.signing_identity,
+ capabilities: &self.capabilities,
+ leaf_node_source: &self.leaf_node_source,
+ extensions: &self.extensions,
+ group_id: context.group_id,
+ leaf_index: context.leaf_index,
+ }
+ .mls_encode_to_vec()
+ }
+
+ fn write_signature(&mut self, signature: Vec<u8>) {
+ self.signature = signature
+ }
+}
+
+#[cfg(test)]
+pub(crate) mod test_utils {
+ use alloc::vec;
+ use mls_rs_core::identity::{BasicCredential, CredentialType};
+
+ use crate::{
+ cipher_suite::CipherSuite,
+ crypto::test_utils::{test_cipher_suite_provider, TestCryptoProvider},
+ identity::test_utils::{get_test_signing_identity, BasicWithCustomProvider},
+ };
+
+ use crate::extension::ApplicationIdExt;
+
+ use super::*;
+
+ #[allow(unused)]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn get_test_node(
+ cipher_suite: CipherSuite,
+ signing_identity: SigningIdentity,
+ secret: &SignatureSecretKey,
+ capabilities: Option<Capabilities>,
+ extensions: Option<ExtensionList>,
+ ) -> (LeafNode, HpkeSecretKey) {
+ get_test_node_with_lifetime(
+ cipher_suite,
+ signing_identity,
+ secret,
+ capabilities.unwrap_or_else(get_test_capabilities),
+ extensions.unwrap_or_default(),
+ Lifetime::years(1).unwrap(),
+ )
+ .await
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn get_test_node_with_lifetime(
+ cipher_suite: CipherSuite,
+ signing_identity: SigningIdentity,
+ secret: &SignatureSecretKey,
+ capabilities: Capabilities,
+ extensions: ExtensionList,
+ lifetime: Lifetime,
+ ) -> (LeafNode, HpkeSecretKey) {
+ let properties = ConfigProperties {
+ capabilities,
+ extensions,
+ };
+
+ LeafNode::generate(
+ &test_cipher_suite_provider(cipher_suite),
+ properties,
+ signing_identity,
+ secret,
+ lifetime,
+ )
+ .await
+ .unwrap()
+ }
+
+ #[allow(unused)]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn get_basic_test_node(cipher_suite: CipherSuite, id: &str) -> LeafNode {
+ get_basic_test_node_sig_key(cipher_suite, id).await.0
+ }
+
+ #[allow(unused)]
+ pub fn default_properties() -> ConfigProperties {
+ ConfigProperties {
+ capabilities: get_test_capabilities(),
+ extensions: Default::default(),
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn get_basic_test_node_capabilities(
+ cipher_suite: CipherSuite,
+ id: &str,
+ capabilities: Capabilities,
+ ) -> (LeafNode, HpkeSecretKey, SignatureSecretKey) {
+ let (signing_identity, signature_key) =
+ get_test_signing_identity(cipher_suite, id.as_bytes()).await;
+
+ LeafNode::generate(
+ &test_cipher_suite_provider(cipher_suite),
+ ConfigProperties {
+ capabilities,
+ extensions: Default::default(),
+ },
+ signing_identity,
+ &signature_key,
+ Lifetime::years(1).unwrap(),
+ )
+ .await
+ .map(|(leaf, hpke_secret_key)| (leaf, hpke_secret_key, signature_key))
+ .unwrap()
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn get_basic_test_node_sig_key(
+ cipher_suite: CipherSuite,
+ id: &str,
+ ) -> (LeafNode, HpkeSecretKey, SignatureSecretKey) {
+ get_basic_test_node_capabilities(cipher_suite, id, get_test_capabilities()).await
+ }
+
+ #[allow(unused)]
+ pub fn get_test_extensions() -> ExtensionList {
+ let mut extension_list = ExtensionList::new();
+
+ extension_list
+ .set_from(ApplicationIdExt {
+ identifier: b"identifier".to_vec(),
+ })
+ .unwrap();
+
+ extension_list
+ }
+
+ pub fn get_test_capabilities() -> Capabilities {
+ Capabilities {
+ credentials: vec![
+ BasicCredential::credential_type(),
+ CredentialType::from(BasicWithCustomProvider::CUSTOM_CREDENTIAL_TYPE),
+ ],
+ cipher_suites: TestCryptoProvider::all_supported_cipher_suites(),
+ ..Default::default()
+ }
+ }
+
+ #[allow(unused)]
+ pub fn get_test_client_identity(leaf: &LeafNode) -> Vec<u8> {
+ leaf.signing_identity
+ .credential
+ .mls_encode_to_vec()
+ .unwrap()
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::test_utils::*;
+ use super::*;
+
+ use crate::client::test_utils::TEST_CIPHER_SUITE;
+ use crate::crypto::test_utils::test_cipher_suite_provider;
+ use crate::crypto::test_utils::TestCryptoProvider;
+ use crate::group::test_utils::random_bytes;
+ use crate::identity::test_utils::get_test_signing_identity;
+ use assert_matches::assert_matches;
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_node_generation() {
+ let capabilities = get_test_capabilities();
+ let extensions = get_test_extensions();
+ let lifetime = Lifetime::years(1).unwrap();
+
+ for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
+ let (signing_identity, secret) = get_test_signing_identity(cipher_suite, b"foo").await;
+
+ let (leaf_node, secret_key) = get_test_node_with_lifetime(
+ cipher_suite,
+ signing_identity.clone(),
+ &secret,
+ capabilities.clone(),
+ extensions.clone(),
+ lifetime.clone(),
+ )
+ .await;
+
+ assert_eq!(leaf_node.ungreased_capabilities(), capabilities);
+ assert_eq!(leaf_node.ungreased_extensions(), extensions);
+ assert_eq!(leaf_node.signing_identity, signing_identity);
+
+ assert_matches!(
+ &leaf_node.leaf_node_source,
+ LeafNodeSource::KeyPackage(lt) if lt == &lifetime,
+ "Expected {:?}, got {:?}", LeafNodeSource::KeyPackage(lifetime),
+ leaf_node.leaf_node_source
+ );
+
+ let provider = test_cipher_suite_provider(cipher_suite);
+
+ // Verify that the hpke key pair generated will work
+ let test_data = random_bytes(32);
+
+ let sealed = provider
+ .hpke_seal(&leaf_node.public_key, &[], None, &test_data)
+ .await
+ .unwrap();
+
+ let opened = provider
+ .hpke_open(&sealed, &secret_key, &leaf_node.public_key, &[], None)
+ .await
+ .unwrap();
+
+ assert_eq!(opened, test_data);
+
+ leaf_node
+ .verify(
+ &test_cipher_suite_provider(cipher_suite),
+ &signing_identity.signature_key,
+ &LeafNodeSigningContext::default(),
+ )
+ .await
+ .unwrap();
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_node_generation_randomness() {
+ let cipher_suite = TEST_CIPHER_SUITE;
+
+ let (signing_identity, secret) = get_test_signing_identity(cipher_suite, b"foo").await;
+
+ let (first_leaf, first_secret) =
+ get_test_node(cipher_suite, signing_identity.clone(), &secret, None, None).await;
+
+ for _ in 0..100 {
+ let (next_leaf, next_secret) =
+ get_test_node(cipher_suite, signing_identity.clone(), &secret, None, None).await;
+
+ assert_ne!(first_secret, next_secret);
+ assert_ne!(first_leaf.public_key, next_leaf.public_key);
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_node_update_no_meta_changes() {
+ for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
+ let cipher_suite_provider = test_cipher_suite_provider(cipher_suite);
+
+ let (signing_identity, secret) = get_test_signing_identity(cipher_suite, b"foo").await;
+
+ let (mut leaf, leaf_secret) =
+ get_test_node(cipher_suite, signing_identity.clone(), &secret, None, None).await;
+
+ let original_leaf = leaf.clone();
+
+ let new_secret = leaf
+ .update(
+ &cipher_suite_provider,
+ b"group",
+ 0,
+ default_properties(),
+ None,
+ &secret,
+ )
+ .await
+ .unwrap();
+
+ assert_ne!(new_secret, leaf_secret);
+ assert_ne!(original_leaf.public_key, leaf.public_key);
+
+ assert_eq!(
+ leaf.ungreased_capabilities(),
+ original_leaf.ungreased_capabilities()
+ );
+
+ assert_eq!(
+ leaf.ungreased_extensions(),
+ original_leaf.ungreased_extensions()
+ );
+
+ assert_eq!(leaf.signing_identity, original_leaf.signing_identity);
+ assert_matches!(&leaf.leaf_node_source, LeafNodeSource::Update);
+
+ leaf.verify(
+ &cipher_suite_provider,
+ &signing_identity.signature_key,
+ &(b"group".as_slice(), 0).into(),
+ )
+ .await
+ .unwrap();
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_node_update_meta_changes() {
+ let cipher_suite = TEST_CIPHER_SUITE;
+
+ let (signing_identity, secret) = get_test_signing_identity(cipher_suite, b"foo").await;
+
+ let new_properties = ConfigProperties {
+ capabilities: get_test_capabilities(),
+ extensions: get_test_extensions(),
+ };
+
+ let (mut leaf, _) =
+ get_test_node(cipher_suite, signing_identity, &secret, None, None).await;
+
+ leaf.update(
+ &test_cipher_suite_provider(cipher_suite),
+ b"group",
+ 0,
+ new_properties.clone(),
+ None,
+ &secret,
+ )
+ .await
+ .unwrap();
+
+ assert_eq!(leaf.ungreased_capabilities(), new_properties.capabilities);
+ assert_eq!(leaf.ungreased_extensions(), new_properties.extensions);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_node_commit_no_meta_changes() {
+ for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
+ let cipher_suite_provider = test_cipher_suite_provider(cipher_suite);
+
+ let (signing_identity, secret) = get_test_signing_identity(cipher_suite, b"foo").await;
+
+ let (mut leaf, leaf_secret) =
+ get_test_node(cipher_suite, signing_identity.clone(), &secret, None, None).await;
+
+ let original_leaf = leaf.clone();
+
+ let new_secret = leaf
+ .commit(
+ &cipher_suite_provider,
+ b"group",
+ 0,
+ default_properties(),
+ None,
+ &secret,
+ )
+ .await
+ .unwrap();
+
+ assert_ne!(new_secret, leaf_secret);
+ assert_ne!(original_leaf.public_key, leaf.public_key);
+
+ assert_eq!(
+ leaf.ungreased_capabilities(),
+ original_leaf.ungreased_capabilities()
+ );
+
+ assert_eq!(
+ leaf.ungreased_extensions(),
+ original_leaf.ungreased_extensions()
+ );
+
+ assert_eq!(leaf.signing_identity, original_leaf.signing_identity);
+
+ leaf.verify(
+ &cipher_suite_provider,
+ &signing_identity.signature_key,
+ &(b"group".as_slice(), 0).into(),
+ )
+ .await
+ .unwrap();
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_node_commit_meta_changes() {
+ let cipher_suite = TEST_CIPHER_SUITE;
+
+ let (signing_identity, secret) = get_test_signing_identity(cipher_suite, b"foo").await;
+ let (mut leaf, _) =
+ get_test_node(cipher_suite, signing_identity, &secret, None, None).await;
+
+ let new_properties = ConfigProperties {
+ capabilities: get_test_capabilities(),
+ extensions: get_test_extensions(),
+ };
+
+ // The new identity has a fresh public key
+ let new_signing_identity = get_test_signing_identity(cipher_suite, b"foo").await.0;
+
+ leaf.commit(
+ &test_cipher_suite_provider(cipher_suite),
+ b"group",
+ 0,
+ new_properties.clone(),
+ Some(new_signing_identity.clone()),
+ &secret,
+ )
+ .await
+ .unwrap();
+
+ assert_eq!(leaf.capabilities, new_properties.capabilities);
+ assert_eq!(leaf.extensions, new_properties.extensions);
+ assert_eq!(leaf.signing_identity, new_signing_identity);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn context_is_signed() {
+ let provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let (signing_identity, secret) = get_test_signing_identity(TEST_CIPHER_SUITE, b"foo").await;
+
+ let (mut leaf, _) = get_test_node(
+ TEST_CIPHER_SUITE,
+ signing_identity.clone(),
+ &secret,
+ None,
+ None,
+ )
+ .await;
+
+ leaf.sign(&provider, &secret, &(b"foo".as_slice(), 0).into())
+ .await
+ .unwrap();
+
+ let res = leaf
+ .verify(
+ &provider,
+ &signing_identity.signature_key,
+ &(b"foo".as_slice(), 1).into(),
+ )
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidSignature));
+
+ let res = leaf
+ .verify(
+ &provider,
+ &signing_identity.signature_key,
+ &(b"bar".as_slice(), 0).into(),
+ )
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidSignature));
+ }
+}
diff --git a/src/tree_kem/leaf_node_validator.rs b/src/tree_kem/leaf_node_validator.rs
new file mode 100644
index 0000000..17742ec
--- /dev/null
+++ b/src/tree_kem/leaf_node_validator.rs
@@ -0,0 +1,708 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use super::leaf_node::{LeafNode, LeafNodeSigningContext, LeafNodeSource};
+use crate::client::MlsError;
+use crate::CipherSuiteProvider;
+use crate::{signer::Signable, time::MlsTime};
+use mls_rs_core::{error::IntoAnyError, extension::ExtensionList, identity::IdentityProvider};
+
+use crate::extension::RequiredCapabilitiesExt;
+
+#[cfg(feature = "by_ref_proposal")]
+use crate::extension::ExternalSendersExt;
+
+pub enum ValidationContext<'a> {
+ Add(Option<MlsTime>),
+ Update((&'a [u8], u32, Option<MlsTime>)),
+ Commit((&'a [u8], u32, Option<MlsTime>)),
+}
+
+impl<'a> ValidationContext<'a> {
+ fn signing_context(&self) -> LeafNodeSigningContext {
+ match *self {
+ ValidationContext::Add(_) => Default::default(),
+ ValidationContext::Update((group_id, leaf_index, _)) => (group_id, leaf_index).into(),
+ ValidationContext::Commit((group_id, leaf_index, _)) => (group_id, leaf_index).into(),
+ }
+ }
+
+ fn generation_time(&self) -> Option<MlsTime> {
+ match *self {
+ ValidationContext::Add(t) => t,
+ ValidationContext::Update((_, _, t)) => t,
+ ValidationContext::Commit((_, _, t)) => t,
+ }
+ }
+}
+
+#[derive(Clone, Debug)]
+pub struct LeafNodeValidator<'a, C, CP>
+where
+ C: IdentityProvider,
+ CP: CipherSuiteProvider,
+{
+ cipher_suite_provider: &'a CP,
+ identity_provider: &'a C,
+ group_context_extensions: Option<&'a ExtensionList>,
+}
+
+impl<'a, C: IdentityProvider, CP: CipherSuiteProvider> LeafNodeValidator<'a, C, CP> {
+ pub fn new(
+ cipher_suite_provider: &'a CP,
+ identity_provider: &'a C,
+ group_context_extensions: Option<&'a ExtensionList>,
+ ) -> Self {
+ Self {
+ cipher_suite_provider,
+ identity_provider,
+ group_context_extensions,
+ }
+ }
+
+ fn check_context(
+ &self,
+ leaf_node: &LeafNode,
+ context: &ValidationContext,
+ ) -> Result<(), MlsError> {
+ // Context specific checks
+ match context {
+ ValidationContext::Add(time) => {
+ // If the context is add, and we specified a time to check for lifetime, verify it
+ if let LeafNodeSource::KeyPackage(lifetime) = &leaf_node.leaf_node_source {
+ if let Some(current_time) = time {
+ if !lifetime.within_lifetime(*current_time) {
+ return Err(MlsError::InvalidLifetime);
+ }
+ }
+ } else {
+ // If the leaf_node_source is anything other than Add it is invalid
+ return Err(MlsError::InvalidLeafNodeSource);
+ }
+ }
+ ValidationContext::Update(_) => {
+ // If the leaf_node_source is anything other than Update it is invalid
+ if !matches!(leaf_node.leaf_node_source, LeafNodeSource::Update) {
+ return Err(MlsError::InvalidLeafNodeSource);
+ }
+ }
+ ValidationContext::Commit(_) => {
+ // If the leaf_node_source is anything other than Commit it is invalid
+ if !matches!(leaf_node.leaf_node_source, LeafNodeSource::Commit(_)) {
+ return Err(MlsError::InvalidLeafNodeSource);
+ }
+ }
+ }
+
+ Ok(())
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn revalidate(
+ &self,
+ leaf_node: &LeafNode,
+ group_id: &[u8],
+ leaf_index: u32,
+ ) -> Result<(), MlsError> {
+ let context = match leaf_node.leaf_node_source {
+ LeafNodeSource::KeyPackage(_) => ValidationContext::Add(None),
+ LeafNodeSource::Update => ValidationContext::Update((group_id, leaf_index, None)),
+ LeafNodeSource::Commit(_) => ValidationContext::Commit((group_id, leaf_index, None)),
+ };
+
+ self.check_if_valid(leaf_node, context).await
+ }
+
+ pub fn validate_required_capabilities(&self, leaf_node: &LeafNode) -> Result<(), MlsError> {
+ let Some(required_capabilities) = self
+ .group_context_extensions
+ .and_then(|exts| exts.get_as::<RequiredCapabilitiesExt>().transpose())
+ .transpose()?
+ else {
+ return Ok(());
+ };
+
+ for extension in &required_capabilities.extensions {
+ if !leaf_node.capabilities.extensions.contains(extension) {
+ return Err(MlsError::RequiredExtensionNotFound(*extension));
+ }
+ }
+
+ for proposal in &required_capabilities.proposals {
+ if !leaf_node.capabilities.proposals.contains(proposal) {
+ return Err(MlsError::RequiredProposalNotFound(*proposal));
+ }
+ }
+
+ for credential in &required_capabilities.credentials {
+ if !leaf_node.capabilities.credentials.contains(credential) {
+ return Err(MlsError::RequiredCredentialNotFound(*credential));
+ }
+ }
+
+ Ok(())
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ pub fn validate_external_senders_ext_credentials(
+ &self,
+ leaf_node: &LeafNode,
+ ) -> Result<(), MlsError> {
+ let Some(ext) = self
+ .group_context_extensions
+ .and_then(|exts| exts.get_as::<ExternalSendersExt>().transpose())
+ .transpose()?
+ else {
+ return Ok(());
+ };
+
+ ext.allowed_senders.iter().try_for_each(|sender| {
+ let cred_type = sender.credential.credential_type();
+ leaf_node
+ .capabilities
+ .credentials
+ .contains(&cred_type)
+ .then_some(())
+ .ok_or(MlsError::RequiredCredentialNotFound(cred_type))
+ })
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn check_if_valid(
+ &self,
+ leaf_node: &LeafNode,
+ context: ValidationContext<'_>,
+ ) -> Result<(), MlsError> {
+ // Check that we are validating within the proper context
+ self.check_context(leaf_node, &context)?;
+
+ // Verify the credential
+ self.identity_provider
+ .validate_member(
+ &leaf_node.signing_identity,
+ context.generation_time(),
+ self.group_context_extensions,
+ )
+ .await
+ .map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))?;
+
+ // Verify that the credential signed the leaf node
+ leaf_node
+ .verify(
+ self.cipher_suite_provider,
+ &leaf_node.signing_identity.signature_key,
+ &context.signing_context(),
+ )
+ .await?;
+
+ // If required capabilities are specified, verify the leaf node meets the requirements
+ self.validate_required_capabilities(leaf_node)?;
+
+ // If there are extensions, make sure they are referenced in the capabilities field
+ for one_ext in &*leaf_node.extensions {
+ if !leaf_node
+ .capabilities
+ .extensions
+ .contains(&one_ext.extension_type)
+ {
+ return Err(MlsError::ExtensionNotInCapabilities(one_ext.extension_type));
+ }
+ }
+
+ // Verify that group extensions are supported by the leaf
+ self.group_context_extensions
+ .into_iter()
+ .flat_map(|exts| &**exts)
+ .map(|ext| ext.extension_type)
+ .find(|ext_type| {
+ !ext_type.is_default() && !leaf_node.capabilities.extensions.contains(ext_type)
+ })
+ .map(MlsError::UnsupportedGroupExtension)
+ .map_or(Ok(()), Err)?;
+
+ #[cfg(feature = "by_ref_proposal")]
+ self.validate_external_senders_ext_credentials(leaf_node)?;
+
+ Ok(())
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use crate::crypto::test_utils::try_test_cipher_suite_provider;
+ use crate::extension::MlsExtension;
+ use alloc::vec;
+ use assert_matches::assert_matches;
+ #[cfg(feature = "std")]
+ use core::time::Duration;
+ use mls_rs_core::crypto::CipherSuite;
+ use mls_rs_core::group::ProposalType;
+
+ use super::*;
+
+ use crate::client::test_utils::TEST_CIPHER_SUITE;
+ use crate::crypto::test_utils::test_cipher_suite_provider;
+ use crate::crypto::test_utils::TestCryptoProvider;
+ use crate::crypto::SignatureSecretKey;
+ use crate::extension::test_utils::TestExtension;
+ use crate::group::test_utils::random_bytes;
+ use crate::identity::basic::BasicCredential;
+ use crate::identity::basic::BasicIdentityProvider;
+ use crate::identity::test_utils::get_test_signing_identity;
+ use crate::tree_kem::leaf_node::test_utils::*;
+ use crate::tree_kem::leaf_node_validator::test_utils::FailureIdentityProvider;
+ use crate::tree_kem::Capabilities;
+ use crate::ExtensionList;
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn get_test_add_node() -> (LeafNode, SignatureSecretKey) {
+ let (signing_identity, secret) = get_test_signing_identity(TEST_CIPHER_SUITE, b"foo").await;
+
+ let (leaf_node, _) =
+ get_test_node(TEST_CIPHER_SUITE, signing_identity, &secret, None, None).await;
+
+ (leaf_node, secret)
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_basic_add_validation() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let (leaf_node, _) = get_test_add_node().await;
+
+ let test_validator =
+ LeafNodeValidator::new(&cipher_suite_provider, &BasicIdentityProvider, None);
+
+ let res = test_validator
+ .check_if_valid(&leaf_node, ValidationContext::Add(None))
+ .await;
+
+ assert_matches!(res, Ok(_));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_failed_validation() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+ let (leaf_node, _) = get_test_add_node().await;
+
+ let fail_test_validator =
+ LeafNodeValidator::new(&cipher_suite_provider, &FailureIdentityProvider, None);
+
+ let res = fail_test_validator
+ .check_if_valid(&leaf_node, ValidationContext::Add(None))
+ .await;
+
+ assert_matches!(res, Err(MlsError::IdentityProviderError(_)));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_basic_update_validation() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+ let group_id = b"group_id";
+
+ let (mut leaf_node, secret) = get_test_add_node().await;
+
+ leaf_node
+ .update(
+ &cipher_suite_provider,
+ group_id,
+ 0,
+ // TODO remove identity from input
+ default_properties(),
+ None,
+ &secret,
+ )
+ .await
+ .unwrap();
+
+ let test_validator =
+ LeafNodeValidator::new(&cipher_suite_provider, &BasicIdentityProvider, None);
+
+ let res = test_validator
+ .check_if_valid(&leaf_node, ValidationContext::Update((group_id, 0, None)))
+ .await;
+
+ assert_matches!(res, Ok(_));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_basic_commit_validation() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+ let group_id = b"group_id";
+
+ let (mut leaf_node, secret) = get_test_add_node().await;
+
+ leaf_node.leaf_node_source = LeafNodeSource::Commit(hex!("f00d").into());
+
+ leaf_node
+ .commit(
+ &cipher_suite_provider,
+ group_id,
+ 0,
+ default_properties(),
+ None,
+ &secret,
+ )
+ .await
+ .unwrap();
+
+ let test_validator =
+ LeafNodeValidator::new(&cipher_suite_provider, &BasicIdentityProvider, None);
+
+ let res = test_validator
+ .check_if_valid(&leaf_node, ValidationContext::Commit((group_id, 0, None)))
+ .await;
+
+ assert_matches!(res, Ok(_));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_incorrect_context() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let test_validator =
+ LeafNodeValidator::new(&cipher_suite_provider, &BasicIdentityProvider, None);
+
+ let (mut leaf_node, secret) = get_test_add_node().await;
+
+ let res = test_validator
+ .check_if_valid(&leaf_node, ValidationContext::Update((b"foo", 0, None)))
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidLeafNodeSource));
+
+ let res = test_validator
+ .check_if_valid(&leaf_node, ValidationContext::Commit((b"foo", 0, None)))
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidLeafNodeSource));
+
+ leaf_node
+ .update(
+ &cipher_suite_provider,
+ b"foo",
+ 0,
+ default_properties(),
+ None,
+ &secret,
+ )
+ .await
+ .unwrap();
+
+ let res = test_validator
+ .check_if_valid(&leaf_node, ValidationContext::Add(None))
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidLeafNodeSource));
+
+ let res = test_validator
+ .check_if_valid(&leaf_node, ValidationContext::Commit((b"foo", 0, None)))
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidLeafNodeSource));
+
+ leaf_node.leaf_node_source = LeafNodeSource::Commit(hex!("f00d").into());
+
+ leaf_node
+ .commit(
+ &cipher_suite_provider,
+ b"foo",
+ 0,
+ default_properties(),
+ None,
+ &secret,
+ )
+ .await
+ .unwrap();
+
+ let res = test_validator
+ .check_if_valid(&leaf_node, ValidationContext::Add(None))
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidLeafNodeSource));
+
+ let res = test_validator
+ .check_if_valid(&leaf_node, ValidationContext::Update((b"foo", 0, None)))
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidLeafNodeSource));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_bad_signature() {
+ for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
+ let cipher_suite_provider = test_cipher_suite_provider(cipher_suite);
+
+ let (signing_identity, secret) = get_test_signing_identity(cipher_suite, b"foo").await;
+
+ let (mut leaf_node, _) =
+ get_test_node(cipher_suite, signing_identity, &secret, None, None).await;
+
+ leaf_node.signature = random_bytes(leaf_node.signature.len());
+
+ let test_validator =
+ LeafNodeValidator::new(&cipher_suite_provider, &BasicIdentityProvider, None);
+
+ let res = test_validator
+ .check_if_valid(&leaf_node, ValidationContext::Add(None))
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidSignature));
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_capabilities_mismatch() {
+ let (signing_identity, secret) = get_test_signing_identity(TEST_CIPHER_SUITE, b"foo").await;
+
+ let mut extensions = ExtensionList::new();
+
+ extensions.set_from(TestExtension::from(0)).unwrap();
+
+ let capabilities = Capabilities {
+ credentials: vec![BasicCredential::credential_type()],
+ ..Default::default()
+ };
+
+ let (leaf_node, _) = get_test_node(
+ TEST_CIPHER_SUITE,
+ signing_identity,
+ &secret,
+ Some(capabilities),
+ Some(extensions),
+ )
+ .await;
+
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let test_validator =
+ LeafNodeValidator::new(&cipher_suite_provider, &BasicIdentityProvider, None);
+
+ let res = test_validator
+ .check_if_valid(&leaf_node, ValidationContext::Add(None))
+ .await;
+
+ assert_matches!(res,
+ Err(MlsError::ExtensionNotInCapabilities(ext)) if ext == 42.into());
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_cipher_suite_mismatch() {
+ for another_cipher_suite in CipherSuite::all().filter(|cs| cs != &TEST_CIPHER_SUITE) {
+ if let Some(cs) = try_test_cipher_suite_provider(*another_cipher_suite) {
+ let (leaf_node, _) = get_test_add_node().await;
+
+ let test_validator = LeafNodeValidator::new(&cs, &BasicIdentityProvider, None);
+
+ let res = test_validator
+ .check_if_valid(&leaf_node, ValidationContext::Add(None))
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidSignature));
+ }
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_required_extension() {
+ let required_capabilities = RequiredCapabilitiesExt {
+ extensions: vec![43.into()],
+ ..Default::default()
+ };
+
+ let (leaf_node, _) = get_test_add_node().await;
+
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let group_context_extensions =
+ core::iter::once(required_capabilities.into_extension().unwrap()).collect();
+
+ let test_validator = LeafNodeValidator::new(
+ &cipher_suite_provider,
+ &BasicIdentityProvider,
+ Some(&group_context_extensions),
+ );
+
+ let res = test_validator
+ .check_if_valid(&leaf_node, ValidationContext::Add(None))
+ .await;
+
+ assert_matches!(
+ res,
+ Err(MlsError::RequiredExtensionNotFound(v)) if v == 43.into()
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_required_proposal() {
+ let required_capabilities = RequiredCapabilitiesExt {
+ proposals: vec![42.into()],
+ ..Default::default()
+ };
+
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let (leaf_node, _) = get_test_add_node().await;
+
+ let group_context_extensions =
+ core::iter::once(required_capabilities.into_extension().unwrap()).collect();
+
+ let test_validator = LeafNodeValidator::new(
+ &cipher_suite_provider,
+ &BasicIdentityProvider,
+ Some(&group_context_extensions),
+ );
+
+ let res = test_validator
+ .check_if_valid(&leaf_node, ValidationContext::Add(None))
+ .await;
+
+ assert_matches!(
+ res,
+ Err(MlsError::RequiredProposalNotFound(p)) if p == ProposalType::new(42)
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_required_credential() {
+ let required_capabilities = RequiredCapabilitiesExt {
+ credentials: vec![0.into()],
+ ..Default::default()
+ };
+
+ let (leaf_node, _) = get_test_add_node().await;
+
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let group_context_extensions =
+ core::iter::once(required_capabilities.into_extension().unwrap()).collect();
+
+ let test_validator = LeafNodeValidator::new(
+ &cipher_suite_provider,
+ &BasicIdentityProvider,
+ Some(&group_context_extensions),
+ );
+
+ let res = test_validator
+ .check_if_valid(&leaf_node, ValidationContext::Add(None))
+ .await;
+
+ assert_matches!(res,
+ Err(MlsError::RequiredCredentialNotFound(ext)) if ext == 0.into()
+ );
+ }
+
+ #[cfg(feature = "std")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_add_lifetime() {
+ let (leaf_node, _) = get_test_add_node().await;
+
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let test_validator =
+ LeafNodeValidator::new(&cipher_suite_provider, &BasicIdentityProvider, None);
+
+ let good_lifetime = MlsTime::now();
+
+ let over_one_year = good_lifetime.seconds_since_epoch() + (86400 * 366);
+
+ let bad_lifetime = MlsTime::from_duration_since_epoch(Duration::from_secs(over_one_year));
+
+ let res = test_validator
+ .check_if_valid(&leaf_node, ValidationContext::Add(Some(good_lifetime)))
+ .await;
+
+ assert_matches!(res, Ok(()));
+
+ let res = test_validator
+ .check_if_valid(&leaf_node, ValidationContext::Add(Some(bad_lifetime)))
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidLifetime));
+ }
+}
+
+#[cfg(test)]
+pub(crate) mod test_utils {
+ use alloc::vec;
+ use alloc::{boxed::Box, vec::Vec};
+ use mls_rs_codec::MlsEncode;
+ use mls_rs_core::{
+ error::IntoAnyError,
+ extension::ExtensionList,
+ identity::{BasicCredential, IdentityProvider},
+ };
+
+ use crate::{identity::SigningIdentity, time::MlsTime};
+
+ #[derive(Clone, Debug, Default)]
+ pub struct FailureIdentityProvider;
+
+ #[cfg(feature = "by_ref_proposal")]
+ impl FailureIdentityProvider {
+ pub fn new() -> Self {
+ Self
+ }
+ }
+
+ #[derive(Debug)]
+ #[cfg_attr(feature = "std", derive(thiserror::Error))]
+ #[cfg_attr(feature = "std", error("test error"))]
+ pub struct TestFailureError;
+
+ impl IntoAnyError for TestFailureError {
+ #[cfg(feature = "std")]
+ fn into_dyn_error(self) -> Result<Box<dyn std::error::Error + Send + Sync>, Self> {
+ Ok(self.into())
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ #[cfg_attr(mls_build_async, maybe_async::must_be_async)]
+ impl IdentityProvider for FailureIdentityProvider {
+ type Error = TestFailureError;
+
+ async fn validate_member(
+ &self,
+ _signing_identity: &SigningIdentity,
+ _timestamp: Option<MlsTime>,
+ _extensions: Option<&ExtensionList>,
+ ) -> Result<(), Self::Error> {
+ Err(TestFailureError)
+ }
+
+ async fn validate_external_sender(
+ &self,
+ _signing_identity: &SigningIdentity,
+ _timestamp: Option<MlsTime>,
+ _extensions: Option<&ExtensionList>,
+ ) -> Result<(), Self::Error> {
+ Err(TestFailureError)
+ }
+
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ async fn identity(
+ &self,
+ signing_id: &SigningIdentity,
+ _extensions: &ExtensionList,
+ ) -> Result<Vec<u8>, Self::Error> {
+ Ok(signing_id.credential.mls_encode_to_vec().unwrap())
+ }
+
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ async fn valid_successor(
+ &self,
+ _predecessor: &SigningIdentity,
+ _successor: &SigningIdentity,
+ _extensions: &ExtensionList,
+ ) -> Result<bool, Self::Error> {
+ Err(TestFailureError)
+ }
+
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn supported_types(&self) -> Vec<crate::identity::CredentialType> {
+ vec![BasicCredential::credential_type()]
+ }
+ }
+}
diff --git a/src/tree_kem/lifetime.rs b/src/tree_kem/lifetime.rs
new file mode 100644
index 0000000..d508ad6
--- /dev/null
+++ b/src/tree_kem/lifetime.rs
@@ -0,0 +1,119 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use crate::{client::MlsError, time::MlsTime};
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+
+#[derive(Clone, Debug, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode, Default)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+#[non_exhaustive]
+pub struct Lifetime {
+ pub not_before: u64,
+ pub not_after: u64,
+}
+
+impl Lifetime {
+ pub fn new(not_before: u64, not_after: u64) -> Lifetime {
+ Lifetime {
+ not_before,
+ not_after,
+ }
+ }
+
+ pub fn seconds(s: u64) -> Result<Self, MlsError> {
+ #[cfg(feature = "std")]
+ let not_before = MlsTime::now().seconds_since_epoch();
+ #[cfg(not(feature = "std"))]
+ // There is no clock on no_std, this is here just so that we can run tests.
+ let not_before = 3600u64;
+
+ let not_after = not_before.checked_add(s).ok_or(MlsError::TimeOverflow)?;
+
+ Ok(Lifetime {
+ // Subtract 1 hour to address time difference between machines
+ not_before: not_before - 3600,
+ not_after,
+ })
+ }
+
+ pub fn days(d: u32) -> Result<Self, MlsError> {
+ Self::seconds((d * 86400) as u64)
+ }
+
+ pub fn years(y: u8) -> Result<Self, MlsError> {
+ Self::days(365 * y as u32)
+ }
+
+ pub(crate) fn within_lifetime(&self, time: MlsTime) -> bool {
+ let since_epoch = time.seconds_since_epoch();
+ since_epoch >= self.not_before && since_epoch <= self.not_after
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use core::time::Duration;
+
+ use super::*;
+ use assert_matches::assert_matches;
+
+ #[test]
+ fn test_lifetime_overflow() {
+ let res = Lifetime::seconds(u64::MAX);
+ assert_matches!(res, Err(MlsError::TimeOverflow))
+ }
+
+ #[test]
+ fn test_seconds() {
+ let seconds = 10;
+ let lifetime = Lifetime::seconds(seconds).unwrap();
+ assert_eq!(lifetime.not_after - lifetime.not_before, 3610);
+ }
+
+ #[test]
+ fn test_days() {
+ let days = 2;
+ let lifetime = Lifetime::days(days).unwrap();
+
+ assert_eq!(
+ lifetime.not_after - lifetime.not_before,
+ 86400u64 * days as u64 + 3600
+ );
+ }
+
+ #[test]
+ fn test_years() {
+ let years = 2;
+ let lifetime = Lifetime::years(years).unwrap();
+
+ assert_eq!(
+ lifetime.not_after - lifetime.not_before,
+ 86400 * 365 * years as u64 + 3600
+ );
+ }
+
+ #[test]
+ fn test_bounds() {
+ let test_lifetime = Lifetime {
+ not_before: 5,
+ not_after: 10,
+ };
+
+ assert!(!test_lifetime
+ .within_lifetime(MlsTime::from_duration_since_epoch(Duration::from_secs(4))));
+
+ assert!(!test_lifetime
+ .within_lifetime(MlsTime::from_duration_since_epoch(Duration::from_secs(11))));
+
+ assert!(test_lifetime
+ .within_lifetime(MlsTime::from_duration_since_epoch(Duration::from_secs(5))));
+
+ assert!(test_lifetime
+ .within_lifetime(MlsTime::from_duration_since_epoch(Duration::from_secs(10))));
+
+ assert!(test_lifetime
+ .within_lifetime(MlsTime::from_duration_since_epoch(Duration::from_secs(6))));
+ }
+}
diff --git a/src/tree_kem/math.rs b/src/tree_kem/math.rs
new file mode 100644
index 0000000..51f82ed
--- /dev/null
+++ b/src/tree_kem/math.rs
@@ -0,0 +1,383 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::vec::Vec;
+use core::{fmt::Debug, hash::Hash};
+use mls_rs_codec::{MlsDecode, MlsEncode};
+
+use super::node::LeafIndex;
+
+pub trait TreeIndex:
+ Send + Sync + Eq + Clone + Debug + Default + MlsEncode + MlsDecode + Hash + Ord
+{
+ fn root(&self) -> Self;
+
+ fn left_unchecked(&self) -> Self;
+ fn right_unchecked(&self) -> Self;
+
+ fn parent_sibling(&self, leaf_count: &Self) -> Option<ParentSibling<Self>>;
+ fn is_leaf(&self) -> bool;
+ fn is_in_tree(&self, root: &Self) -> bool;
+
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ fn zero() -> Self;
+
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message", test))]
+ fn left(&self) -> Option<Self> {
+ (!self.is_leaf()).then(|| self.left_unchecked())
+ }
+
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message", test))]
+ fn right(&self) -> Option<Self> {
+ (!self.is_leaf()).then(|| self.right_unchecked())
+ }
+
+ fn direct_copath(&self, leaf_count: &Self) -> Vec<CopathNode<Self>> {
+ let root = leaf_count.root();
+
+ if !self.is_in_tree(&root) {
+ return Vec::new();
+ }
+
+ let mut path = Vec::new();
+ let mut parent = self.clone();
+
+ while let Some(ps) = parent.parent_sibling(leaf_count) {
+ path.push(CopathNode::new(ps.parent.clone(), ps.sibling));
+ parent = ps.parent;
+ }
+
+ path
+ }
+}
+
+#[derive(Clone, PartialEq, Eq, Debug)]
+pub struct CopathNode<T> {
+ pub path: T,
+ pub copath: T,
+}
+
+impl<T: Clone + PartialEq + Eq + core::fmt::Debug> CopathNode<T> {
+ pub fn new(path: T, copath: T) -> CopathNode<T> {
+ CopathNode { path, copath }
+ }
+}
+
+#[derive(Clone, PartialEq, Eq, Debug)]
+pub struct ParentSibling<T> {
+ pub parent: T,
+ pub sibling: T,
+}
+
+impl<T: Clone + PartialEq + Eq + core::fmt::Debug> ParentSibling<T> {
+ pub fn new(parent: T, sibling: T) -> ParentSibling<T> {
+ ParentSibling { parent, sibling }
+ }
+}
+
+macro_rules! impl_tree_stdint {
+ ($t:ty) => {
+ impl TreeIndex for $t {
+ fn root(&self) -> $t {
+ *self - 1
+ }
+
+ /// Panicks if `x` is even in debug, overflows in release.
+ fn left_unchecked(&self) -> Self {
+ *self ^ (0x01 << (self.trailing_ones() - 1))
+ }
+
+ /// Panicks if `x` is even in debug, overflows in release.
+ fn right_unchecked(&self) -> Self {
+ *self ^ (0x03 << (self.trailing_ones() - 1))
+ }
+
+ fn parent_sibling(&self, leaf_count: &Self) -> Option<ParentSibling<Self>> {
+ if self == &leaf_count.root() {
+ return None;
+ }
+
+ let lvl = self.trailing_ones();
+ let p = (self & !(1 << (lvl + 1))) | (1 << lvl);
+
+ let s = if *self < p {
+ p.right_unchecked()
+ } else {
+ p.left_unchecked()
+ };
+
+ Some(ParentSibling::new(p, s))
+ }
+
+ fn is_leaf(&self) -> bool {
+ self & 1 == 0
+ }
+
+ fn is_in_tree(&self, root: &Self) -> bool {
+ *self <= 2 * root
+ }
+
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ fn zero() -> Self {
+ 0
+ }
+ }
+ };
+}
+
+impl_tree_stdint!(u32);
+
+#[cfg(test)]
+impl_tree_stdint!(u64);
+
+pub fn leaf_lca_level(x: u32, y: u32) -> u32 {
+ let mut xn = x;
+ let mut yn = y;
+ let mut k = 0;
+
+ while xn != yn {
+ xn >>= 1;
+ yn >>= 1;
+ k += 1;
+ }
+
+ k
+}
+
+pub fn subtree(x: u32) -> (LeafIndex, LeafIndex) {
+ let breadth = 1 << x.trailing_ones();
+ (
+ LeafIndex((x + 1 - breadth) >> 1),
+ LeafIndex(((x + breadth) >> 1) + 1),
+ )
+}
+
+pub struct BfsIterTopDown {
+ level: usize,
+ mask: usize,
+ level_end: usize,
+ ctr: usize,
+}
+
+impl BfsIterTopDown {
+ pub fn new(num_leaves: usize) -> Self {
+ let depth = num_leaves.trailing_zeros() as usize;
+ Self {
+ level: depth + 1,
+ mask: (1 << depth) - 1,
+ level_end: 1,
+ ctr: 0,
+ }
+ }
+}
+
+impl Iterator for BfsIterTopDown {
+ type Item = usize;
+
+ fn next(&mut self) -> Option<Self::Item> {
+ if self.ctr == self.level_end {
+ if self.level == 1 {
+ return None;
+ }
+ self.level_end = (((self.level_end - 1) << 1) | 1) + 1;
+ self.level -= 1;
+ self.ctr = 0;
+ self.mask >>= 1;
+ }
+ let res = Some((self.ctr << self.level) | self.mask);
+ self.ctr += 1;
+ res
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use itertools::Itertools;
+ use serde::{Deserialize, Serialize};
+
+ #[cfg(target_arch = "wasm32")]
+ use wasm_bindgen_test::wasm_bindgen_test as test;
+
+ #[derive(Serialize, Deserialize)]
+ struct TestCase {
+ n_leaves: u32,
+ n_nodes: u32,
+ root: u32,
+ left: Vec<Option<u32>>,
+ right: Vec<Option<u32>>,
+ parent: Vec<Option<u32>>,
+ sibling: Vec<Option<u32>>,
+ }
+
+ pub fn node_width(n: u32) -> u32 {
+ if n == 0 {
+ 0
+ } else {
+ 2 * (n - 1) + 1
+ }
+ }
+
+ #[test]
+ fn test_bfs_iterator() {
+ let expected = [7, 3, 11, 1, 5, 9, 13, 0, 2, 4, 6, 8, 10, 12, 14];
+ let bfs = BfsIterTopDown::new(8);
+ assert_eq!(bfs.collect::<Vec<_>>(), expected);
+ }
+
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn generate_tree_math_test_cases() -> Vec<TestCase> {
+ let mut test_cases = Vec::new();
+
+ for log_n_leaves in 0..8 {
+ let n_leaves = 1 << log_n_leaves;
+ let n_nodes = node_width(n_leaves);
+ let left = (0..n_nodes).map(|x| x.left()).collect::<Vec<_>>();
+ let right = (0..n_nodes).map(|x| x.right()).collect::<Vec<_>>();
+
+ let (parent, sibling) = (0..n_nodes)
+ .map(|x| {
+ x.parent_sibling(&n_leaves)
+ .map(|ps| (ps.parent, ps.sibling))
+ .unzip()
+ })
+ .unzip();
+
+ test_cases.push(TestCase {
+ n_leaves,
+ n_nodes,
+ root: n_leaves.root(),
+ left,
+ right,
+ parent,
+ sibling,
+ })
+ }
+
+ test_cases
+ }
+
+ fn load_test_cases() -> Vec<TestCase> {
+ load_test_case_json!(tree_math, generate_tree_math_test_cases())
+ }
+
+ #[test]
+ fn test_tree_math() {
+ let test_cases = load_test_cases();
+
+ for case in test_cases {
+ assert_eq!(node_width(case.n_leaves), case.n_nodes);
+ assert_eq!(case.n_leaves.root(), case.root);
+
+ for x in 0..case.n_nodes {
+ assert_eq!(x.left(), case.left[x as usize]);
+ assert_eq!(x.right(), case.right[x as usize]);
+
+ let (p, s) = x
+ .parent_sibling(&case.n_leaves)
+ .map(|ps| (ps.parent, ps.sibling))
+ .unzip();
+
+ assert_eq!(p, case.parent[x as usize]);
+ assert_eq!(s, case.sibling[x as usize]);
+ }
+ }
+ }
+
+ #[test]
+ fn test_direct_path() {
+ let expected: Vec<Vec<u32>> = [
+ [0x01, 0x03, 0x07, 0x0f].to_vec(),
+ [0x03, 0x07, 0x0f].to_vec(),
+ [0x01, 0x03, 0x07, 0x0f].to_vec(),
+ [0x07, 0x0f].to_vec(),
+ [0x05, 0x03, 0x07, 0x0f].to_vec(),
+ [0x03, 0x07, 0x0f].to_vec(),
+ [0x05, 0x03, 0x07, 0x0f].to_vec(),
+ [0x0f].to_vec(),
+ [0x09, 0x0b, 0x07, 0x0f].to_vec(),
+ [0x0b, 0x07, 0x0f].to_vec(),
+ [0x09, 0x0b, 0x07, 0x0f].to_vec(),
+ [0x07, 0x0f].to_vec(),
+ [0x0d, 0x0b, 0x07, 0x0f].to_vec(),
+ [0x0b, 0x07, 0x0f].to_vec(),
+ [0x0d, 0x0b, 0x07, 0x0f].to_vec(),
+ [].to_vec(),
+ [0x11, 0x13, 0x17, 0x0f].to_vec(),
+ [0x13, 0x17, 0x0f].to_vec(),
+ [0x11, 0x13, 0x17, 0x0f].to_vec(),
+ [0x17, 0x0f].to_vec(),
+ [0x15, 0x13, 0x17, 0x0f].to_vec(),
+ [0x13, 0x17, 0x0f].to_vec(),
+ [0x15, 0x13, 0x17, 0x0f].to_vec(),
+ [0x0f].to_vec(),
+ [0x19, 0x1b, 0x17, 0x0f].to_vec(),
+ [0x1b, 0x17, 0x0f].to_vec(),
+ [0x19, 0x1b, 0x17, 0x0f].to_vec(),
+ [0x17, 0x0f].to_vec(),
+ [0x1d, 0x1b, 0x17, 0x0f].to_vec(),
+ [0x1b, 0x17, 0x0f].to_vec(),
+ [0x1d, 0x1b, 0x17, 0x0f].to_vec(),
+ ]
+ .to_vec();
+
+ for (i, item) in expected.iter().enumerate() {
+ let path = (i as u32)
+ .direct_copath(&16)
+ .into_iter()
+ .map(|cp| cp.path)
+ .collect_vec();
+
+ assert_eq!(item, &path)
+ }
+ }
+
+ #[test]
+ fn test_copath_path() {
+ let expected: Vec<Vec<u32>> = [
+ [0x02, 0x05, 0x0b, 0x17].to_vec(),
+ [0x05, 0x0b, 0x17].to_vec(),
+ [0x00, 0x05, 0x0b, 0x17].to_vec(),
+ [0x0b, 0x17].to_vec(),
+ [0x06, 0x01, 0x0b, 0x17].to_vec(),
+ [0x01, 0x0b, 0x17].to_vec(),
+ [0x04, 0x01, 0x0b, 0x17].to_vec(),
+ [0x17].to_vec(),
+ [0x0a, 0x0d, 0x03, 0x17].to_vec(),
+ [0x0d, 0x03, 0x17].to_vec(),
+ [0x08, 0x0d, 0x03, 0x17].to_vec(),
+ [0x03, 0x17].to_vec(),
+ [0x0e, 0x09, 0x03, 0x17].to_vec(),
+ [0x09, 0x03, 0x17].to_vec(),
+ [0x0c, 0x09, 0x03, 0x17].to_vec(),
+ [].to_vec(),
+ [0x12, 0x15, 0x1b, 0x07].to_vec(),
+ [0x15, 0x1b, 0x07].to_vec(),
+ [0x10, 0x15, 0x1b, 0x07].to_vec(),
+ [0x1b, 0x07].to_vec(),
+ [0x16, 0x11, 0x1b, 0x07].to_vec(),
+ [0x11, 0x1b, 0x07].to_vec(),
+ [0x14, 0x11, 0x1b, 0x07].to_vec(),
+ [0x07].to_vec(),
+ [0x1a, 0x1d, 0x13, 0x07].to_vec(),
+ [0x1d, 0x13, 0x07].to_vec(),
+ [0x18, 0x1d, 0x13, 0x07].to_vec(),
+ [0x13, 0x07].to_vec(),
+ [0x1e, 0x19, 0x13, 0x07].to_vec(),
+ [0x19, 0x13, 0x07].to_vec(),
+ [0x1c, 0x19, 0x13, 0x07].to_vec(),
+ ]
+ .to_vec();
+
+ for (i, item) in expected.iter().enumerate() {
+ let copath = (i as u32)
+ .direct_copath(&16)
+ .into_iter()
+ .map(|cp| cp.copath)
+ .collect_vec();
+
+ assert_eq!(item, &copath)
+ }
+ }
+}
diff --git a/src/tree_kem/mod.rs b/src/tree_kem/mod.rs
new file mode 100644
index 0000000..430ee16
--- /dev/null
+++ b/src/tree_kem/mod.rs
@@ -0,0 +1,1490 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::vec;
+use alloc::vec::Vec;
+#[cfg(feature = "std")]
+use core::fmt::Display;
+use itertools::Itertools;
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+use mls_rs_core::extension::ExtensionList;
+
+use mls_rs_core::{error::IntoAnyError, identity::IdentityProvider};
+
+#[cfg(feature = "tree_index")]
+use mls_rs_core::identity::SigningIdentity;
+
+use math as tree_math;
+use node::{LeafIndex, NodeIndex, NodeVec};
+
+use self::leaf_node::LeafNode;
+
+use crate::client::MlsError;
+use crate::crypto::{self, CipherSuiteProvider, HpkeSecretKey};
+
+#[cfg(feature = "by_ref_proposal")]
+use crate::group::proposal::{AddProposal, UpdateProposal};
+
+#[cfg(any(test, feature = "by_ref_proposal"))]
+use crate::group::proposal::RemoveProposal;
+
+use crate::group::proposal_filter::ProposalBundle;
+use crate::tree_kem::tree_hash::TreeHashes;
+
+mod capabilities;
+pub(crate) mod hpke_encryption;
+mod lifetime;
+pub(crate) mod math;
+pub mod node;
+pub mod parent_hash;
+pub mod path_secret;
+mod private;
+mod tree_hash;
+pub mod tree_validator;
+pub mod update_path;
+
+pub use capabilities::*;
+pub use lifetime::*;
+pub(crate) use private::*;
+pub use update_path::*;
+
+use tree_index::*;
+
+pub mod kem;
+pub mod leaf_node;
+pub mod leaf_node_validator;
+mod tree_index;
+
+#[cfg(feature = "std")]
+pub(crate) mod tree_utils;
+
+#[cfg(test)]
+mod interop_test_vectors;
+
+#[cfg(feature = "custom_proposal")]
+use crate::group::proposal::ProposalType;
+
+#[derive(Clone, Debug, MlsEncode, MlsDecode, MlsSize, Default)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub struct TreeKemPublic {
+ #[cfg(feature = "tree_index")]
+ #[cfg_attr(feature = "serde", serde(skip))]
+ index: TreeIndex,
+ pub(crate) nodes: NodeVec,
+ tree_hashes: TreeHashes,
+}
+
+impl PartialEq for TreeKemPublic {
+ fn eq(&self, other: &Self) -> bool {
+ self.nodes == other.nodes
+ }
+}
+
+impl TreeKemPublic {
+ pub fn new() -> TreeKemPublic {
+ Default::default()
+ }
+
+ #[cfg_attr(not(feature = "tree_index"), allow(unused))]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn import_node_data<IP>(
+ nodes: NodeVec,
+ identity_provider: &IP,
+ extensions: &ExtensionList,
+ ) -> Result<TreeKemPublic, MlsError>
+ where
+ IP: IdentityProvider,
+ {
+ let mut tree = TreeKemPublic {
+ nodes,
+ ..Default::default()
+ };
+
+ #[cfg(feature = "tree_index")]
+ tree.initialize_index_if_necessary(identity_provider, extensions)
+ .await?;
+
+ Ok(tree)
+ }
+
+ #[cfg(feature = "tree_index")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn initialize_index_if_necessary<IP: IdentityProvider>(
+ &mut self,
+ identity_provider: &IP,
+ extensions: &ExtensionList,
+ ) -> Result<(), MlsError> {
+ if !self.index.is_initialized() {
+ self.index = TreeIndex::new();
+
+ for (leaf_index, leaf) in self.nodes.non_empty_leaves() {
+ index_insert(
+ &mut self.index,
+ leaf,
+ leaf_index,
+ identity_provider,
+ extensions,
+ )
+ .await?;
+ }
+ }
+
+ Ok(())
+ }
+
+ #[cfg(feature = "tree_index")]
+ pub(crate) fn get_leaf_node_with_identity(&self, identity: &[u8]) -> Option<LeafIndex> {
+ self.index.get_leaf_index_with_identity(identity)
+ }
+
+ #[cfg(not(feature = "tree_index"))]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn get_leaf_node_with_identity<I: IdentityProvider>(
+ &self,
+ identity: &[u8],
+ id_provider: &I,
+ extensions: &ExtensionList,
+ ) -> Result<Option<LeafIndex>, MlsError> {
+ for (i, leaf) in self.nodes.non_empty_leaves() {
+ let leaf_id = id_provider
+ .identity(&leaf.signing_identity, extensions)
+ .await
+ .map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))?;
+
+ if leaf_id == identity {
+ return Ok(Some(i));
+ }
+ }
+
+ Ok(None)
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn derive<I: IdentityProvider>(
+ leaf_node: LeafNode,
+ secret_key: HpkeSecretKey,
+ identity_provider: &I,
+ extensions: &ExtensionList,
+ ) -> Result<(TreeKemPublic, TreeKemPrivate), MlsError> {
+ let mut public_tree = TreeKemPublic::new();
+
+ public_tree
+ .add_leaf(leaf_node, identity_provider, extensions, None)
+ .await?;
+
+ let private_tree = TreeKemPrivate::new_self_leaf(LeafIndex(0), secret_key);
+
+ Ok((public_tree, private_tree))
+ }
+
+ pub fn total_leaf_count(&self) -> u32 {
+ self.nodes.total_leaf_count()
+ }
+
+ #[cfg(any(test, all(feature = "custom_proposal", feature = "tree_index")))]
+ pub fn occupied_leaf_count(&self) -> u32 {
+ self.nodes.occupied_leaf_count()
+ }
+
+ pub fn get_leaf_node(&self, index: LeafIndex) -> Result<&LeafNode, MlsError> {
+ self.nodes.borrow_as_leaf(index)
+ }
+
+ pub fn find_leaf_node(&self, leaf_node: &LeafNode) -> Option<LeafIndex> {
+ self.nodes.non_empty_leaves().find_map(
+ |(index, node)| {
+ if node == leaf_node {
+ Some(index)
+ } else {
+ None
+ }
+ },
+ )
+ }
+
+ #[cfg(feature = "custom_proposal")]
+ pub fn can_support_proposal(&self, proposal_type: ProposalType) -> bool {
+ #[cfg(feature = "tree_index")]
+ return self.index.count_supporting_proposal(proposal_type) == self.occupied_leaf_count();
+
+ #[cfg(not(feature = "tree_index"))]
+ self.nodes
+ .non_empty_leaves()
+ .all(|(_, l)| l.capabilities.proposals.contains(&proposal_type))
+ }
+
+ #[cfg(test)]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn add_leaves<I: IdentityProvider, CP: CipherSuiteProvider>(
+ &mut self,
+ leaf_nodes: Vec<LeafNode>,
+ id_provider: &I,
+ cipher_suite_provider: &CP,
+ ) -> Result<Vec<LeafIndex>, MlsError> {
+ let mut start = LeafIndex(0);
+ let mut added = vec![];
+
+ for leaf in leaf_nodes.into_iter() {
+ start = self
+ .add_leaf(leaf, id_provider, &Default::default(), Some(start))
+ .await?;
+ added.push(start);
+ }
+
+ self.update_hashes(&added, cipher_suite_provider).await?;
+
+ Ok(added)
+ }
+
+ pub fn non_empty_leaves(&self) -> impl Iterator<Item = (LeafIndex, &LeafNode)> + '_ {
+ self.nodes.non_empty_leaves()
+ }
+
+ #[cfg(feature = "prior_epoch")]
+ pub fn leaves(&self) -> impl Iterator<Item = Option<&LeafNode>> + '_ {
+ self.nodes.leaves()
+ }
+
+ pub(crate) fn update_node(
+ &mut self,
+ pub_key: crypto::HpkePublicKey,
+ index: NodeIndex,
+ ) -> Result<(), MlsError> {
+ self.nodes
+ .borrow_or_fill_node_as_parent(index, &pub_key)
+ .map(|p| {
+ p.public_key = pub_key;
+ p.unmerged_leaves = vec![];
+ })
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn apply_update_path<IP, CP>(
+ &mut self,
+ sender: LeafIndex,
+ update_path: &ValidatedUpdatePath,
+ extensions: &ExtensionList,
+ identity_provider: IP,
+ cipher_suite_provider: &CP,
+ ) -> Result<(), MlsError>
+ where
+ IP: IdentityProvider,
+ CP: CipherSuiteProvider,
+ {
+ // Install the new leaf node
+ let existing_leaf = self.nodes.borrow_as_leaf_mut(sender)?;
+
+ #[cfg(feature = "tree_index")]
+ let original_leaf_node = existing_leaf.clone();
+
+ #[cfg(feature = "tree_index")]
+ let original_identity = identity_provider
+ .identity(&original_leaf_node.signing_identity, extensions)
+ .await
+ .map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))?;
+
+ *existing_leaf = update_path.leaf_node.clone();
+
+ // Update the rest of the nodes on the direct path
+ let path = self.nodes.direct_copath(sender);
+
+ for (node, pn) in update_path.nodes.iter().zip(path) {
+ node.as_ref()
+ .map(|n| self.update_node(n.public_key.clone(), pn.path))
+ .transpose()?;
+ }
+
+ #[cfg(feature = "tree_index")]
+ self.index.remove(&original_leaf_node, &original_identity);
+
+ index_insert(
+ #[cfg(feature = "tree_index")]
+ &mut self.index,
+ #[cfg(not(feature = "tree_index"))]
+ &self.nodes,
+ &update_path.leaf_node,
+ sender,
+ &identity_provider,
+ extensions,
+ )
+ .await?;
+
+ // Verify the parent hash of the new sender leaf node and update the parent hash values
+ // in the local tree
+ self.update_parent_hashes(sender, true, cipher_suite_provider)
+ .await?;
+
+ Ok(())
+ }
+
+ fn update_unmerged(&mut self, index: LeafIndex) -> Result<(), MlsError> {
+ // For a given leaf index, find parent nodes and add the leaf to the unmerged leaf
+ self.nodes.direct_copath(index).into_iter().for_each(|i| {
+ if let Ok(p) = self.nodes.borrow_as_parent_mut(i.path) {
+ p.unmerged_leaves.push(index)
+ }
+ });
+
+ Ok(())
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn batch_edit<I, CP>(
+ &mut self,
+ proposal_bundle: &mut ProposalBundle,
+ extensions: &ExtensionList,
+ id_provider: &I,
+ cipher_suite_provider: &CP,
+ filter: bool,
+ ) -> Result<Vec<LeafIndex>, MlsError>
+ where
+ I: IdentityProvider,
+ CP: CipherSuiteProvider,
+ {
+ // Apply removes (they commute with updates because they don't touch the same leaves)
+ for i in (0..proposal_bundle.remove_proposals().len()).rev() {
+ let index = proposal_bundle.remove_proposals()[i].proposal.to_remove;
+ let res = self.nodes.blank_leaf_node(index);
+
+ if res.is_ok() {
+ // This shouldn't fail if `blank_leaf_node` succedded.
+ self.nodes.blank_direct_path(index)?;
+ }
+
+ #[cfg(feature = "tree_index")]
+ if let Ok(old_leaf) = &res {
+ // If this fails, it's not because the proposal is bad.
+ let identity =
+ identity(&old_leaf.signing_identity, id_provider, extensions).await?;
+
+ self.index.remove(old_leaf, &identity);
+ }
+
+ if proposal_bundle.remove_proposals()[i].is_by_value() || !filter {
+ res?;
+ } else if res.is_err() {
+ proposal_bundle.remove::<RemoveProposal>(i);
+ }
+ }
+
+ // Remove from the tree old leaves from updates
+ let mut partial_updates = vec![];
+ let senders = proposal_bundle.update_senders.iter().copied();
+
+ for (i, (p, index)) in proposal_bundle.updates.iter().zip(senders).enumerate() {
+ let new_leaf = p.proposal.leaf_node.clone();
+
+ match self.nodes.blank_leaf_node(index) {
+ Ok(old_leaf) => {
+ #[cfg(feature = "tree_index")]
+ let old_id =
+ identity(&old_leaf.signing_identity, id_provider, extensions).await?;
+
+ #[cfg(feature = "tree_index")]
+ self.index.remove(&old_leaf, &old_id);
+
+ partial_updates.push((index, old_leaf, new_leaf, i));
+ }
+ _ => {
+ if !filter || !p.is_by_reference() {
+ return Err(MlsError::UpdatingNonExistingMember);
+ }
+ }
+ }
+ }
+
+ #[cfg(feature = "tree_index")]
+ let index_clone = self.index.clone();
+
+ let mut removed_leaves = vec![];
+ let mut updated_indices = vec![];
+ let mut bad_indices = vec![];
+
+ // Apply updates one by one. If there's an update which we can't apply or revert, we revert
+ // all updates.
+ for (index, old_leaf, new_leaf, i) in partial_updates.into_iter() {
+ #[cfg(feature = "tree_index")]
+ let res =
+ index_insert(&mut self.index, &new_leaf, index, id_provider, extensions).await;
+
+ #[cfg(not(feature = "tree_index"))]
+ let res = index_insert(&self.nodes, &new_leaf, index, id_provider, extensions).await;
+
+ let err = res.is_err();
+
+ if !filter {
+ res?;
+ }
+
+ if !err {
+ self.nodes.insert_leaf(index, new_leaf);
+ removed_leaves.push(old_leaf);
+ updated_indices.push(index);
+ } else {
+ #[cfg(feature = "tree_index")]
+ let res =
+ index_insert(&mut self.index, &old_leaf, index, id_provider, extensions).await;
+
+ #[cfg(not(feature = "tree_index"))]
+ let res =
+ index_insert(&self.nodes, &old_leaf, index, id_provider, extensions).await;
+
+ if res.is_ok() {
+ self.nodes.insert_leaf(index, old_leaf);
+ bad_indices.push(i);
+ } else {
+ // Revert all updates and stop. We're already in the "filter" case, so we don't throw an error.
+ #[cfg(feature = "tree_index")]
+ {
+ self.index = index_clone;
+ }
+
+ removed_leaves
+ .into_iter()
+ .zip(updated_indices.iter())
+ .for_each(|(leaf, index)| self.nodes.insert_leaf(*index, leaf));
+
+ updated_indices = vec![];
+ break;
+ }
+ }
+ }
+
+ // If we managed to update something, blank direct paths
+ updated_indices
+ .iter()
+ .try_for_each(|index| self.nodes.blank_direct_path(*index).map(|_| ()))?;
+
+ // Remove rejected updates from applied proposals
+ if updated_indices.is_empty() {
+ // This takes care of the "revert all" scenario
+ proposal_bundle.updates = vec![];
+ } else {
+ for i in bad_indices.into_iter().rev() {
+ proposal_bundle.remove::<UpdateProposal>(i);
+ proposal_bundle.update_senders.remove(i);
+ }
+ }
+
+ // Apply adds
+ let mut start = LeafIndex(0);
+ let mut added = vec![];
+ let mut bad_indexes = vec![];
+
+ for i in 0..proposal_bundle.additions.len() {
+ let leaf = proposal_bundle.additions[i]
+ .proposal
+ .key_package
+ .leaf_node
+ .clone();
+
+ let res = self
+ .add_leaf(leaf, id_provider, extensions, Some(start))
+ .await;
+
+ if let Ok(index) = res {
+ start = index;
+ added.push(start);
+ } else if proposal_bundle.additions[i].is_by_value() || !filter {
+ res?;
+ } else {
+ bad_indexes.push(i);
+ }
+ }
+
+ for i in bad_indexes.into_iter().rev() {
+ proposal_bundle.remove::<AddProposal>(i);
+ }
+
+ self.nodes.trim();
+
+ let updated_leaves = proposal_bundle
+ .remove_proposals()
+ .iter()
+ .map(|p| p.proposal.to_remove)
+ .chain(updated_indices)
+ .chain(added.iter().copied())
+ .collect_vec();
+
+ self.update_hashes(&updated_leaves, cipher_suite_provider)
+ .await?;
+
+ Ok(added)
+ }
+
+ #[cfg(not(feature = "by_ref_proposal"))]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn batch_edit_lite<I, CP>(
+ &mut self,
+ proposal_bundle: &ProposalBundle,
+ extensions: &ExtensionList,
+ id_provider: &I,
+ cipher_suite_provider: &CP,
+ ) -> Result<Vec<LeafIndex>, MlsError>
+ where
+ I: IdentityProvider,
+ CP: CipherSuiteProvider,
+ {
+ // Apply removes
+ for p in &proposal_bundle.removals {
+ let index = p.proposal.to_remove;
+
+ #[cfg(feature = "tree_index")]
+ {
+ // If this fails, it's not because the proposal is bad.
+ let old_leaf = self.nodes.blank_leaf_node(index)?;
+
+ let identity =
+ identity(&old_leaf.signing_identity, id_provider, extensions).await?;
+
+ self.index.remove(&old_leaf, &identity);
+ }
+
+ #[cfg(not(feature = "tree_index"))]
+ self.nodes.blank_leaf_node(index)?;
+
+ self.nodes.blank_direct_path(index)?;
+ }
+
+ // Apply adds
+ let mut start = LeafIndex(0);
+ let mut added = vec![];
+
+ for p in &proposal_bundle.additions {
+ let leaf = p.proposal.key_package.leaf_node.clone();
+ start = self
+ .add_leaf(leaf, id_provider, extensions, Some(start))
+ .await?;
+ added.push(start);
+ }
+
+ self.nodes.trim();
+
+ let updated_leaves = proposal_bundle
+ .remove_proposals()
+ .iter()
+ .map(|p| p.proposal.to_remove)
+ .chain(added.iter().copied())
+ .collect_vec();
+
+ self.update_hashes(&updated_leaves, cipher_suite_provider)
+ .await?;
+
+ Ok(added)
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn add_leaf<I: IdentityProvider>(
+ &mut self,
+ leaf: LeafNode,
+ id_provider: &I,
+ extensions: &ExtensionList,
+ start: Option<LeafIndex>,
+ ) -> Result<LeafIndex, MlsError> {
+ let index = self.nodes.next_empty_leaf(start.unwrap_or(LeafIndex(0)));
+
+ #[cfg(feature = "tree_index")]
+ index_insert(&mut self.index, &leaf, index, id_provider, extensions).await?;
+
+ #[cfg(not(feature = "tree_index"))]
+ index_insert(&self.nodes, &leaf, index, id_provider, extensions).await?;
+
+ self.nodes.insert_leaf(index, leaf);
+ self.update_unmerged(index)?;
+
+ Ok(index)
+ }
+}
+
+#[cfg(feature = "tree_index")]
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+async fn identity<I: IdentityProvider>(
+ signing_id: &SigningIdentity,
+ provider: &I,
+ extensions: &ExtensionList,
+) -> Result<Vec<u8>, MlsError> {
+ provider
+ .identity(signing_id, extensions)
+ .await
+ .map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))
+}
+
+#[cfg(feature = "std")]
+impl Display for TreeKemPublic {
+ fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
+ write!(f, "{}", tree_utils::build_ascii_tree(&self.nodes))
+ }
+}
+
+#[cfg(test)]
+use crate::group::{proposal::Proposal, proposal_filter::ProposalSource, Sender};
+
+#[cfg(test)]
+impl TreeKemPublic {
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn update_leaf<I, CP>(
+ &mut self,
+ leaf_index: u32,
+ leaf_node: LeafNode,
+ identity_provider: &I,
+ cipher_suite_provider: &CP,
+ ) -> Result<(), MlsError>
+ where
+ I: IdentityProvider,
+ CP: CipherSuiteProvider,
+ {
+ let p = Proposal::Update(UpdateProposal { leaf_node });
+
+ let mut bundle = ProposalBundle::default();
+ bundle.add(p, Sender::Member(leaf_index), ProposalSource::ByValue);
+ bundle.update_senders = vec![LeafIndex(leaf_index)];
+
+ self.batch_edit(
+ &mut bundle,
+ &Default::default(),
+ identity_provider,
+ cipher_suite_provider,
+ true,
+ )
+ .await?;
+
+ Ok(())
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn remove_leaves<I, CP>(
+ &mut self,
+ indexes: Vec<LeafIndex>,
+ identity_provider: &I,
+ cipher_suite_provider: &CP,
+ ) -> Result<Vec<(LeafIndex, LeafNode)>, MlsError>
+ where
+ I: IdentityProvider,
+ CP: CipherSuiteProvider,
+ {
+ let old_tree = self.clone();
+
+ let proposals = indexes
+ .iter()
+ .copied()
+ .map(|to_remove| Proposal::Remove(RemoveProposal { to_remove }));
+
+ let mut bundle = ProposalBundle::default();
+
+ for p in proposals {
+ bundle.add(p, Sender::Member(0), ProposalSource::ByValue);
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ self.batch_edit(
+ &mut bundle,
+ &Default::default(),
+ identity_provider,
+ cipher_suite_provider,
+ true,
+ )
+ .await?;
+
+ #[cfg(not(feature = "by_ref_proposal"))]
+ self.batch_edit_lite(
+ &bundle,
+ &Default::default(),
+ identity_provider,
+ cipher_suite_provider,
+ )
+ .await?;
+
+ bundle
+ .removals
+ .iter()
+ .map(|p| {
+ let index = p.proposal.to_remove;
+ let leaf = old_tree.get_leaf_node(index)?.clone();
+ Ok((index, leaf))
+ })
+ .collect()
+ }
+
+ pub fn get_leaf_nodes(&self) -> Vec<&LeafNode> {
+ self.nodes.non_empty_leaves().map(|(_, l)| l).collect()
+ }
+}
+
+#[cfg(test)]
+pub(crate) mod test_utils {
+ use crate::crypto::test_utils::TestCryptoProvider;
+ use crate::signer::Signable;
+ use alloc::vec::Vec;
+ use alloc::{format, vec};
+ use mls_rs_core::crypto::CipherSuiteProvider;
+ use mls_rs_core::group::Capabilities;
+ use mls_rs_core::identity::BasicCredential;
+
+ use crate::identity::test_utils::get_test_signing_identity;
+ use crate::{
+ cipher_suite::CipherSuite,
+ crypto::{HpkeSecretKey, SignatureSecretKey},
+ identity::basic::BasicIdentityProvider,
+ tree_kem::leaf_node::test_utils::get_basic_test_node_sig_key,
+ };
+
+ use super::leaf_node::{ConfigProperties, LeafNodeSigningContext};
+ use super::node::LeafIndex;
+ use super::Lifetime;
+ use super::{
+ leaf_node::{test_utils::get_basic_test_node, LeafNode},
+ TreeKemPrivate, TreeKemPublic,
+ };
+
+ #[derive(Debug)]
+ pub(crate) struct TestTree {
+ pub public: TreeKemPublic,
+ pub private: TreeKemPrivate,
+ pub creator_leaf: LeafNode,
+ pub creator_signing_key: SignatureSecretKey,
+ pub creator_hpke_secret: HpkeSecretKey,
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn get_test_tree(cipher_suite: CipherSuite) -> TestTree {
+ let (creator_leaf, creator_hpke_secret, creator_signing_key) =
+ get_basic_test_node_sig_key(cipher_suite, "creator").await;
+
+ let (test_public, test_private) = TreeKemPublic::derive(
+ creator_leaf.clone(),
+ creator_hpke_secret.clone(),
+ &BasicIdentityProvider,
+ &Default::default(),
+ )
+ .await
+ .unwrap();
+
+ TestTree {
+ public: test_public,
+ private: test_private,
+ creator_leaf,
+ creator_signing_key,
+ creator_hpke_secret,
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn get_test_leaf_nodes(cipher_suite: CipherSuite) -> Vec<LeafNode> {
+ [
+ get_basic_test_node(cipher_suite, "A").await,
+ get_basic_test_node(cipher_suite, "B").await,
+ get_basic_test_node(cipher_suite, "C").await,
+ ]
+ .to_vec()
+ }
+
+ impl TreeKemPublic {
+ #[cfg(feature = "tree_index")]
+ pub fn equal_internals(&self, other: &TreeKemPublic) -> bool {
+ self.tree_hashes == other.tree_hashes && self.index == other.index
+ }
+ }
+
+ #[derive(Debug, Clone)]
+ pub struct TreeWithSigners {
+ pub tree: TreeKemPublic,
+ pub signers: Vec<Option<SignatureSecretKey>>,
+ pub group_id: Vec<u8>,
+ }
+
+ impl TreeWithSigners {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn make_full_tree<P: CipherSuiteProvider>(
+ n_leaves: u32,
+ cs: &P,
+ ) -> TreeWithSigners {
+ let mut tree = TreeWithSigners {
+ tree: TreeKemPublic::new(),
+ signers: vec![],
+ group_id: cs.random_bytes_vec(cs.kdf_extract_size()).unwrap(),
+ };
+
+ tree.add_member("Alice", cs).await;
+
+ // A adds B, B adds C, C adds D etc.
+ for i in 1..n_leaves {
+ tree.add_member(&format!("Alice{i}"), cs).await;
+ tree.update_committer_path(i - 1, cs).await;
+ }
+
+ tree
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn add_member<P: CipherSuiteProvider>(&mut self, name: &str, cs: &P) {
+ let (leaf, signer) = make_leaf(name, cs).await;
+ let index = self.tree.nodes.next_empty_leaf(LeafIndex(0));
+ self.tree.nodes.insert_leaf(index, leaf);
+ self.tree.update_unmerged(index).unwrap();
+ let index = *index as usize;
+
+ match self.signers.len() {
+ l if l == index => self.signers.push(Some(signer)),
+ l if l > index => self.signers[index] = Some(signer),
+ _ => panic!("signer tree size mismatch"),
+ }
+ }
+
+ #[cfg(feature = "rfc_compliant")]
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ pub fn remove_member(&mut self, member: u32) {
+ self.tree
+ .nodes
+ .blank_direct_path(LeafIndex(member))
+ .unwrap();
+
+ self.tree.nodes.blank_leaf_node(LeafIndex(member)).unwrap();
+
+ *self
+ .signers
+ .get_mut(member as usize)
+ .expect("signer tree size mismatch") = None;
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn update_committer_path<P: CipherSuiteProvider>(
+ &mut self,
+ committer: u32,
+ cs: &P,
+ ) {
+ let committer = LeafIndex(committer);
+
+ let path = self.tree.nodes.direct_copath(committer);
+ let filtered = self.tree.nodes.filtered(committer).unwrap();
+
+ for (n, f) in path.into_iter().zip(filtered) {
+ if !f {
+ self.tree
+ .update_node(cs.kem_generate().await.unwrap().1, n.path)
+ .unwrap();
+ }
+ }
+
+ self.tree.tree_hashes.current = vec![];
+ self.tree.tree_hash(cs).await.unwrap();
+
+ self.tree
+ .update_parent_hashes(committer, false, cs)
+ .await
+ .unwrap();
+
+ self.tree.tree_hashes.current = vec![];
+ self.tree.tree_hash(cs).await.unwrap();
+
+ let context = LeafNodeSigningContext {
+ group_id: Some(&self.group_id),
+ leaf_index: Some(*committer),
+ };
+
+ let signer = self.signers[*committer as usize].as_ref().unwrap();
+
+ self.tree
+ .nodes
+ .borrow_as_leaf_mut(committer)
+ .unwrap()
+ .sign(cs, signer, &context)
+ .await
+ .unwrap();
+
+ self.tree.tree_hashes.current = vec![];
+ self.tree.tree_hash(cs).await.unwrap();
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn make_leaf<P: CipherSuiteProvider>(
+ name: &str,
+ cs: &P,
+ ) -> (LeafNode, SignatureSecretKey) {
+ let (signing_identity, signature_key) =
+ get_test_signing_identity(cs.cipher_suite(), name.as_bytes()).await;
+
+ let capabilities = Capabilities {
+ credentials: vec![BasicCredential::credential_type()],
+ cipher_suites: TestCryptoProvider::all_supported_cipher_suites(),
+ ..Default::default()
+ };
+
+ let properties = ConfigProperties {
+ capabilities,
+ extensions: Default::default(),
+ };
+
+ let (leaf, _) = LeafNode::generate(
+ cs,
+ properties,
+ signing_identity,
+ &signature_key,
+ Lifetime::years(1).unwrap(),
+ )
+ .await
+ .unwrap();
+
+ (leaf, signature_key)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use crate::client::test_utils::TEST_CIPHER_SUITE;
+ use crate::crypto::test_utils::{test_cipher_suite_provider, TestCryptoProvider};
+
+ #[cfg(feature = "custom_proposal")]
+ use crate::group::proposal::ProposalType;
+
+ use crate::identity::basic::BasicIdentityProvider;
+ use crate::tree_kem::leaf_node::LeafNode;
+ use crate::tree_kem::node::{LeafIndex, Node, NodeIndex, NodeTypeResolver, Parent};
+ use crate::tree_kem::parent_hash::ParentHash;
+ use crate::tree_kem::test_utils::{get_test_leaf_nodes, get_test_tree};
+ use crate::tree_kem::{MlsError, TreeKemPublic};
+ use alloc::borrow::ToOwned;
+ use alloc::vec;
+ use alloc::vec::Vec;
+ use assert_matches::assert_matches;
+
+ #[cfg(feature = "by_ref_proposal")]
+ use alloc::boxed::Box;
+
+ #[cfg(feature = "by_ref_proposal")]
+ use crate::{
+ client::test_utils::TEST_PROTOCOL_VERSION,
+ group::{
+ proposal::{Proposal, RemoveProposal, UpdateProposal},
+ proposal_filter::{ProposalBundle, ProposalSource},
+ proposal_ref::ProposalRef,
+ Sender,
+ },
+ key_package::test_utils::test_key_package,
+ };
+
+ #[cfg(any(feature = "by_ref_proposal", feature = "custo_proposal"))]
+ use crate::tree_kem::leaf_node::test_utils::get_basic_test_node;
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_derive() {
+ for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
+ let test_tree = get_test_tree(cipher_suite).await;
+
+ assert_eq!(
+ test_tree.public.nodes[0],
+ Some(Node::Leaf(test_tree.creator_leaf.clone()))
+ );
+
+ assert_eq!(test_tree.private.self_index, LeafIndex(0));
+
+ assert_eq!(
+ test_tree.private.secret_keys[0],
+ Some(test_tree.creator_hpke_secret)
+ );
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_import_export() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+ let mut test_tree = get_test_tree(TEST_CIPHER_SUITE).await;
+
+ let additional_key_packages = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
+
+ test_tree
+ .public
+ .add_leaves(
+ additional_key_packages,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap();
+
+ let imported = TreeKemPublic::import_node_data(
+ test_tree.public.nodes.clone(),
+ &BasicIdentityProvider,
+ &Default::default(),
+ )
+ .await
+ .unwrap();
+
+ assert_eq!(test_tree.public.nodes, imported.nodes);
+
+ #[cfg(feature = "tree_index")]
+ assert_eq!(test_tree.public.index, imported.index);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_add_leaf() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+ let mut tree = TreeKemPublic::new();
+
+ let leaf_nodes = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
+
+ let res = tree
+ .add_leaves(
+ leaf_nodes.clone(),
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap();
+
+ // The leaf count should be equal to the number of packages we added
+ assert_eq!(res.len(), leaf_nodes.len());
+ assert_eq!(tree.occupied_leaf_count(), leaf_nodes.len() as u32);
+
+ // Each added package should be at the proper index and searchable in the tree
+ res.into_iter().zip(leaf_nodes.clone()).for_each(|(r, kp)| {
+ assert_eq!(tree.get_leaf_node(r).unwrap(), &kp);
+ });
+
+ // Verify the underlying state
+ #[cfg(feature = "tree_index")]
+ assert_eq!(tree.index.len(), tree.occupied_leaf_count() as usize);
+
+ assert_eq!(tree.nodes.len(), 5);
+ assert_eq!(tree.nodes[0], leaf_nodes[0].clone().into());
+ assert_eq!(tree.nodes[1], None);
+ assert_eq!(tree.nodes[2], leaf_nodes[1].clone().into());
+ assert_eq!(tree.nodes[3], None);
+ assert_eq!(tree.nodes[4], leaf_nodes[2].clone().into());
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_get_key_packages() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+ let mut tree = TreeKemPublic::new();
+
+ let key_packages = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
+
+ tree.add_leaves(key_packages, &BasicIdentityProvider, &cipher_suite_provider)
+ .await
+ .unwrap();
+
+ let key_packages = tree.get_leaf_nodes();
+ assert_eq!(key_packages, key_packages.to_owned());
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_add_leaf_duplicate() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+ let mut tree = TreeKemPublic::new();
+
+ let key_packages = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
+
+ tree.add_leaves(
+ key_packages.clone(),
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap();
+
+ let res = tree
+ .add_leaves(key_packages, &BasicIdentityProvider, &cipher_suite_provider)
+ .await;
+
+ assert_matches!(res, Err(MlsError::DuplicateLeafData(_)));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_add_leaf_empty_leaf() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+ let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public;
+ let key_packages = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
+
+ tree.add_leaves(
+ [key_packages[0].clone()].to_vec(),
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap();
+
+ tree.nodes[0] = None; // Set the original first node to none
+ //
+ tree.add_leaves(
+ [key_packages[1].clone()].to_vec(),
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap();
+
+ assert_eq!(tree.nodes[0], key_packages[1].clone().into());
+ assert_eq!(tree.nodes[1], None);
+ assert_eq!(tree.nodes[2], key_packages[0].clone().into());
+ assert_eq!(tree.nodes.len(), 3)
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_add_leaf_unmerged() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+ let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public;
+ let key_packages = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
+
+ tree.add_leaves(
+ [key_packages[0].clone(), key_packages[1].clone()].to_vec(),
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap();
+
+ tree.nodes[3] = Parent {
+ public_key: vec![].into(),
+ parent_hash: ParentHash::empty(),
+ unmerged_leaves: vec![],
+ }
+ .into();
+
+ tree.add_leaves(
+ [key_packages[2].clone()].to_vec(),
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap();
+
+ assert_eq!(
+ tree.nodes[3].as_parent().unwrap().unmerged_leaves,
+ vec![LeafIndex(3)]
+ )
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_update_leaf() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+ // Create a tree
+ let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public;
+
+ let key_packages = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
+
+ tree.add_leaves(key_packages, &BasicIdentityProvider, &cipher_suite_provider)
+ .await
+ .unwrap();
+
+ // Add in parent nodes so we can detect them clearing after update
+ tree.nodes.direct_copath(LeafIndex(0)).iter().for_each(|n| {
+ tree.nodes
+ .borrow_or_fill_node_as_parent(n.path, &b"pub_key".to_vec().into())
+ .unwrap();
+ });
+
+ let original_size = tree.occupied_leaf_count();
+ let original_leaf_index = LeafIndex(1);
+
+ let updated_leaf = get_basic_test_node(TEST_CIPHER_SUITE, "A").await;
+
+ tree.update_leaf(
+ *original_leaf_index,
+ updated_leaf.clone(),
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap();
+
+ // The tree should not have grown due to an update
+ assert_eq!(tree.occupied_leaf_count(), original_size);
+
+ // The cache of tree package indexes should not have grown
+ #[cfg(feature = "tree_index")]
+ assert_eq!(tree.index.len() as u32, tree.occupied_leaf_count());
+
+ // The key package should be updated in the tree
+ assert_eq!(
+ tree.get_leaf_node(original_leaf_index).unwrap(),
+ &updated_leaf
+ );
+
+ // Verify that the direct path has been cleared
+ tree.nodes.direct_copath(LeafIndex(0)).iter().for_each(|n| {
+ assert!(tree.nodes[n.path as usize].is_none());
+ });
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_update_leaf_not_found() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ // Create a tree
+ let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public;
+
+ let key_packages = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
+
+ tree.add_leaves(key_packages, &BasicIdentityProvider, &cipher_suite_provider)
+ .await
+ .unwrap();
+
+ let new_key_package = get_basic_test_node(TEST_CIPHER_SUITE, "new").await;
+
+ let res = tree
+ .update_leaf(
+ 128,
+ new_key_package,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ )
+ .await;
+
+ assert_matches!(res, Err(MlsError::UpdatingNonExistingMember));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_remove_leaf() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ // Create a tree
+ let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public;
+ let key_packages = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
+
+ let indexes = tree
+ .add_leaves(
+ key_packages.clone(),
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap();
+
+ let original_leaf_count = tree.occupied_leaf_count();
+
+ // Remove two leaves from the tree
+ let expected_result: Vec<(LeafIndex, LeafNode)> =
+ indexes.clone().into_iter().zip(key_packages).collect();
+
+ let res = tree
+ .remove_leaves(
+ indexes.clone(),
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap();
+
+ // The order may change
+ assert!(res.iter().all(|x| expected_result.contains(x)));
+ assert!(expected_result.iter().all(|x| res.contains(x)));
+
+ // The leaves should be removed from the tree
+ assert_eq!(
+ tree.occupied_leaf_count(),
+ original_leaf_count - indexes.len() as u32
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_remove_leaf_middle() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ // Create a tree
+ let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public;
+ let leaf_nodes = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
+
+ let to_remove = tree
+ .add_leaves(
+ leaf_nodes.clone(),
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap()[0];
+
+ let original_leaf_count = tree.occupied_leaf_count();
+
+ let res = tree
+ .remove_leaves(
+ vec![to_remove],
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap();
+
+ assert_eq!(res, vec![(to_remove, leaf_nodes[0].clone())]);
+
+ // The leaf count should have been reduced by 1
+ assert_eq!(tree.occupied_leaf_count(), original_leaf_count - 1);
+
+ // There should be a blank in the tree
+ assert_eq!(
+ tree.nodes.get(NodeIndex::from(to_remove) as usize).unwrap(),
+ &None
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_create_blanks() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ // Create a tree
+ let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public;
+
+ let key_packages = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
+
+ tree.add_leaves(key_packages, &BasicIdentityProvider, &cipher_suite_provider)
+ .await
+ .unwrap();
+
+ let original_leaf_count = tree.occupied_leaf_count();
+
+ let to_remove = vec![LeafIndex(2)];
+
+ // Remove the leaf from the tree
+ tree.remove_leaves(to_remove, &BasicIdentityProvider, &cipher_suite_provider)
+ .await
+ .unwrap();
+
+ // The occupied leaf count should have been reduced by 1
+ assert_eq!(tree.occupied_leaf_count(), original_leaf_count - 1);
+
+ // The total leaf count should remain unchanged
+ assert_eq!(tree.total_leaf_count(), original_leaf_count);
+
+ // The location of key_packages[1] should now be blank
+ let removed_location = tree
+ .nodes
+ .get(NodeIndex::from(LeafIndex(2)) as usize)
+ .unwrap();
+
+ assert_eq!(removed_location, &None);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_remove_leaf_failure() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ // Create a tree
+ let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public;
+
+ let res = tree
+ .remove_leaves(
+ vec![LeafIndex(128)],
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ )
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidNodeIndex(256)));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_find_leaf_node() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+ // Create a tree
+ let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public;
+
+ let leaf_nodes = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
+
+ tree.add_leaves(
+ leaf_nodes.clone(),
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap();
+
+ // Find each node
+ for (i, leaf_node) in leaf_nodes.iter().enumerate() {
+ let expected_index = LeafIndex(i as u32 + 1);
+ assert_eq!(tree.find_leaf_node(leaf_node), Some(expected_index));
+ }
+ }
+
+ // TODO add test for the lite version
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn batch_edit_works() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public;
+ let leaf_nodes = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
+
+ tree.add_leaves(leaf_nodes, &BasicIdentityProvider, &cipher_suite_provider)
+ .await
+ .unwrap();
+
+ let mut bundle = ProposalBundle::default();
+
+ let kp = test_key_package(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "D").await;
+ let add = Proposal::Add(Box::new(kp.into()));
+
+ bundle.add(add, Sender::Member(0), ProposalSource::ByValue);
+
+ let update = UpdateProposal {
+ leaf_node: get_basic_test_node(TEST_CIPHER_SUITE, "A").await,
+ };
+
+ let update = Proposal::Update(update);
+ let pref = ProposalRef::new_fake(vec![1, 2, 3]);
+
+ bundle.add(update, Sender::Member(1), ProposalSource::ByReference(pref));
+
+ bundle.update_senders = vec![LeafIndex(1)];
+
+ let remove = RemoveProposal {
+ to_remove: LeafIndex(2),
+ };
+
+ let remove = Proposal::Remove(remove);
+
+ bundle.add(remove, Sender::Member(0), ProposalSource::ByValue);
+
+ tree.batch_edit(
+ &mut bundle,
+ &Default::default(),
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ true,
+ )
+ .await
+ .unwrap();
+
+ assert_eq!(bundle.add_proposals().len(), 1);
+ assert_eq!(bundle.remove_proposals().len(), 1);
+ assert_eq!(bundle.update_proposals().len(), 1);
+ }
+
+ #[cfg(feature = "custom_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn custom_proposal_support() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+ let mut tree = TreeKemPublic::new();
+
+ let test_proposal_type = ProposalType::from(42);
+
+ let mut leaf_nodes = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
+
+ leaf_nodes
+ .iter_mut()
+ .for_each(|n| n.capabilities.proposals.push(test_proposal_type));
+
+ tree.add_leaves(leaf_nodes, &BasicIdentityProvider, &cipher_suite_provider)
+ .await
+ .unwrap();
+
+ assert!(tree.can_support_proposal(test_proposal_type));
+ assert!(!tree.can_support_proposal(ProposalType::from(43)));
+
+ let test_node = get_basic_test_node(TEST_CIPHER_SUITE, "another").await;
+
+ tree.add_leaves(
+ vec![test_node],
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap();
+
+ assert!(!tree.can_support_proposal(test_proposal_type));
+ }
+}
diff --git a/src/tree_kem/node.rs b/src/tree_kem/node.rs
new file mode 100644
index 0000000..8b7372f
--- /dev/null
+++ b/src/tree_kem/node.rs
@@ -0,0 +1,577 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use super::leaf_node::LeafNode;
+use crate::client::MlsError;
+use crate::crypto::HpkePublicKey;
+use crate::tree_kem::math as tree_math;
+use crate::tree_kem::parent_hash::ParentHash;
+use alloc::vec;
+use alloc::vec::Vec;
+use core::hash::Hash;
+use core::ops::{Deref, DerefMut};
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+use tree_math::{CopathNode, TreeIndex};
+
+#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub(crate) struct Parent {
+ pub public_key: HpkePublicKey,
+ pub parent_hash: ParentHash,
+ pub unmerged_leaves: Vec<LeafIndex>,
+}
+
+#[derive(
+ Clone, Copy, Debug, Ord, PartialEq, PartialOrd, Hash, Eq, MlsSize, MlsEncode, MlsDecode,
+)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub struct LeafIndex(pub(crate) u32);
+
+impl LeafIndex {
+ pub fn new(i: u32) -> Self {
+ Self(i)
+ }
+}
+
+impl Deref for LeafIndex {
+ type Target = u32;
+
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+impl From<&LeafIndex> for NodeIndex {
+ fn from(leaf_index: &LeafIndex) -> Self {
+ leaf_index.0 * 2
+ }
+}
+
+impl From<LeafIndex> for NodeIndex {
+ fn from(leaf_index: LeafIndex) -> Self {
+ leaf_index.0 * 2
+ }
+}
+
+pub(crate) type NodeIndex = u32;
+
+#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+#[allow(clippy::large_enum_variant)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+#[repr(u8)]
+//TODO: Research if this should actually be a Box<Leaf> for memory / performance reasons
+pub(crate) enum Node {
+ Leaf(LeafNode) = 1u8,
+ Parent(Parent) = 2u8,
+}
+
+impl Node {
+ pub fn public_key(&self) -> &HpkePublicKey {
+ match self {
+ Node::Parent(p) => &p.public_key,
+ Node::Leaf(l) => &l.public_key,
+ }
+ }
+}
+
+impl From<Parent> for Option<Node> {
+ fn from(p: Parent) -> Self {
+ Node::from(p).into()
+ }
+}
+
+impl From<LeafNode> for Option<Node> {
+ fn from(l: LeafNode) -> Self {
+ Node::from(l).into()
+ }
+}
+
+impl From<Parent> for Node {
+ fn from(p: Parent) -> Self {
+ Node::Parent(p)
+ }
+}
+
+impl From<LeafNode> for Node {
+ fn from(l: LeafNode) -> Self {
+ Node::Leaf(l)
+ }
+}
+
+pub(crate) trait NodeTypeResolver {
+ fn as_parent(&self) -> Result<&Parent, MlsError>;
+ fn as_parent_mut(&mut self) -> Result<&mut Parent, MlsError>;
+ fn as_leaf(&self) -> Result<&LeafNode, MlsError>;
+ fn as_leaf_mut(&mut self) -> Result<&mut LeafNode, MlsError>;
+ fn as_non_empty(&self) -> Result<&Node, MlsError>;
+}
+
+impl NodeTypeResolver for Option<Node> {
+ fn as_parent(&self) -> Result<&Parent, MlsError> {
+ self.as_ref()
+ .and_then(|n| match n {
+ Node::Parent(p) => Some(p),
+ Node::Leaf(_) => None,
+ })
+ .ok_or(MlsError::ExpectedNode)
+ }
+
+ fn as_parent_mut(&mut self) -> Result<&mut Parent, MlsError> {
+ self.as_mut()
+ .and_then(|n| match n {
+ Node::Parent(p) => Some(p),
+ Node::Leaf(_) => None,
+ })
+ .ok_or(MlsError::ExpectedNode)
+ }
+
+ fn as_leaf(&self) -> Result<&LeafNode, MlsError> {
+ self.as_ref()
+ .and_then(|n| match n {
+ Node::Parent(_) => None,
+ Node::Leaf(l) => Some(l),
+ })
+ .ok_or(MlsError::ExpectedNode)
+ }
+
+ fn as_leaf_mut(&mut self) -> Result<&mut LeafNode, MlsError> {
+ self.as_mut()
+ .and_then(|n| match n {
+ Node::Parent(_) => None,
+ Node::Leaf(l) => Some(l),
+ })
+ .ok_or(MlsError::ExpectedNode)
+ }
+
+ fn as_non_empty(&self) -> Result<&Node, MlsError> {
+ self.as_ref().ok_or(MlsError::UnexpectedEmptyNode)
+ }
+}
+
+#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode, Default)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub(crate) struct NodeVec(Vec<Option<Node>>);
+
+impl From<Vec<Option<Node>>> for NodeVec {
+ fn from(x: Vec<Option<Node>>) -> Self {
+ NodeVec(x)
+ }
+}
+
+impl Deref for NodeVec {
+ type Target = Vec<Option<Node>>;
+
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+impl DerefMut for NodeVec {
+ fn deref_mut(&mut self) -> &mut Self::Target {
+ &mut self.0
+ }
+}
+
+impl NodeVec {
+ #[cfg(any(test, all(feature = "custom_proposal", feature = "tree_index")))]
+ pub fn occupied_leaf_count(&self) -> u32 {
+ self.non_empty_leaves().count() as u32
+ }
+
+ pub fn total_leaf_count(&self) -> u32 {
+ (self.len() as u32 / 2 + 1).next_power_of_two()
+ }
+
+ #[inline]
+ pub fn borrow_node(&self, index: NodeIndex) -> Result<&Option<Node>, MlsError> {
+ Ok(self.get(self.validate_index(index)?).unwrap_or(&None))
+ }
+
+ fn validate_index(&self, index: NodeIndex) -> Result<usize, MlsError> {
+ if (index as usize) >= self.len().next_power_of_two() {
+ Err(MlsError::InvalidNodeIndex(index))
+ } else {
+ Ok(index as usize)
+ }
+ }
+
+ #[cfg(test)]
+ fn empty_leaves(&mut self) -> impl Iterator<Item = (LeafIndex, &mut Option<Node>)> {
+ self.iter_mut()
+ .step_by(2)
+ .enumerate()
+ .filter(|(_, n)| n.is_none())
+ .map(|(i, n)| (LeafIndex(i as u32), n))
+ }
+
+ pub fn non_empty_leaves(&self) -> impl Iterator<Item = (LeafIndex, &LeafNode)> + '_ {
+ self.leaves()
+ .enumerate()
+ .filter_map(|(i, l)| l.map(|l| (LeafIndex(i as u32), l)))
+ }
+
+ pub fn non_empty_parents(&self) -> impl Iterator<Item = (NodeIndex, &Parent)> + '_ {
+ self.iter()
+ .enumerate()
+ .skip(1)
+ .step_by(2)
+ .map(|(i, n)| (i as NodeIndex, n))
+ .filter_map(|(i, n)| n.as_parent().ok().map(|p| (i, p)))
+ }
+
+ pub fn leaves(&self) -> impl Iterator<Item = Option<&LeafNode>> + '_ {
+ self.iter().step_by(2).map(|n| n.as_leaf().ok())
+ }
+
+ pub fn direct_copath(&self, index: LeafIndex) -> Vec<CopathNode<NodeIndex>> {
+ NodeIndex::from(index).direct_copath(&self.total_leaf_count())
+ }
+
+ // Section 8.4
+ // The filtered direct path of a node is obtained from the node's direct path by removing
+ // all nodes whose child on the nodes's copath has an empty resolution
+ pub fn filtered(&self, index: LeafIndex) -> Result<Vec<bool>, MlsError> {
+ Ok(NodeIndex::from(index)
+ .direct_copath(&self.total_leaf_count())
+ .into_iter()
+ .map(|cp| self.is_resolution_empty(cp.copath))
+ .collect())
+ }
+
+ #[inline]
+ pub fn is_blank(&self, index: NodeIndex) -> Result<bool, MlsError> {
+ self.borrow_node(index).map(|n| n.is_none())
+ }
+
+ #[inline]
+ pub fn is_leaf(&self, index: NodeIndex) -> bool {
+ index % 2 == 0
+ }
+
+ // Blank a previously filled leaf node, and return the existing leaf
+ pub fn blank_leaf_node(&mut self, leaf_index: LeafIndex) -> Result<LeafNode, MlsError> {
+ let node_index = self.validate_index(leaf_index.into())?;
+
+ match self.get_mut(node_index).and_then(Option::take) {
+ Some(Node::Leaf(l)) => Ok(l),
+ _ => Err(MlsError::RemovingNonExistingMember),
+ }
+ }
+
+ pub fn blank_direct_path(&mut self, leaf: LeafIndex) -> Result<(), MlsError> {
+ for i in self.direct_copath(leaf) {
+ if let Some(n) = self.get_mut(i.path as usize) {
+ *n = None
+ }
+ }
+
+ Ok(())
+ }
+
+ // Remove elements until the last node is non-blank
+ pub fn trim(&mut self) {
+ while self.last() == Some(&None) {
+ self.pop();
+ }
+ }
+
+ pub fn borrow_as_parent(&self, node_index: NodeIndex) -> Result<&Parent, MlsError> {
+ self.borrow_node(node_index).and_then(|n| n.as_parent())
+ }
+
+ pub fn borrow_as_parent_mut(&mut self, node_index: NodeIndex) -> Result<&mut Parent, MlsError> {
+ let index = self.validate_index(node_index)?;
+
+ self.get_mut(index)
+ .ok_or(MlsError::InvalidNodeIndex(node_index))?
+ .as_parent_mut()
+ }
+
+ pub fn borrow_as_leaf_mut(&mut self, index: LeafIndex) -> Result<&mut LeafNode, MlsError> {
+ let node_index = NodeIndex::from(index);
+ let index = self.validate_index(node_index)?;
+
+ self.get_mut(index)
+ .ok_or(MlsError::InvalidNodeIndex(node_index))?
+ .as_leaf_mut()
+ }
+
+ pub fn borrow_as_leaf(&self, index: LeafIndex) -> Result<&LeafNode, MlsError> {
+ let node_index = NodeIndex::from(index);
+ self.borrow_node(node_index).and_then(|n| n.as_leaf())
+ }
+
+ pub fn borrow_or_fill_node_as_parent(
+ &mut self,
+ node_index: NodeIndex,
+ public_key: &HpkePublicKey,
+ ) -> Result<&mut Parent, MlsError> {
+ let index = self.validate_index(node_index)?;
+
+ while self.len() <= index {
+ self.push(None);
+ }
+
+ self.get_mut(index)
+ .ok_or(MlsError::InvalidNodeIndex(node_index))
+ .and_then(|n| {
+ if n.is_none() {
+ *n = Parent {
+ public_key: public_key.clone(),
+ parent_hash: ParentHash::empty(),
+ unmerged_leaves: vec![],
+ }
+ .into();
+ }
+ n.as_parent_mut()
+ })
+ }
+
+ pub fn get_resolution_index(&self, index: NodeIndex) -> Result<Vec<NodeIndex>, MlsError> {
+ let mut indexes = vec![index];
+ let mut resolution = vec![];
+
+ while let Some(index) = indexes.pop() {
+ if let Some(Some(node)) = self.get(index as usize) {
+ resolution.push(index);
+
+ if let Node::Parent(p) = node {
+ resolution.extend(p.unmerged_leaves.iter().map(NodeIndex::from));
+ }
+ } else if !index.is_leaf() {
+ indexes.push(index.right_unchecked());
+ indexes.push(index.left_unchecked());
+ }
+ }
+
+ Ok(resolution)
+ }
+
+ pub fn find_in_resolution(
+ &self,
+ index: NodeIndex,
+ to_find: Option<NodeIndex>,
+ ) -> Option<usize> {
+ let mut indexes = vec![index];
+ let mut resolution_len = 0;
+
+ while let Some(index) = indexes.pop() {
+ if let Some(Some(node)) = self.get(index as usize) {
+ if Some(index) == to_find || to_find.is_none() {
+ return Some(resolution_len);
+ }
+
+ resolution_len += 1;
+
+ if let Node::Parent(p) = node {
+ indexes.extend(p.unmerged_leaves.iter().map(NodeIndex::from));
+ }
+ } else if !index.is_leaf() {
+ indexes.push(index.right_unchecked());
+ indexes.push(index.left_unchecked());
+ }
+ }
+
+ None
+ }
+
+ pub fn is_resolution_empty(&self, index: NodeIndex) -> bool {
+ self.find_in_resolution(index, None).is_none()
+ }
+
+ pub(crate) fn next_empty_leaf(&self, start: LeafIndex) -> LeafIndex {
+ let mut n = NodeIndex::from(start) as usize;
+
+ while n < self.len() {
+ if self.0[n].is_none() {
+ return LeafIndex((n as u32) >> 1);
+ }
+
+ n += 2;
+ }
+
+ LeafIndex((self.len() as u32 + 1) >> 1)
+ }
+
+ /// If `index` fits in the current tree, inserts `leaf` at `index`. Else, inserts `leaf` as the
+ /// last leaf
+ pub fn insert_leaf(&mut self, index: LeafIndex, leaf: LeafNode) {
+ let node_index = (*index as usize) << 1;
+
+ if node_index > self.len() {
+ self.push(None);
+ self.push(None);
+ } else if self.is_empty() {
+ self.push(None);
+ }
+
+ self.0[node_index] = Some(leaf.into());
+ }
+}
+
+#[cfg(test)]
+pub(crate) mod test_utils {
+ use super::*;
+ use crate::{
+ client::test_utils::TEST_CIPHER_SUITE, tree_kem::leaf_node::test_utils::get_basic_test_node,
+ };
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn get_test_node_vec() -> NodeVec {
+ let mut nodes = vec![None; 7];
+
+ nodes[0] = get_basic_test_node(TEST_CIPHER_SUITE, "A").await.into();
+ nodes[4] = get_basic_test_node(TEST_CIPHER_SUITE, "C").await.into();
+
+ nodes[5] = Parent {
+ public_key: b"CD".to_vec().into(),
+ parent_hash: ParentHash::empty(),
+ unmerged_leaves: vec![LeafIndex(2)],
+ }
+ .into();
+
+ nodes[6] = get_basic_test_node(TEST_CIPHER_SUITE, "D").await.into();
+
+ NodeVec::from(nodes)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::{
+ client::test_utils::TEST_CIPHER_SUITE,
+ tree_kem::{
+ leaf_node::test_utils::get_basic_test_node, node::test_utils::get_test_node_vec,
+ },
+ };
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn node_key_getters() {
+ let test_node_parent: Node = Parent {
+ public_key: b"pub".to_vec().into(),
+ parent_hash: ParentHash::empty(),
+ unmerged_leaves: vec![],
+ }
+ .into();
+
+ let test_leaf = get_basic_test_node(TEST_CIPHER_SUITE, "B").await;
+ let test_node_leaf: Node = test_leaf.clone().into();
+
+ assert_eq!(test_node_parent.public_key().as_ref(), b"pub");
+ assert_eq!(test_node_leaf.public_key(), &test_leaf.public_key);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_empty_leaves() {
+ let mut test_vec = get_test_node_vec().await;
+ let mut test_vec_clone = get_test_node_vec().await;
+ let empty_leaves: Vec<(LeafIndex, &mut Option<Node>)> = test_vec.empty_leaves().collect();
+ assert_eq!(
+ [(LeafIndex(1), &mut test_vec_clone[2])].as_ref(),
+ empty_leaves.as_slice()
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_direct_path() {
+ let test_vec = get_test_node_vec().await;
+ // Tree math is already tested in that module, just ensure equality
+ let expected = 0.direct_copath(&4);
+ let actual = test_vec.direct_copath(LeafIndex(0));
+ assert_eq!(actual, expected);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_filtered_direct_path_co_path() {
+ let test_vec = get_test_node_vec().await;
+ let expected = [true, false];
+ let actual = test_vec.filtered(LeafIndex(0)).unwrap();
+ assert_eq!(actual, expected);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_get_parent_node() {
+ let mut test_vec = get_test_node_vec().await;
+
+ // If the node is a leaf it should fail
+ assert!(test_vec.borrow_as_parent_mut(0).is_err());
+
+ // If the node index is out of range it should fail
+ assert!(test_vec
+ .borrow_as_parent_mut(test_vec.len() as u32)
+ .is_err());
+
+ // Otherwise it should succeed
+ let mut expected = Parent {
+ public_key: b"CD".to_vec().into(),
+ parent_hash: ParentHash::empty(),
+ unmerged_leaves: vec![LeafIndex(2)],
+ };
+
+ assert_eq!(test_vec.borrow_as_parent_mut(5).unwrap(), &mut expected);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_get_resolution() {
+ let test_vec = get_test_node_vec().await;
+
+ let resolution_node_5 = test_vec.get_resolution_index(5).unwrap();
+ let resolution_node_2 = test_vec.get_resolution_index(2).unwrap();
+ let resolution_node_3 = test_vec.get_resolution_index(3).unwrap();
+
+ assert_eq!(&resolution_node_5, &[5, 4]);
+ assert!(resolution_node_2.is_empty());
+ assert_eq!(&resolution_node_3, &[0, 5, 4]);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_get_or_fill_existing() {
+ let mut test_vec = get_test_node_vec().await;
+ let mut test_vec2 = test_vec.clone();
+
+ let expected = test_vec[5].as_parent_mut().unwrap();
+ let actual = test_vec2
+ .borrow_or_fill_node_as_parent(5, &Vec::new().into())
+ .unwrap();
+
+ assert_eq!(actual, expected);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_get_or_fill_empty() {
+ let mut test_vec = get_test_node_vec().await;
+
+ let mut expected = Parent {
+ public_key: vec![0u8; 4].into(),
+ parent_hash: ParentHash::empty(),
+ unmerged_leaves: vec![],
+ };
+
+ let actual = test_vec
+ .borrow_or_fill_node_as_parent(1, &vec![0u8; 4].into())
+ .unwrap();
+
+ assert_eq!(actual, &mut expected);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_leaf_count() {
+ let test_vec = get_test_node_vec().await;
+ assert_eq!(test_vec.len(), 7);
+ assert_eq!(test_vec.occupied_leaf_count(), 3);
+ assert_eq!(
+ test_vec.non_empty_leaves().count(),
+ test_vec.occupied_leaf_count() as usize
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_total_leaf_count() {
+ let test_vec = get_test_node_vec().await;
+ assert_eq!(test_vec.occupied_leaf_count(), 3);
+ assert_eq!(test_vec.total_leaf_count(), 4);
+ }
+}
diff --git a/src/tree_kem/parent_hash.rs b/src/tree_kem/parent_hash.rs
new file mode 100644
index 0000000..f04157a
--- /dev/null
+++ b/src/tree_kem/parent_hash.rs
@@ -0,0 +1,431 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use crate::client::MlsError;
+use crate::crypto::{CipherSuiteProvider, HpkePublicKey};
+use crate::tree_kem::math as tree_math;
+use crate::tree_kem::node::{LeafIndex, Node, NodeIndex};
+use crate::tree_kem::TreeKemPublic;
+use alloc::vec::Vec;
+use core::{
+ fmt::{self, Debug},
+ ops::Deref,
+};
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+use mls_rs_core::error::IntoAnyError;
+use tree_math::TreeIndex;
+
+use super::leaf_node::LeafNodeSource;
+
+#[cfg(feature = "std")]
+use std::collections::HashSet;
+
+#[cfg(not(feature = "std"))]
+use alloc::collections::BTreeSet;
+
+#[derive(Clone, Debug, MlsSize, MlsEncode)]
+struct ParentHashInput<'a> {
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ public_key: &'a HpkePublicKey,
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ parent_hash: &'a [u8],
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ original_sibling_tree_hash: &'a [u8],
+}
+
+#[derive(Clone, MlsSize, MlsEncode, MlsDecode, PartialEq, Eq)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub struct ParentHash(
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
+ Vec<u8>,
+);
+
+impl Debug for ParentHash {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ mls_rs_core::debug::pretty_bytes(&self.0)
+ .named("ParentHash")
+ .fmt(f)
+ }
+}
+
+impl From<Vec<u8>> for ParentHash {
+ fn from(v: Vec<u8>) -> Self {
+ Self(v)
+ }
+}
+
+impl Deref for ParentHash {
+ type Target = Vec<u8>;
+
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+impl ParentHash {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn new<P: CipherSuiteProvider>(
+ cipher_suite_provider: &P,
+ public_key: &HpkePublicKey,
+ parent_hash: &ParentHash,
+ original_sibling_tree_hash: &[u8],
+ ) -> Result<Self, MlsError> {
+ let input = ParentHashInput {
+ public_key,
+ parent_hash,
+ original_sibling_tree_hash,
+ };
+
+ let input_bytes = input.mls_encode_to_vec()?;
+
+ let hash = cipher_suite_provider
+ .hash(&input_bytes)
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
+
+ Ok(Self(hash))
+ }
+
+ pub fn empty() -> Self {
+ ParentHash(Vec::new())
+ }
+
+ pub fn matches(&self, hash: &ParentHash) -> bool {
+ //TODO: Constant time equals
+ hash == self
+ }
+}
+
+impl Node {
+ fn get_parent_hash(&self) -> Option<ParentHash> {
+ match self {
+ Node::Parent(p) => Some(p.parent_hash.clone()),
+ Node::Leaf(l) => match &l.leaf_node_source {
+ LeafNodeSource::Commit(parent_hash) => Some(parent_hash.clone()),
+ _ => None,
+ },
+ }
+ }
+}
+
+impl TreeKemPublic {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn parent_hash_for_leaf<P: CipherSuiteProvider>(
+ &mut self,
+ cipher_suite_provider: &P,
+ index: LeafIndex,
+ ) -> Result<ParentHash, MlsError> {
+ let mut hash = ParentHash::empty();
+
+ for node in self.nodes.direct_copath(index).into_iter().rev() {
+ if self.nodes.is_resolution_empty(node.copath) {
+ continue;
+ }
+
+ let parent = self.nodes.borrow_as_parent_mut(node.path)?;
+
+ let calculated = ParentHash::new(
+ cipher_suite_provider,
+ &parent.public_key,
+ &hash,
+ &self.tree_hashes.current[node.copath as usize],
+ )
+ .await?;
+
+ (parent.parent_hash, hash) = (hash, calculated);
+ }
+
+ Ok(hash)
+ }
+
+ // Updates all of the required parent hash values, and returns the calculated parent hash value for the leaf node
+ // If an update path is provided, additionally verify that the calculated parent hash matches
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn update_parent_hashes<P: CipherSuiteProvider>(
+ &mut self,
+ index: LeafIndex,
+ verify_leaf_hash: bool,
+ cipher_suite_provider: &P,
+ ) -> Result<(), MlsError> {
+ // First update the relevant original hashes used for parent hash computation.
+ self.update_hashes(&[index], cipher_suite_provider).await?;
+
+ let leaf_hash = self
+ .parent_hash_for_leaf(cipher_suite_provider, index)
+ .await?;
+
+ let leaf = self.nodes.borrow_as_leaf_mut(index)?;
+
+ if verify_leaf_hash {
+ // Verify the parent hash of the new sender leaf node and update the parent hash values
+ // in the local tree
+ if let LeafNodeSource::Commit(parent_hash) = &leaf.leaf_node_source {
+ if !leaf_hash.matches(parent_hash) {
+ return Err(MlsError::ParentHashMismatch);
+ }
+ } else {
+ return Err(MlsError::InvalidLeafNodeSource);
+ }
+ } else {
+ leaf.leaf_node_source = LeafNodeSource::Commit(leaf_hash);
+ }
+
+ // Update hashes after changes to the tree.
+ self.update_hashes(&[index], cipher_suite_provider).await
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(super) async fn validate_parent_hashes<P: CipherSuiteProvider>(
+ &self,
+ cipher_suite_provider: &P,
+ ) -> Result<(), MlsError> {
+ let original_hashes = self.compute_original_hashes(cipher_suite_provider).await?;
+
+ let nodes_to_validate = self
+ .nodes
+ .non_empty_parents()
+ .map(|(node_index, _)| node_index);
+
+ #[cfg(feature = "std")]
+ let mut nodes_to_validate = nodes_to_validate.collect::<HashSet<_>>();
+ #[cfg(not(feature = "std"))]
+ let mut nodes_to_validate = nodes_to_validate.collect::<BTreeSet<_>>();
+
+ let num_leaves = self.total_leaf_count();
+
+ // For each leaf l, validate all non-blank nodes on the chain from l up the tree.
+ for (leaf_index, _) in self.nodes.non_empty_leaves() {
+ let mut n = NodeIndex::from(leaf_index);
+
+ while let Some(mut ps) = n.parent_sibling(&num_leaves) {
+ // Find the first non-blank ancestor p of n and p's co-path child s.
+ while self.nodes.is_blank(ps.parent)? {
+ // If we reached the root, we're done with this chain.
+ let Some(ps_parent) = ps.parent.parent_sibling(&num_leaves) else {
+ return Ok(());
+ };
+
+ ps = ps_parent;
+ }
+
+ // Check is n's parent_hash field matches the parent hash of p with co-path child s.
+ let p_parent = self.nodes.borrow_as_parent(ps.parent)?;
+
+ let n_node = self
+ .nodes
+ .borrow_node(n)?
+ .as_ref()
+ .ok_or(MlsError::ExpectedNode)?;
+
+ let calculated = ParentHash::new(
+ cipher_suite_provider,
+ &p_parent.public_key,
+ &p_parent.parent_hash,
+ &original_hashes[ps.sibling as usize],
+ )
+ .await?;
+
+ if n_node.get_parent_hash() == Some(calculated) {
+ // Check that "n is in the resolution of c, and the intersection of p's unmerged_leaves with the subtree
+ // under c is equal to the resolution of c with n removed".
+ let Some(cp) = ps.sibling.parent_sibling(&num_leaves) else {
+ return Err(MlsError::ParentHashMismatch);
+ };
+
+ let c = cp.sibling;
+ let c_resolution = self.nodes.get_resolution_index(c)?.into_iter();
+
+ #[cfg(feature = "std")]
+ let mut c_resolution = c_resolution.collect::<HashSet<_>>();
+ #[cfg(not(feature = "std"))]
+ let mut c_resolution = c_resolution.collect::<BTreeSet<_>>();
+
+ let p_unmerged_in_c_subtree = self
+ .unmerged_in_subtree(ps.parent, c)?
+ .iter()
+ .copied()
+ .map(|x| *x * 2);
+
+ #[cfg(feature = "std")]
+ let p_unmerged_in_c_subtree = p_unmerged_in_c_subtree.collect::<HashSet<_>>();
+ #[cfg(not(feature = "std"))]
+ let p_unmerged_in_c_subtree = p_unmerged_in_c_subtree.collect::<BTreeSet<_>>();
+
+ if c_resolution.remove(&n)
+ && c_resolution == p_unmerged_in_c_subtree
+ && nodes_to_validate.remove(&ps.parent)
+ {
+ // If n's parent_hash field matches and p has not been validated yet, mark p as validated and continue.
+ n = ps.parent;
+ } else {
+ // If p is validated for the second time, the check fails ("all non-blank parent nodes are covered by exactly one such chain").
+ return Err(MlsError::ParentHashMismatch);
+ }
+ } else {
+ // If n's parent_hash field doesn't match, we're done with this chain.
+ break;
+ }
+ }
+ }
+
+ // The check passes iff all non-blank nodes are validated.
+ if nodes_to_validate.is_empty() {
+ Ok(())
+ } else {
+ Err(MlsError::ParentHashMismatch)
+ }
+ }
+}
+
+#[cfg(test)]
+pub(crate) mod test_utils {
+
+ use super::*;
+ use crate::{
+ cipher_suite::CipherSuite,
+ crypto::test_utils::test_cipher_suite_provider,
+ identity::basic::BasicIdentityProvider,
+ tree_kem::{leaf_node::test_utils::get_basic_test_node, node::Parent},
+ };
+
+ use alloc::vec;
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn test_parent(
+ cipher_suite: CipherSuite,
+ unmerged_leaves: Vec<LeafIndex>,
+ ) -> Parent {
+ let (_, public_key) = test_cipher_suite_provider(cipher_suite)
+ .kem_generate()
+ .await
+ .unwrap();
+
+ Parent {
+ public_key,
+ parent_hash: ParentHash::empty(),
+ unmerged_leaves,
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn test_parent_node(
+ cipher_suite: CipherSuite,
+ unmerged_leaves: Vec<LeafIndex>,
+ ) -> Node {
+ Node::Parent(test_parent(cipher_suite, unmerged_leaves).await)
+ }
+
+ // Create figure 12 from MLS RFC
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn get_test_tree_fig_12(cipher_suite: CipherSuite) -> TreeKemPublic {
+ let cipher_suite_provider = test_cipher_suite_provider(cipher_suite);
+
+ let mut tree = TreeKemPublic::new();
+
+ let mut leaves = Vec::new();
+
+ for l in ["A", "B", "C", "D", "E", "F", "G"] {
+ leaves.push(get_basic_test_node(cipher_suite, l).await);
+ }
+
+ tree.add_leaves(leaves, &BasicIdentityProvider, &cipher_suite_provider)
+ .await
+ .unwrap();
+
+ tree.nodes[1] = Some(test_parent_node(cipher_suite, vec![]).await);
+ tree.nodes[3] = Some(test_parent_node(cipher_suite, vec![LeafIndex(3)]).await);
+
+ tree.nodes[7] =
+ Some(test_parent_node(cipher_suite, vec![LeafIndex(3), LeafIndex(6)]).await);
+
+ tree.nodes[9] = Some(test_parent_node(cipher_suite, vec![LeafIndex(5)]).await);
+
+ tree.nodes[11] =
+ Some(test_parent_node(cipher_suite, vec![LeafIndex(5), LeafIndex(6)]).await);
+
+ tree.update_parent_hashes(LeafIndex(0), false, &cipher_suite_provider)
+ .await
+ .unwrap();
+
+ tree.update_parent_hashes(LeafIndex(4), false, &cipher_suite_provider)
+ .await
+ .unwrap();
+
+ tree
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::client::test_utils::TEST_CIPHER_SUITE;
+ use crate::crypto::test_utils::test_cipher_suite_provider;
+ use crate::tree_kem::leaf_node::test_utils::get_basic_test_node;
+ use crate::tree_kem::leaf_node::LeafNodeSource;
+ use crate::tree_kem::test_utils::TreeWithSigners;
+ use crate::tree_kem::MlsError;
+ use assert_matches::assert_matches;
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_missing_parent_hash() {
+ let cs = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+ let mut test_tree = TreeWithSigners::make_full_tree(8, &cs).await.tree;
+
+ *test_tree.nodes.borrow_as_leaf_mut(LeafIndex(0)).unwrap() =
+ get_basic_test_node(TEST_CIPHER_SUITE, "foo").await;
+
+ let missing_parent_hash_res = test_tree
+ .update_parent_hashes(
+ LeafIndex(0),
+ true,
+ &test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .await;
+
+ assert_matches!(
+ missing_parent_hash_res,
+ Err(MlsError::InvalidLeafNodeSource)
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_parent_hash_mismatch() {
+ let cs = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+ let mut test_tree = TreeWithSigners::make_full_tree(8, &cs).await.tree;
+
+ let unexpected_parent_hash = ParentHash::from(hex!("f00d"));
+
+ test_tree
+ .nodes
+ .borrow_as_leaf_mut(LeafIndex(0))
+ .unwrap()
+ .leaf_node_source = LeafNodeSource::Commit(unexpected_parent_hash);
+
+ let invalid_parent_hash_res = test_tree
+ .update_parent_hashes(
+ LeafIndex(0),
+ true,
+ &test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .await;
+
+ assert_matches!(invalid_parent_hash_res, Err(MlsError::ParentHashMismatch));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_parent_hash_invalid() {
+ let cs = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+ let mut test_tree = TreeWithSigners::make_full_tree(8, &cs).await.tree;
+
+ test_tree.nodes[2] = None;
+
+ let res = test_tree
+ .validate_parent_hashes(&test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .await;
+
+ assert_matches!(res, Err(MlsError::ParentHashMismatch));
+ }
+}
diff --git a/src/tree_kem/path_secret.rs b/src/tree_kem/path_secret.rs
new file mode 100644
index 0000000..c9fce76
--- /dev/null
+++ b/src/tree_kem/path_secret.rs
@@ -0,0 +1,265 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use crate::client::MlsError;
+use crate::crypto::{CipherSuiteProvider, HpkePublicKey, HpkeSecretKey};
+use crate::group::key_schedule::kdf_derive_secret;
+use alloc::vec;
+use alloc::vec::Vec;
+use core::{
+ fmt::{self, Debug},
+ ops::Deref,
+};
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+use mls_rs_core::error::IntoAnyError;
+use zeroize::Zeroizing;
+
+use super::hpke_encryption::HpkeEncryptable;
+
+#[derive(Clone, Eq, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub struct PathSecret(
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))]
+ Zeroizing<Vec<u8>>,
+);
+
+impl Debug for PathSecret {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ mls_rs_core::debug::pretty_bytes(&self.0)
+ .named("PathSecret")
+ .fmt(f)
+ }
+}
+
+impl Deref for PathSecret {
+ type Target = Vec<u8>;
+
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+impl From<Vec<u8>> for PathSecret {
+ fn from(data: Vec<u8>) -> Self {
+ PathSecret(Zeroizing::new(data))
+ }
+}
+
+impl From<Zeroizing<Vec<u8>>> for PathSecret {
+ fn from(data: Zeroizing<Vec<u8>>) -> Self {
+ PathSecret(data)
+ }
+}
+
+impl PathSecret {
+ pub fn random<P: CipherSuiteProvider>(
+ cipher_suite_provider: &P,
+ ) -> Result<PathSecret, MlsError> {
+ cipher_suite_provider
+ .random_bytes_vec(cipher_suite_provider.kdf_extract_size())
+ .map(Into::into)
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
+ }
+
+ pub fn empty<P: CipherSuiteProvider>(cipher_suite_provider: &P) -> Self {
+ // Define commit_secret as the all-zero vector of the same length as a path_secret
+ PathSecret::from(vec![0u8; cipher_suite_provider.kdf_extract_size()])
+ }
+}
+
+impl HpkeEncryptable for PathSecret {
+ const ENCRYPT_LABEL: &'static str = "UpdatePathNode";
+
+ fn from_bytes(bytes: Vec<u8>) -> Result<Self, MlsError> {
+ Ok(Self(Zeroizing::new(bytes)))
+ }
+
+ fn get_bytes(&self) -> Result<Vec<u8>, MlsError> {
+ Ok(self.to_vec())
+ }
+}
+
+impl PathSecret {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn to_hpke_key_pair<P: CipherSuiteProvider>(
+ &self,
+ cs: &P,
+ ) -> Result<(HpkeSecretKey, HpkePublicKey), MlsError> {
+ let node_secret = Zeroizing::new(kdf_derive_secret(cs, self, b"node").await?);
+
+ cs.kem_derive(&node_secret)
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
+ }
+}
+
+#[derive(Clone, Debug)]
+pub struct PathSecretGenerator<'a, P> {
+ cipher_suite_provider: &'a P,
+ last: Option<PathSecret>,
+ starting_with: Option<PathSecret>,
+}
+
+impl<'a, P: CipherSuiteProvider> PathSecretGenerator<'a, P> {
+ pub fn new(cipher_suite_provider: &'a P) -> Self {
+ Self {
+ cipher_suite_provider,
+ last: None,
+ starting_with: None,
+ }
+ }
+
+ pub fn starting_with(cipher_suite_provider: &'a P, secret: PathSecret) -> Self {
+ Self {
+ starting_with: Some(secret),
+ ..Self::new(cipher_suite_provider)
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn next_secret(&mut self) -> Result<PathSecret, MlsError> {
+ let secret = if let Some(starting_with) = self.starting_with.take() {
+ Ok(starting_with)
+ } else if let Some(last) = self.last.take() {
+ kdf_derive_secret(self.cipher_suite_provider, &last, b"path")
+ .await
+ .map(PathSecret::from)
+ } else {
+ PathSecret::random(self.cipher_suite_provider)
+ }?;
+
+ self.last = Some(secret.clone());
+
+ Ok(secret)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use crate::{
+ cipher_suite::CipherSuite,
+ client::test_utils::TEST_CIPHER_SUITE,
+ crypto::test_utils::{
+ test_cipher_suite_provider, try_test_cipher_suite_provider, TestCryptoProvider,
+ },
+ };
+
+ use super::*;
+
+ use alloc::string::String;
+
+ #[cfg(target_arch = "wasm32")]
+ use wasm_bindgen_test::wasm_bindgen_test as test;
+
+ #[derive(serde::Deserialize, serde::Serialize)]
+ struct TestCase {
+ cipher_suite: u16,
+ generations: Vec<String>,
+ }
+
+ impl TestCase {
+ #[cfg(not(mls_build_async))]
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn generate() -> Vec<TestCase> {
+ CipherSuite::all()
+ .map(
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ |cipher_suite| {
+ let cs_provider = test_cipher_suite_provider(cipher_suite);
+ let mut generator = PathSecretGenerator::new(&cs_provider);
+
+ let generations = (0..10)
+ .map(|_| hex::encode(&*generator.next_secret().unwrap()))
+ .collect();
+
+ TestCase {
+ cipher_suite: cipher_suite.into(),
+ generations,
+ }
+ },
+ )
+ .collect()
+ }
+
+ #[cfg(mls_build_async)]
+ fn generate() -> Vec<TestCase> {
+ panic!("Tests cannot be generated in async mode");
+ }
+ }
+
+ fn load_test_cases() -> Vec<TestCase> {
+ load_test_case_json!(path_secret, TestCase::generate())
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_path_secret_generation() {
+ let cases = load_test_cases();
+
+ for test_case in cases {
+ let Some(cs_provider) = try_test_cipher_suite_provider(test_case.cipher_suite) else {
+ continue;
+ };
+
+ let first_secret = PathSecret::from(hex::decode(&test_case.generations[0]).unwrap());
+ let mut generator = PathSecretGenerator::starting_with(&cs_provider, first_secret);
+
+ for expected in &test_case.generations {
+ let generated = hex::encode(&*generator.next_secret().await.unwrap());
+ assert_eq!(expected, &generated);
+ }
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_first_path_is_random() {
+ let cs_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let mut generator = PathSecretGenerator::new(&cs_provider);
+ let first_secret = generator.next_secret().await.unwrap();
+
+ for _ in 0..100 {
+ let mut next_generator = PathSecretGenerator::new(&cs_provider);
+ let next_secret = next_generator.next_secret().await.unwrap();
+ assert_ne!(first_secret, next_secret);
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_starting_with() {
+ let cs_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+ let secret = PathSecret::random(&cs_provider).unwrap();
+
+ let mut generator = PathSecretGenerator::starting_with(&cs_provider, secret.clone());
+
+ let first_secret = generator.next_secret().await.unwrap();
+ let second_secret = generator.next_secret().await.unwrap();
+
+ assert_eq!(secret, first_secret);
+ assert_ne!(first_secret, second_secret);
+ }
+
+ #[test]
+ fn test_empty_path_secret() {
+ for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
+ let cs_provider = test_cipher_suite_provider(cipher_suite);
+ let empty = PathSecret::empty(&cs_provider);
+ assert_eq!(
+ empty,
+ PathSecret::from(vec![0u8; cs_provider.kdf_extract_size()])
+ )
+ }
+ }
+
+ #[test]
+ fn test_random_path_secret() {
+ let cs_provider = test_cipher_suite_provider(CipherSuite::P256_AES128);
+ let initial = PathSecret::random(&cs_provider).unwrap();
+
+ for _ in 0..100 {
+ let next = PathSecret::random(&cs_provider).unwrap();
+ assert_ne!(next, initial);
+ }
+ }
+}
diff --git a/src/tree_kem/private.rs b/src/tree_kem/private.rs
new file mode 100644
index 0000000..1cc72ee
--- /dev/null
+++ b/src/tree_kem/private.rs
@@ -0,0 +1,310 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+use alloc::{vec, vec::Vec};
+
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+use mls_rs_core::crypto::HpkeSecretKey;
+
+use crate::{client::MlsError, crypto::CipherSuiteProvider};
+
+use super::{
+ math::leaf_lca_level,
+ node::LeafIndex,
+ path_secret::{PathSecret, PathSecretGenerator},
+ TreeKemPublic,
+};
+
+#[derive(Clone, Debug, MlsEncode, MlsDecode, MlsSize, Eq, PartialEq)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+#[non_exhaustive]
+pub struct TreeKemPrivate {
+ pub self_index: LeafIndex,
+ pub secret_keys: Vec<Option<HpkeSecretKey>>,
+}
+
+impl TreeKemPrivate {
+ pub fn new_self_leaf(self_index: LeafIndex, leaf_secret: HpkeSecretKey) -> Self {
+ TreeKemPrivate {
+ self_index,
+ secret_keys: vec![Some(leaf_secret)],
+ }
+ }
+
+ pub fn new_for_external() -> Self {
+ TreeKemPrivate {
+ self_index: LeafIndex(0),
+ secret_keys: Default::default(),
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn update_secrets<P: CipherSuiteProvider>(
+ &mut self,
+ cipher_suite_provider: &P,
+ signer_index: LeafIndex,
+ path_secret: PathSecret,
+ public_tree: &TreeKemPublic,
+ ) -> Result<(), MlsError> {
+ // Identify the lowest common
+ // ancestor of the leaves at index and at GroupInfo.signer_index. Set the private key
+ // for this node to the private key derived from the path_secret.
+ let lca_index = leaf_lca_level(self.self_index.into(), signer_index.into()) as usize - 2;
+
+ // For each parent of the common ancestor, up to the root of the tree, derive a new
+ // path secret and set the private key for the node to the private key derived from the
+ // path secret. The private key MUST be the private key that corresponds to the public
+ // key in the node.
+
+ let mut node_secret_gen =
+ PathSecretGenerator::starting_with(cipher_suite_provider, path_secret);
+
+ let path = public_tree.nodes.direct_copath(self.self_index);
+ let filtered = &public_tree.nodes.filtered(self.self_index)?;
+ self.secret_keys.resize(path.len() + 1, None);
+
+ for (i, (n, f)) in path.iter().zip(filtered).enumerate().skip(lca_index) {
+ if *f {
+ continue;
+ }
+
+ let secret = node_secret_gen.next_secret().await?;
+
+ let expected_pub_key = public_tree
+ .nodes
+ .borrow_node(n.path)?
+ .as_ref()
+ .map(|n| n.public_key())
+ .ok_or(MlsError::PubKeyMismatch)?;
+
+ let (secret_key, public_key) = secret.to_hpke_key_pair(cipher_suite_provider).await?;
+
+ if expected_pub_key != &public_key {
+ return Err(MlsError::PubKeyMismatch);
+ }
+
+ // It's ok to use index directly because of the resize above
+ self.secret_keys[i + 1] = Some(secret_key);
+ }
+
+ Ok(())
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ pub fn update_leaf(&mut self, new_leaf: HpkeSecretKey) {
+ self.secret_keys = vec![None; self.secret_keys.len()];
+ self.secret_keys[0] = Some(new_leaf);
+ }
+}
+
+#[cfg(test)]
+impl TreeKemPrivate {
+ pub fn new(self_index: LeafIndex) -> Self {
+ TreeKemPrivate {
+ self_index,
+ secret_keys: Default::default(),
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use assert_matches::assert_matches;
+
+ use crate::{
+ cipher_suite::CipherSuite,
+ client::test_utils::TEST_CIPHER_SUITE,
+ crypto::test_utils::test_cipher_suite_provider,
+ group::test_utils::{get_test_group_context, random_bytes},
+ identity::basic::BasicIdentityProvider,
+ tree_kem::{
+ kem::TreeKem,
+ leaf_node::test_utils::{
+ default_properties, get_basic_test_node, get_basic_test_node_sig_key,
+ },
+ math::TreeIndex,
+ node::LeafIndex,
+ },
+ };
+
+ use super::*;
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn random_hpke_secret_key() -> HpkeSecretKey {
+ let (secret, _) = test_cipher_suite_provider(TEST_CIPHER_SUITE)
+ .kem_derive(&random_bytes(32))
+ .await
+ .unwrap();
+
+ secret
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_create_self_leaf() {
+ let secret = random_hpke_secret_key().await;
+
+ let self_index = LeafIndex(42);
+
+ let private_key = TreeKemPrivate::new_self_leaf(self_index, secret.clone());
+
+ assert_eq!(private_key.self_index, self_index);
+ assert_eq!(private_key.secret_keys.len(), 1);
+ assert_eq!(private_key.secret_keys[0].as_ref().unwrap(), &secret)
+ }
+
+ // Create a ratchet tree for Alice, Bob and Charlie. Alice generates an update path for
+ // Charlie. Return (Public Tree, Charlie's private key, update path, path secret)
+ // The ratchet tree returned has leaf indexes as [alice, bob, charlie]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn update_secrets_setup(
+ cipher_suite: CipherSuite,
+ ) -> (TreeKemPublic, TreeKemPrivate, TreeKemPrivate, PathSecret) {
+ let cipher_suite_provider = test_cipher_suite_provider(cipher_suite);
+
+ let (alice_leaf, alice_hpke_secret, alice_signing) =
+ get_basic_test_node_sig_key(cipher_suite, "alice").await;
+
+ let bob_leaf = get_basic_test_node(cipher_suite, "bob").await;
+
+ let (charlie_leaf, charlie_hpke_secret, _charlie_signing) =
+ get_basic_test_node_sig_key(cipher_suite, "charlie").await;
+
+ // Create a new public tree with Alice
+ let (mut public_tree, mut alice_private) = TreeKemPublic::derive(
+ alice_leaf,
+ alice_hpke_secret,
+ &BasicIdentityProvider,
+ &Default::default(),
+ )
+ .await
+ .unwrap();
+
+ // Add bob and charlie to the tree
+ public_tree
+ .add_leaves(
+ vec![bob_leaf, charlie_leaf],
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap();
+
+ // Alice's secret key is longer now
+ alice_private.secret_keys.resize(3, None);
+
+ // Generate an update path for Alice
+ let encap_gen = TreeKem::new(&mut public_tree, &mut alice_private)
+ .encap(
+ &mut get_test_group_context(42, cipher_suite).await,
+ &[],
+ &alice_signing,
+ default_properties(),
+ None,
+ &cipher_suite_provider,
+ #[cfg(test)]
+ &Default::default(),
+ )
+ .await
+ .unwrap();
+
+ // Get a path secret from Alice for Charlie
+ let path_secret = encap_gen.path_secrets[1].clone().unwrap();
+
+ // Private key for Charlie
+ let charlie_private = TreeKemPrivate::new_self_leaf(LeafIndex(2), charlie_hpke_secret);
+
+ (public_tree, charlie_private, alice_private, path_secret)
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_update_secrets() {
+ let cipher_suite = TEST_CIPHER_SUITE;
+
+ let (public_tree, mut charlie_private, alice_private, path_secret) =
+ update_secrets_setup(cipher_suite).await;
+
+ let existing_private = charlie_private.secret_keys.first().cloned().unwrap();
+
+ // Add the secrets for Charlie to his private key
+ charlie_private
+ .update_secrets(
+ &test_cipher_suite_provider(cipher_suite),
+ LeafIndex(0),
+ path_secret,
+ &public_tree,
+ )
+ .await
+ .unwrap();
+
+ // Make sure that Charlie's private key didn't lose keys
+ assert_eq!(charlie_private.secret_keys.len(), 3);
+
+ // Check that the intersection of the secret keys of Alice and Charlie matches.
+ // The intersection contains only the root.
+ assert_eq!(alice_private.secret_keys[2], charlie_private.secret_keys[2]);
+
+ assert_eq!(
+ charlie_private.secret_keys[0].as_ref(),
+ existing_private.as_ref()
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_update_secrets_key_mismatch() {
+ let cipher_suite = TEST_CIPHER_SUITE;
+
+ let (mut public_tree, mut charlie_private, _, path_secret) =
+ update_secrets_setup(cipher_suite).await;
+
+ // Sabotage the public tree
+ public_tree
+ .nodes
+ .borrow_as_parent_mut(public_tree.total_leaf_count().root())
+ .unwrap()
+ .public_key = random_bytes(32).into();
+
+ // Add the secrets for Charlie to his private key
+ let res = charlie_private
+ .update_secrets(
+ &test_cipher_suite_provider(cipher_suite),
+ LeafIndex(0),
+ path_secret,
+ &public_tree,
+ )
+ .await;
+
+ assert_matches!(res, Err(MlsError::PubKeyMismatch));
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn setup_direct_path(self_index: LeafIndex, leaf_count: u32) -> TreeKemPrivate {
+ let secret = random_hpke_secret_key().await;
+
+ let mut private_key = TreeKemPrivate::new_self_leaf(self_index, secret.clone());
+
+ private_key.secret_keys = (0..0.direct_copath(&leaf_count).len() + 1)
+ .map(|_| Some(secret.clone()))
+ .collect();
+
+ private_key
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_update_leaf() {
+ let self_leaf = LeafIndex(42);
+ let mut private_key = setup_direct_path(self_leaf, 128).await;
+
+ let new_secret = random_hpke_secret_key().await;
+
+ private_key.update_leaf(new_secret.clone());
+
+ // The update operation should have removed all the other keys in our direct path we
+ // previously added
+ assert!(private_key.secret_keys.iter().skip(1).all(|n| n.is_none()));
+
+ // The secret key for our leaf should have been updated accordingly
+ assert_eq!(private_key.secret_keys.first().unwrap(), &Some(new_secret));
+ }
+}
diff --git a/src/tree_kem/tree_hash.rs b/src/tree_kem/tree_hash.rs
new file mode 100644
index 0000000..d9115e3
--- /dev/null
+++ b/src/tree_kem/tree_hash.rs
@@ -0,0 +1,432 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use super::leaf_node::LeafNode;
+use super::node::{LeafIndex, NodeVec};
+use super::tree_math::BfsIterTopDown;
+use crate::client::MlsError;
+use crate::crypto::CipherSuiteProvider;
+use crate::tree_kem::math as tree_math;
+use crate::tree_kem::node::Parent;
+use crate::tree_kem::TreeKemPublic;
+use alloc::collections::VecDeque;
+use alloc::vec;
+use alloc::vec::Vec;
+use core::fmt::{self, Debug};
+use itertools::Itertools;
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+use mls_rs_core::error::IntoAnyError;
+use tree_math::TreeIndex;
+
+use core::ops::Deref;
+
+#[derive(Clone, Default, MlsSize, MlsEncode, MlsDecode, PartialEq)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub(crate) struct TreeHash(
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
+ Vec<u8>,
+);
+
+impl Debug for TreeHash {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ mls_rs_core::debug::pretty_bytes(&self.0)
+ .named("TreeHash")
+ .fmt(f)
+ }
+}
+
+impl Deref for TreeHash {
+ type Target = [u8];
+
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+#[derive(Clone, Debug, Default, MlsSize, MlsEncode, MlsDecode, PartialEq)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub(crate) struct TreeHashes {
+ pub current: Vec<TreeHash>,
+}
+
+#[derive(Debug, MlsSize, MlsEncode)]
+struct LeafNodeHashInput<'a> {
+ leaf_index: LeafIndex,
+ leaf_node: Option<&'a LeafNode>,
+}
+
+#[derive(Debug, MlsSize, MlsEncode)]
+struct ParentNodeTreeHashInput<'a> {
+ parent_node: Option<&'a Parent>,
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ left_hash: &'a [u8],
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ right_hash: &'a [u8],
+}
+
+#[derive(Debug, MlsSize, MlsEncode)]
+#[repr(u8)]
+enum TreeHashInput<'a> {
+ Leaf(LeafNodeHashInput<'a>) = 1u8,
+ Parent(ParentNodeTreeHashInput<'a>) = 2u8,
+}
+
+impl TreeKemPublic {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ #[inline(never)]
+ pub async fn tree_hash<P: CipherSuiteProvider>(
+ &mut self,
+ cipher_suite_provider: &P,
+ ) -> Result<Vec<u8>, MlsError> {
+ self.initialize_hashes(cipher_suite_provider).await?;
+ let root = self.total_leaf_count().root();
+ Ok(self.tree_hashes.current[root as usize].to_vec())
+ }
+
+ // Update hashes after `committer` makes changes to the tree. `path_blank` is the
+ // list of leaves whose paths were blanked, i.e. updates and removes.
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn update_hashes<P: CipherSuiteProvider>(
+ &mut self,
+ updated_leaves: &[LeafIndex],
+ cipher_suite_provider: &P,
+ ) -> Result<(), MlsError> {
+ let num_leaves = self.total_leaf_count();
+
+ let trailing_blanks = (0..num_leaves)
+ .rev()
+ .map_while(|l| {
+ self.tree_hashes
+ .current
+ .get(2 * l as usize)
+ .is_none()
+ .then_some(LeafIndex(l))
+ })
+ .collect::<Vec<_>>();
+
+ // Update the current hashes for direct paths of all modified leaves.
+ tree_hash(
+ &mut self.tree_hashes.current,
+ &self.nodes,
+ Some([updated_leaves, &trailing_blanks].concat()),
+ &[],
+ num_leaves,
+ cipher_suite_provider,
+ )
+ .await?;
+
+ Ok(())
+ }
+
+ // Initialize all hashes after creating / importing a tree.
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn initialize_hashes<P>(&mut self, cipher_suite_provider: &P) -> Result<(), MlsError>
+ where
+ P: CipherSuiteProvider,
+ {
+ if self.tree_hashes.current.is_empty() {
+ let num_leaves = self.total_leaf_count();
+
+ tree_hash(
+ &mut self.tree_hashes.current,
+ &self.nodes,
+ None,
+ &[],
+ num_leaves,
+ cipher_suite_provider,
+ )
+ .await?;
+ }
+
+ Ok(())
+ }
+
+ pub(crate) fn unmerged_in_subtree(
+ &self,
+ node_unmerged: u32,
+ subtree_root: u32,
+ ) -> Result<&[LeafIndex], MlsError> {
+ let unmerged = &self.nodes.borrow_as_parent(node_unmerged)?.unmerged_leaves;
+ let (left, right) = tree_math::subtree(subtree_root);
+ let mut start = 0;
+ while start < unmerged.len() && unmerged[start] < left {
+ start += 1;
+ }
+ let mut end = start;
+ while end < unmerged.len() && unmerged[end] < right {
+ end += 1;
+ }
+ Ok(&unmerged[start..end])
+ }
+
+ fn different_unmerged(&self, ancestor: u32, descendant: u32) -> Result<bool, MlsError> {
+ Ok(!self.nodes.is_blank(ancestor)?
+ && !self.nodes.is_blank(descendant)?
+ && self.unmerged_in_subtree(ancestor, descendant)?
+ != self.nodes.borrow_as_parent(descendant)?.unmerged_leaves)
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn compute_original_hashes<P: CipherSuiteProvider>(
+ &self,
+ cipher_suite: &P,
+ ) -> Result<Vec<TreeHash>, MlsError> {
+ let num_leaves = self.nodes.total_leaf_count() as usize;
+ let root = (num_leaves as u32).root();
+
+ // The value `filtered_sets[n]` is a list of all ancestors `a` of `n` s.t. we have to compute
+ // the tree hash of `n` with the unmerged leaves of `a` filtered out.
+ let mut filtered_sets = vec![vec![]; num_leaves * 2 - 1];
+ filtered_sets[root as usize].push(root);
+ let mut tree_hashes = vec![vec![]; num_leaves * 2 - 1];
+
+ let bfs_iter = BfsIterTopDown::new(num_leaves).skip(1);
+
+ for n in bfs_iter {
+ let Some(ps) = (n as u32).parent_sibling(&(num_leaves as u32)) else {
+ break;
+ };
+
+ let p = ps.parent;
+ filtered_sets[n] = filtered_sets[p as usize].clone();
+
+ if self.different_unmerged(*filtered_sets[p as usize].last().unwrap(), p)? {
+ filtered_sets[n].push(p);
+
+ // Compute tree hash of `n` without unmerged leaves of `p`. This also computes the tree hash
+ // for any descendants of `n` added to `filtered_sets` later via `clone`.
+ let (start_leaf, end_leaf) = tree_math::subtree(n as u32);
+
+ tree_hash(
+ &mut tree_hashes[p as usize],
+ &self.nodes,
+ Some((*start_leaf..*end_leaf).map(LeafIndex).collect_vec()),
+ &self.nodes.borrow_as_parent(p)?.unmerged_leaves,
+ num_leaves as u32,
+ cipher_suite,
+ )
+ .await?;
+ }
+ }
+
+ // Set the `original_hashes` based on the computed `hashes`.
+ let mut original_hashes = vec![TreeHash::default(); num_leaves * 2 - 1];
+
+ // If root has unmerged leaves, we recompute it's original hash. Else, we can use the current hash.
+ let root_original = if !self.nodes.is_blank(root)? && !self.nodes.is_leaf(root) {
+ let root_unmerged = &self.nodes.borrow_as_parent(root)?.unmerged_leaves;
+
+ if !root_unmerged.is_empty() {
+ let mut hashes = vec![];
+
+ tree_hash(
+ &mut hashes,
+ &self.nodes,
+ None,
+ root_unmerged,
+ num_leaves as u32,
+ cipher_suite,
+ )
+ .await?;
+
+ Some(hashes)
+ } else {
+ None
+ }
+ } else {
+ None
+ };
+
+ for (i, hash) in original_hashes.iter_mut().enumerate() {
+ let a = filtered_sets[i].last().unwrap();
+ *hash = if self.nodes.is_blank(*a)? || a == &root {
+ if let Some(root_original) = &root_original {
+ root_original[i].clone()
+ } else {
+ self.tree_hashes.current[i].clone()
+ }
+ } else {
+ tree_hashes[*a as usize][i].clone()
+ }
+ }
+
+ Ok(original_hashes)
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+async fn tree_hash<P: CipherSuiteProvider>(
+ hashes: &mut Vec<TreeHash>,
+ nodes: &NodeVec,
+ leaves_to_update: Option<Vec<LeafIndex>>,
+ filtered_leaves: &[LeafIndex],
+ num_leaves: u32,
+ cipher_suite_provider: &P,
+) -> Result<(), MlsError> {
+ let leaves_to_update =
+ leaves_to_update.unwrap_or_else(|| (0..num_leaves).map(LeafIndex).collect::<Vec<_>>());
+
+ // Resize the array in case the tree was extended or truncated
+ hashes.resize(num_leaves as usize * 2 - 1, TreeHash::default());
+
+ let mut node_queue = VecDeque::with_capacity(leaves_to_update.len());
+
+ for l in leaves_to_update.iter().filter(|l| ***l < num_leaves) {
+ let leaf = (!filtered_leaves.contains(l))
+ .then_some(nodes.borrow_as_leaf(*l).ok())
+ .flatten();
+
+ hashes[2 * **l as usize] = TreeHash(hash_for_leaf(*l, leaf, cipher_suite_provider).await?);
+
+ if let Some(ps) = (2 * **l).parent_sibling(&num_leaves) {
+ node_queue.push_back(ps.parent);
+ }
+ }
+
+ while let Some(n) = node_queue.pop_front() {
+ let hash = TreeHash(
+ hash_for_parent(
+ nodes.borrow_as_parent(n).ok(),
+ cipher_suite_provider,
+ filtered_leaves,
+ &hashes[n.left_unchecked() as usize],
+ &hashes[n.right_unchecked() as usize],
+ )
+ .await?,
+ );
+
+ hashes[n as usize] = hash;
+
+ if let Some(ps) = n.parent_sibling(&num_leaves) {
+ node_queue.push_back(ps.parent);
+ }
+ }
+
+ Ok(())
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+async fn hash_for_leaf<P: CipherSuiteProvider>(
+ leaf_index: LeafIndex,
+ leaf_node: Option<&LeafNode>,
+ cipher_suite_provider: &P,
+) -> Result<Vec<u8>, MlsError> {
+ let input = TreeHashInput::Leaf(LeafNodeHashInput {
+ leaf_index,
+ leaf_node,
+ });
+
+ cipher_suite_provider
+ .hash(&input.mls_encode_to_vec()?)
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+async fn hash_for_parent<P: CipherSuiteProvider>(
+ parent_node: Option<&Parent>,
+ cipher_suite_provider: &P,
+ filtered: &[LeafIndex],
+ left_hash: &[u8],
+ right_hash: &[u8],
+) -> Result<Vec<u8>, MlsError> {
+ let mut parent_node = parent_node.cloned();
+
+ if let Some(ref mut parent_node) = parent_node {
+ parent_node
+ .unmerged_leaves
+ .retain(|unmerged_index| !filtered.contains(unmerged_index));
+ }
+
+ let input = TreeHashInput::Parent(ParentNodeTreeHashInput {
+ parent_node: parent_node.as_ref(),
+ left_hash,
+ right_hash,
+ });
+
+ cipher_suite_provider
+ .hash(&input.mls_encode_to_vec()?)
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
+}
+
+#[cfg(test)]
+mod tests {
+ use mls_rs_codec::MlsDecode;
+
+ use crate::{
+ cipher_suite::CipherSuite,
+ crypto::test_utils::{test_cipher_suite_provider, try_test_cipher_suite_provider},
+ identity::basic::BasicIdentityProvider,
+ tree_kem::{node::NodeVec, parent_hash::test_utils::get_test_tree_fig_12},
+ };
+
+ use super::*;
+
+ #[derive(serde::Deserialize, serde::Serialize)]
+ struct TestCase {
+ cipher_suite: u16,
+ #[serde(with = "hex::serde")]
+ tree_data: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ tree_hash: Vec<u8>,
+ }
+
+ impl TestCase {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ async fn generate() -> Vec<TestCase> {
+ let mut test_cases = Vec::new();
+
+ for cipher_suite in CipherSuite::all() {
+ let mut tree = get_test_tree_fig_12(cipher_suite).await;
+
+ test_cases.push(TestCase {
+ cipher_suite: cipher_suite.into(),
+ tree_data: tree.nodes.mls_encode_to_vec().unwrap(),
+ tree_hash: tree
+ .tree_hash(&test_cipher_suite_provider(cipher_suite))
+ .await
+ .unwrap(),
+ })
+ }
+
+ test_cases
+ }
+ }
+
+ #[cfg(mls_build_async)]
+ async fn load_test_cases() -> Vec<TestCase> {
+ load_test_case_json!(tree_hash, TestCase::generate().await)
+ }
+
+ #[cfg(not(mls_build_async))]
+ fn load_test_cases() -> Vec<TestCase> {
+ load_test_case_json!(tree_hash, TestCase::generate())
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_tree_hash() {
+ let cases = load_test_cases().await;
+
+ for one_case in cases {
+ let Some(cs_provider) = try_test_cipher_suite_provider(one_case.cipher_suite) else {
+ continue;
+ };
+
+ let mut tree = TreeKemPublic::import_node_data(
+ NodeVec::mls_decode(&mut &*one_case.tree_data).unwrap(),
+ &BasicIdentityProvider,
+ &Default::default(),
+ )
+ .await
+ .unwrap();
+
+ let calculated_hash = tree.tree_hash(&cs_provider).await.unwrap();
+
+ assert_eq!(calculated_hash, one_case.tree_hash);
+ }
+ }
+}
diff --git a/src/tree_kem/tree_index.rs b/src/tree_kem/tree_index.rs
new file mode 100644
index 0000000..4e6731a
--- /dev/null
+++ b/src/tree_kem/tree_index.rs
@@ -0,0 +1,505 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use super::*;
+#[cfg(feature = "tree_index")]
+use core::fmt::{self, Debug};
+
+#[cfg(all(feature = "tree_index", feature = "custom_proposal"))]
+use crate::group::proposal::ProposalType;
+
+#[cfg(feature = "tree_index")]
+use crate::identity::CredentialType;
+
+#[cfg(feature = "tree_index")]
+use mls_rs_core::crypto::SignaturePublicKey;
+
+#[cfg(all(feature = "tree_index", feature = "std"))]
+use itertools::Itertools;
+
+#[cfg(all(feature = "tree_index", not(feature = "std")))]
+use alloc::collections::{btree_map::Entry, BTreeMap};
+
+#[cfg(all(feature = "tree_index", feature = "std"))]
+use std::collections::{hash_map::Entry, HashMap};
+
+#[cfg(all(feature = "tree_index", not(feature = "std")))]
+use alloc::collections::BTreeSet;
+
+#[cfg(feature = "tree_index")]
+use mls_rs_core::crypto::HpkePublicKey;
+
+#[cfg(feature = "tree_index")]
+#[derive(Clone, Default, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode, Hash, PartialOrd, Ord)]
+pub struct Identifier(#[mls_codec(with = "mls_rs_codec::byte_vec")] Vec<u8>);
+
+#[cfg(feature = "tree_index")]
+impl Debug for Identifier {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ mls_rs_core::debug::pretty_bytes(&self.0)
+ .named("Identifier")
+ .fmt(f)
+ }
+}
+
+#[cfg(all(feature = "tree_index", feature = "std"))]
+#[derive(Clone, Debug, Default, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+pub struct TreeIndex {
+ credential_signature_key: HashMap<SignaturePublicKey, LeafIndex>,
+ hpke_key: HashMap<HpkePublicKey, LeafIndex>,
+ identities: HashMap<Identifier, LeafIndex>,
+ credential_type_counters: HashMap<CredentialType, TypeCounter>,
+ #[cfg(feature = "custom_proposal")]
+ proposal_type_counter: HashMap<ProposalType, u32>,
+}
+
+#[cfg(all(feature = "tree_index", not(feature = "std")))]
+#[derive(Clone, Debug, Default, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+pub struct TreeIndex {
+ credential_signature_key: BTreeMap<SignaturePublicKey, LeafIndex>,
+ hpke_key: BTreeMap<HpkePublicKey, LeafIndex>,
+ identities: BTreeMap<Identifier, LeafIndex>,
+ credential_type_counters: BTreeMap<CredentialType, TypeCounter>,
+ #[cfg(feature = "custom_proposal")]
+ proposal_type_counter: BTreeMap<ProposalType, u32>,
+}
+
+#[cfg(feature = "tree_index")]
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(super) async fn index_insert<I: IdentityProvider>(
+ tree_index: &mut TreeIndex,
+ new_leaf: &LeafNode,
+ new_leaf_idx: LeafIndex,
+ id_provider: &I,
+ extensions: &ExtensionList,
+) -> Result<(), MlsError> {
+ let new_id = id_provider
+ .identity(&new_leaf.signing_identity, extensions)
+ .await
+ .map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))?;
+
+ tree_index.insert(new_leaf_idx, new_leaf, new_id)
+}
+
+#[cfg(not(feature = "tree_index"))]
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(super) async fn index_insert<I: IdentityProvider>(
+ nodes: &NodeVec,
+ new_leaf: &LeafNode,
+ new_leaf_idx: LeafIndex,
+ id_provider: &I,
+ extensions: &ExtensionList,
+) -> Result<(), MlsError> {
+ let new_id = id_provider
+ .identity(&new_leaf.signing_identity, extensions)
+ .await
+ .map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))?;
+
+ for (i, leaf) in nodes.non_empty_leaves().filter(|(i, _)| i != &new_leaf_idx) {
+ (new_leaf.public_key != leaf.public_key)
+ .then_some(())
+ .ok_or(MlsError::DuplicateLeafData(*i))?;
+
+ (new_leaf.signing_identity.signature_key != leaf.signing_identity.signature_key)
+ .then_some(())
+ .ok_or(MlsError::DuplicateLeafData(*i))?;
+
+ let id = id_provider
+ .identity(&leaf.signing_identity, extensions)
+ .await
+ .map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))?;
+
+ (new_id != id)
+ .then_some(())
+ .ok_or(MlsError::DuplicateLeafData(*i))?;
+
+ let cred_type = leaf.signing_identity.credential.credential_type();
+
+ new_leaf
+ .capabilities
+ .credentials
+ .contains(&cred_type)
+ .then_some(())
+ .ok_or(MlsError::InUseCredentialTypeUnsupportedByNewLeaf)?;
+
+ let new_cred_type = new_leaf.signing_identity.credential.credential_type();
+
+ leaf.capabilities
+ .credentials
+ .contains(&new_cred_type)
+ .then_some(())
+ .ok_or(MlsError::CredentialTypeOfNewLeafIsUnsupported)?;
+ }
+
+ Ok(())
+}
+
+#[cfg(feature = "tree_index")]
+impl TreeIndex {
+ pub fn new() -> Self {
+ Default::default()
+ }
+
+ pub fn is_initialized(&self) -> bool {
+ !self.identities.is_empty()
+ }
+
+ fn insert(
+ &mut self,
+ index: LeafIndex,
+ leaf_node: &LeafNode,
+ identity: Vec<u8>,
+ ) -> Result<(), MlsError> {
+ let old_leaf_count = self.credential_signature_key.len();
+
+ let pub_key = leaf_node.signing_identity.signature_key.clone();
+ let credential_entry = self.credential_signature_key.entry(pub_key);
+
+ if let Entry::Occupied(entry) = credential_entry {
+ return Err(MlsError::DuplicateLeafData(**entry.get()));
+ }
+
+ let hpke_entry = self.hpke_key.entry(leaf_node.public_key.clone());
+
+ if let Entry::Occupied(entry) = hpke_entry {
+ return Err(MlsError::DuplicateLeafData(**entry.get()));
+ }
+
+ let identity_entry = self.identities.entry(Identifier(identity));
+ if let Entry::Occupied(entry) = identity_entry {
+ return Err(MlsError::DuplicateLeafData(**entry.get()));
+ }
+
+ let in_use_cred_type_unsupported_by_new_leaf = self
+ .credential_type_counters
+ .iter()
+ .filter_map(|(cred_type, counters)| Some(*cred_type).filter(|_| counters.used > 0))
+ .find(|cred_type| !leaf_node.capabilities.credentials.contains(cred_type));
+
+ if in_use_cred_type_unsupported_by_new_leaf.is_some() {
+ return Err(MlsError::InUseCredentialTypeUnsupportedByNewLeaf);
+ }
+
+ let new_leaf_cred_type = leaf_node.signing_identity.credential.credential_type();
+
+ let cred_type_counters = self
+ .credential_type_counters
+ .entry(new_leaf_cred_type)
+ .or_default();
+
+ if cred_type_counters.supported != old_leaf_count as u32 {
+ return Err(MlsError::CredentialTypeOfNewLeafIsUnsupported);
+ }
+
+ cred_type_counters.used += 1;
+
+ let credential_type_iter = leaf_node.capabilities.credentials.iter().copied();
+
+ #[cfg(feature = "std")]
+ let credential_type_iter = credential_type_iter.unique();
+
+ #[cfg(not(feature = "std"))]
+ let credential_type_iter = credential_type_iter.collect::<BTreeSet<_>>().into_iter();
+
+ // Credential type counter updates
+ credential_type_iter.for_each(|cred_type| {
+ self.credential_type_counters
+ .entry(cred_type)
+ .or_default()
+ .supported += 1;
+ });
+
+ #[cfg(feature = "custom_proposal")]
+ {
+ let proposal_type_iter = leaf_node.capabilities.proposals.iter().copied();
+
+ #[cfg(feature = "std")]
+ let proposal_type_iter = proposal_type_iter.unique();
+
+ #[cfg(not(feature = "std"))]
+ let proposal_type_iter = proposal_type_iter.collect::<BTreeSet<_>>().into_iter();
+
+ // Proposal type counter update
+ proposal_type_iter.for_each(|proposal_type| {
+ *self.proposal_type_counter.entry(proposal_type).or_default() += 1;
+ });
+ }
+
+ identity_entry.or_insert(index);
+ credential_entry.or_insert(index);
+ hpke_entry.or_insert(index);
+
+ Ok(())
+ }
+
+ pub(crate) fn get_leaf_index_with_identity(&self, identity: &[u8]) -> Option<LeafIndex> {
+ self.identities.get(&Identifier(identity.to_vec())).copied()
+ }
+
+ pub fn remove(&mut self, leaf_node: &LeafNode, identity: &[u8]) {
+ let existed = self
+ .identities
+ .remove(&Identifier(identity.to_vec()))
+ .is_some();
+
+ self.credential_signature_key
+ .remove(&leaf_node.signing_identity.signature_key);
+
+ self.hpke_key.remove(&leaf_node.public_key);
+
+ if !existed {
+ return;
+ }
+
+ // Decrement credential type counters
+ let leaf_cred_type = leaf_node.signing_identity.credential.credential_type();
+
+ if let Some(counters) = self.credential_type_counters.get_mut(&leaf_cred_type) {
+ counters.used -= 1;
+ }
+
+ let credential_type_iter = leaf_node.capabilities.credentials.iter();
+
+ #[cfg(feature = "std")]
+ let credential_type_iter = credential_type_iter.unique();
+
+ #[cfg(not(feature = "std"))]
+ let credential_type_iter = credential_type_iter.collect::<BTreeSet<_>>().into_iter();
+
+ credential_type_iter.for_each(|cred_type| {
+ if let Some(counters) = self.credential_type_counters.get_mut(cred_type) {
+ counters.supported -= 1;
+ }
+ });
+
+ #[cfg(feature = "custom_proposal")]
+ {
+ let proposal_type_iter = leaf_node.capabilities.proposals.iter();
+
+ #[cfg(feature = "std")]
+ let proposal_type_iter = proposal_type_iter.unique();
+
+ #[cfg(not(feature = "std"))]
+ let proposal_type_iter = proposal_type_iter.collect::<BTreeSet<_>>().into_iter();
+
+ // Decrement proposal type counters
+ proposal_type_iter.for_each(|proposal_type| {
+ if let Some(supported) = self.proposal_type_counter.get_mut(proposal_type) {
+ *supported -= 1;
+ }
+ })
+ }
+ }
+
+ #[cfg(feature = "custom_proposal")]
+ pub fn count_supporting_proposal(&self, proposal_type: ProposalType) -> u32 {
+ self.proposal_type_counter
+ .get(&proposal_type)
+ .copied()
+ .unwrap_or_default()
+ }
+
+ #[cfg(test)]
+ pub fn len(&self) -> usize {
+ self.credential_signature_key.len()
+ }
+}
+
+#[cfg(feature = "tree_index")]
+#[derive(Clone, Debug, Default, PartialEq, MlsEncode, MlsDecode, MlsSize)]
+struct TypeCounter {
+ supported: u32,
+ used: u32,
+}
+
+#[cfg(feature = "tree_index")]
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::{
+ client::test_utils::TEST_CIPHER_SUITE,
+ tree_kem::leaf_node::test_utils::{get_basic_test_node, get_test_client_identity},
+ };
+ use alloc::format;
+ use assert_matches::assert_matches;
+
+ #[derive(Clone, Debug)]
+ struct TestData {
+ pub leaf_node: LeafNode,
+ pub index: LeafIndex,
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn get_test_data(index: LeafIndex) -> TestData {
+ let cipher_suite = TEST_CIPHER_SUITE;
+ let leaf_node = get_basic_test_node(cipher_suite, &format!("foo{}", index.0)).await;
+
+ TestData { leaf_node, index }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn test_setup() -> (Vec<TestData>, TreeIndex) {
+ let mut test_data = Vec::new();
+
+ for i in 0..10 {
+ test_data.push(get_test_data(LeafIndex(i)).await);
+ }
+
+ let mut test_index = TreeIndex::new();
+
+ test_data.clone().into_iter().for_each(|d| {
+ test_index
+ .insert(
+ d.index,
+ &d.leaf_node,
+ get_test_client_identity(&d.leaf_node),
+ )
+ .unwrap()
+ });
+
+ (test_data, test_index)
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_insert() {
+ let (test_data, test_index) = test_setup().await;
+
+ assert_eq!(test_index.credential_signature_key.len(), test_data.len());
+ assert_eq!(test_index.hpke_key.len(), test_data.len());
+
+ test_data.into_iter().enumerate().for_each(|(i, d)| {
+ let pub_key = d.leaf_node.signing_identity.signature_key;
+
+ assert_eq!(
+ test_index.credential_signature_key.get(&pub_key),
+ Some(&LeafIndex(i as u32))
+ );
+
+ assert_eq!(
+ test_index.hpke_key.get(&d.leaf_node.public_key),
+ Some(&LeafIndex(i as u32))
+ );
+ })
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_insert_duplicate_credential_key() {
+ let (test_data, mut test_index) = test_setup().await;
+
+ let before_error = test_index.clone();
+
+ let mut new_key_package = get_basic_test_node(TEST_CIPHER_SUITE, "foo").await;
+ new_key_package.signing_identity = test_data[1].leaf_node.signing_identity.clone();
+
+ let res = test_index.insert(
+ test_data[1].index,
+ &new_key_package,
+ get_test_client_identity(&new_key_package),
+ );
+
+ assert_matches!(res, Err(MlsError::DuplicateLeafData(index))
+ if index == *test_data[1].index);
+
+ assert_eq!(before_error, test_index);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_insert_duplicate_hpke_key() {
+ let cipher_suite = TEST_CIPHER_SUITE;
+ let (test_data, mut test_index) = test_setup().await;
+ let before_error = test_index.clone();
+
+ let mut new_leaf_node = get_basic_test_node(cipher_suite, "foo").await;
+ new_leaf_node.public_key = test_data[1].leaf_node.public_key.clone();
+
+ let res = test_index.insert(
+ test_data[1].index,
+ &new_leaf_node,
+ get_test_client_identity(&new_leaf_node),
+ );
+
+ assert_matches!(res, Err(MlsError::DuplicateLeafData(index))
+ if index == *test_data[1].index);
+
+ assert_eq!(before_error, test_index);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_remove() {
+ let (test_data, mut test_index) = test_setup().await;
+
+ test_index.remove(
+ &test_data[1].leaf_node,
+ &get_test_client_identity(&test_data[1].leaf_node),
+ );
+
+ assert_eq!(
+ test_index.credential_signature_key.len(),
+ test_data.len() - 1
+ );
+
+ assert_eq!(test_index.hpke_key.len(), test_data.len() - 1);
+
+ assert_eq!(
+ test_index
+ .credential_signature_key
+ .get(&test_data[1].leaf_node.signing_identity.signature_key),
+ None
+ );
+
+ assert_eq!(
+ test_index.hpke_key.get(&test_data[1].leaf_node.public_key),
+ None
+ );
+ }
+
+ #[cfg(feature = "custom_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn custom_proposals() {
+ let test_proposal_id = ProposalType::new(42);
+ let other_proposal_id = ProposalType::new(45);
+
+ let mut test_data_1 = get_test_data(LeafIndex(0)).await;
+
+ test_data_1
+ .leaf_node
+ .capabilities
+ .proposals
+ .push(test_proposal_id);
+
+ let mut test_data_2 = get_test_data(LeafIndex(1)).await;
+
+ test_data_2
+ .leaf_node
+ .capabilities
+ .proposals
+ .push(test_proposal_id);
+
+ test_data_2
+ .leaf_node
+ .capabilities
+ .proposals
+ .push(other_proposal_id);
+
+ let mut test_index = TreeIndex::new();
+
+ test_index
+ .insert(test_data_1.index, &test_data_1.leaf_node, vec![0])
+ .unwrap();
+
+ assert_eq!(test_index.count_supporting_proposal(test_proposal_id), 1);
+
+ test_index
+ .insert(test_data_2.index, &test_data_2.leaf_node, vec![1])
+ .unwrap();
+
+ assert_eq!(test_index.count_supporting_proposal(test_proposal_id), 2);
+ assert_eq!(test_index.count_supporting_proposal(other_proposal_id), 1);
+
+ test_index.remove(&test_data_2.leaf_node, &[1]);
+
+ assert_eq!(test_index.count_supporting_proposal(test_proposal_id), 1);
+ assert_eq!(test_index.count_supporting_proposal(other_proposal_id), 0);
+ }
+}
diff --git a/src/tree_kem/tree_utils.rs b/src/tree_kem/tree_utils.rs
new file mode 100644
index 0000000..e7cdeb1
--- /dev/null
+++ b/src/tree_kem/tree_utils.rs
@@ -0,0 +1,191 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::string::String;
+use alloc::{format, vec};
+use core::borrow::BorrowMut;
+
+use debug_tree::TreeBuilder;
+
+use super::node::{NodeIndex, NodeVec};
+use crate::{client::MlsError, tree_kem::math::TreeIndex};
+
+pub(crate) fn build_tree(
+ tree: &mut TreeBuilder,
+ nodes: &NodeVec,
+ idx: NodeIndex,
+) -> Result<(), MlsError> {
+ let blank_tag = if nodes.is_blank(idx)? { "Blank " } else { "" };
+
+ // Leaf Node
+ if nodes.is_leaf(idx) {
+ let leaf_tag = format!("{blank_tag}Leaf ({idx})");
+ tree.add_leaf(&leaf_tag);
+ return Ok(());
+ }
+
+ // Parent Leaf
+ let mut parent_tag = format!("{blank_tag}Parent ({idx})");
+
+ if nodes.total_leaf_count().root() == idx {
+ parent_tag = format!("{blank_tag}Root ({idx})");
+ }
+
+ // Add unmerged leaves indexes
+ let unmerged_leaves_idxs = match nodes.borrow_as_parent(idx) {
+ Ok(parent) => parent
+ .unmerged_leaves
+ .iter()
+ .map(|leaf_idx| format!("{}", leaf_idx.0))
+ .collect(),
+ Err(_) => {
+ // Empty parent nodes throw `NotParent` error when borrow as Parent
+ vec![]
+ }
+ };
+
+ if !unmerged_leaves_idxs.is_empty() {
+ let unmerged_leaves_tag =
+ format!(" unmerged leaves idxs: {}", unmerged_leaves_idxs.join(","));
+ parent_tag.push_str(&unmerged_leaves_tag);
+ }
+
+ let mut branch = tree.add_branch(&parent_tag);
+
+ //This cannot panic, as we already checked that idx is not a leaf
+ build_tree(tree, nodes, idx.left_unchecked())?;
+ build_tree(tree, nodes, idx.right_unchecked())?;
+
+ branch.release();
+
+ Ok(())
+}
+
+pub(crate) fn build_ascii_tree(nodes: &NodeVec) -> String {
+ let leaves_count: u32 = nodes.total_leaf_count();
+ let mut tree = TreeBuilder::new();
+ build_tree(tree.borrow_mut(), nodes, leaves_count.root()).unwrap();
+ tree.string()
+}
+
+#[cfg(test)]
+mod tests {
+ use alloc::vec;
+
+ use crate::{
+ client::test_utils::TEST_CIPHER_SUITE,
+ crypto::test_utils::test_cipher_suite_provider,
+ identity::basic::BasicIdentityProvider,
+ tree_kem::{
+ node::Parent,
+ parent_hash::ParentHash,
+ test_utils::{get_test_leaf_nodes, get_test_tree},
+ },
+ };
+
+ use super::build_ascii_tree;
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn print_fully_populated_tree() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ // Create a tree
+ let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public;
+ let key_packages = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
+
+ tree.add_leaves(key_packages, &BasicIdentityProvider, &cipher_suite_provider)
+ .await
+ .unwrap();
+
+ let tree_str = concat!(
+ "Blank Root (3)\n",
+ "├╼ Blank Parent (1)\n",
+ "│ ├╼ Leaf (0)\n",
+ "│ └╼ Leaf (2)\n",
+ "└╼ Blank Parent (5)\n",
+ " ├╼ Leaf (4)\n",
+ " └╼ Leaf (6)",
+ );
+
+ assert_eq!(tree_str, build_ascii_tree(&tree.nodes));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn print_tree_blank_leaves() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ // Create a tree
+ let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public;
+ let key_packages = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
+
+ let to_remove = tree
+ .add_leaves(key_packages, &BasicIdentityProvider, &cipher_suite_provider)
+ .await
+ .unwrap()[0];
+
+ tree.remove_leaves(
+ vec![to_remove],
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap();
+
+ let tree_str = concat!(
+ "Blank Root (3)\n",
+ "├╼ Blank Parent (1)\n",
+ "│ ├╼ Leaf (0)\n",
+ "│ └╼ Blank Leaf (2)\n",
+ "└╼ Blank Parent (5)\n",
+ " ├╼ Leaf (4)\n",
+ " └╼ Leaf (6)",
+ );
+
+ assert_eq!(tree_str, build_ascii_tree(&tree.nodes));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn print_tree_unmerged_leaves_on_parent() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ // Create a tree
+ let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public;
+ let key_packages = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
+
+ tree.add_leaves(
+ [key_packages[0].clone(), key_packages[1].clone()].to_vec(),
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap();
+
+ tree.nodes[3] = Parent {
+ public_key: vec![].into(),
+ parent_hash: ParentHash::empty(),
+ unmerged_leaves: vec![],
+ }
+ .into();
+
+ tree.add_leaves(
+ [key_packages[2].clone()].to_vec(),
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap();
+
+ let tree_str = concat!(
+ "Root (3) unmerged leaves idxs: 3\n",
+ "├╼ Blank Parent (1)\n",
+ "│ ├╼ Leaf (0)\n",
+ "│ └╼ Leaf (2)\n",
+ "└╼ Blank Parent (5)\n",
+ " ├╼ Leaf (4)\n",
+ " └╼ Leaf (6)",
+ );
+
+ assert_eq!(tree_str, build_ascii_tree(&tree.nodes));
+ }
+}
diff --git a/src/tree_kem/tree_validator.rs b/src/tree_kem/tree_validator.rs
new file mode 100644
index 0000000..26d4baf
--- /dev/null
+++ b/src/tree_kem/tree_validator.rs
@@ -0,0 +1,356 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+#[cfg(feature = "std")]
+use std::collections::HashSet;
+
+#[cfg(not(feature = "std"))]
+use alloc::{vec, vec::Vec};
+use tree_math::TreeIndex;
+
+use super::node::{Node, NodeIndex};
+use crate::client::MlsError;
+use crate::crypto::CipherSuiteProvider;
+use crate::group::GroupContext;
+use crate::iter::wrap_impl_iter;
+use crate::tree_kem::math as tree_math;
+use crate::tree_kem::{leaf_node_validator::LeafNodeValidator, TreeKemPublic};
+use mls_rs_core::identity::IdentityProvider;
+
+#[cfg(all(not(mls_build_async), feature = "rayon"))]
+use rayon::prelude::*;
+
+#[cfg(mls_build_async)]
+use futures::{StreamExt, TryStreamExt};
+
+pub(crate) struct TreeValidator<'a, C, CSP>
+where
+ C: IdentityProvider,
+ CSP: CipherSuiteProvider,
+{
+ expected_tree_hash: &'a [u8],
+ leaf_node_validator: LeafNodeValidator<'a, C, CSP>,
+ group_id: &'a [u8],
+ cipher_suite_provider: &'a CSP,
+}
+
+impl<'a, C: IdentityProvider, CSP: CipherSuiteProvider> TreeValidator<'a, C, CSP> {
+ pub fn new(
+ cipher_suite_provider: &'a CSP,
+ context: &'a GroupContext,
+ identity_provider: &'a C,
+ ) -> Self {
+ TreeValidator {
+ expected_tree_hash: &context.tree_hash,
+ leaf_node_validator: LeafNodeValidator::new(
+ cipher_suite_provider,
+ identity_provider,
+ Some(&context.extensions),
+ ),
+ group_id: &context.group_id,
+ cipher_suite_provider,
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn validate(&self, tree: &mut TreeKemPublic) -> Result<(), MlsError> {
+ self.validate_tree_hash(tree).await?;
+
+ tree.validate_parent_hashes(self.cipher_suite_provider)
+ .await?;
+
+ self.validate_no_trailing_blanks(tree)?;
+ self.validate_leaves(tree).await?;
+ validate_unmerged(tree)
+ }
+
+ fn validate_no_trailing_blanks(&self, tree: &TreeKemPublic) -> Result<(), MlsError> {
+ tree.nodes
+ .last()
+ .ok_or(MlsError::UnexpectedEmptyTree)?
+ .is_some()
+ .then_some(())
+ .ok_or(MlsError::UnexpectedTrailingBlanks)
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn validate_tree_hash(&self, tree: &mut TreeKemPublic) -> Result<(), MlsError> {
+ //Verify that the tree hash of the ratchet tree matches the tree_hash field in the GroupInfo.
+ let tree_hash = tree.tree_hash(self.cipher_suite_provider).await?;
+
+ if tree_hash != self.expected_tree_hash {
+ return Err(MlsError::TreeHashMismatch);
+ }
+
+ Ok(())
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn validate_leaves(&self, tree: &TreeKemPublic) -> Result<(), MlsError> {
+ let leaves = wrap_impl_iter(tree.nodes.non_empty_leaves());
+
+ #[cfg(mls_build_async)]
+ let leaves = leaves.map(Ok);
+
+ { leaves }
+ .try_for_each(|(index, leaf_node)| async move {
+ self.leaf_node_validator
+ .revalidate(leaf_node, self.group_id, *index)
+ .await
+ })
+ .await
+ }
+}
+
+fn validate_unmerged(tree: &TreeKemPublic) -> Result<(), MlsError> {
+ let unmerged_sets = tree.nodes.iter().map(|n| {
+ #[cfg(feature = "std")]
+ if let Some(Node::Parent(p)) = n {
+ HashSet::from_iter(p.unmerged_leaves.iter().cloned())
+ } else {
+ HashSet::new()
+ }
+
+ #[cfg(not(feature = "std"))]
+ if let Some(Node::Parent(p)) = n {
+ p.unmerged_leaves.clone()
+ } else {
+ vec![]
+ }
+ });
+
+ let mut unmerged_sets = unmerged_sets.collect::<Vec<_>>();
+
+ // For each leaf L, we search for the longest prefix P[1], P[2], ..., P[k] of the direct path of L
+ // such that for each i=1..k, either L is in the unmerged leaves of P[i], or P[i] is blank. We will
+ // then check that L is unmerged at each P[1], ..., P[k] and no other node.
+ let leaf_count = tree.total_leaf_count();
+
+ for (index, _) in tree.nodes.non_empty_leaves() {
+ let mut n = NodeIndex::from(index);
+
+ while let Some(ps) = n.parent_sibling(&leaf_count) {
+ if tree.nodes.is_blank(ps.parent)? {
+ n = ps.parent;
+ continue;
+ }
+
+ let parent_node = tree.nodes.borrow_as_parent(ps.parent)?;
+
+ if parent_node.unmerged_leaves.contains(&index) {
+ unmerged_sets[ps.parent as usize].retain(|i| i != &index);
+
+ n = ps.parent;
+ } else {
+ break;
+ }
+ }
+ }
+
+ let unmerged_sets = unmerged_sets.iter().all(|set| set.is_empty());
+
+ unmerged_sets
+ .then_some(())
+ .ok_or(MlsError::UnmergedLeavesMismatch)
+}
+
+#[cfg(test)]
+mod tests {
+ use alloc::vec;
+ use assert_matches::assert_matches;
+
+ use super::*;
+ use crate::{
+ cipher_suite::CipherSuite,
+ client::test_utils::TEST_CIPHER_SUITE,
+ crypto::test_utils::test_cipher_suite_provider,
+ crypto::test_utils::TestCryptoProvider,
+ group::test_utils::{get_test_group_context, random_bytes},
+ identity::basic::BasicIdentityProvider,
+ tree_kem::{
+ kem::TreeKem,
+ leaf_node::test_utils::{default_properties, get_basic_test_node},
+ node::{LeafIndex, Node, Parent},
+ parent_hash::{test_utils::get_test_tree_fig_12, ParentHash},
+ test_utils::get_test_tree,
+ },
+ };
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn test_parent_node(cipher_suite: CipherSuite) -> Parent {
+ let (_, public_key) = test_cipher_suite_provider(cipher_suite)
+ .kem_generate()
+ .await
+ .unwrap();
+
+ Parent {
+ public_key,
+ parent_hash: ParentHash::empty(),
+ unmerged_leaves: vec![],
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn get_valid_tree(cipher_suite: CipherSuite) -> TreeKemPublic {
+ let cipher_suite_provider = test_cipher_suite_provider(cipher_suite);
+
+ let mut test_tree = get_test_tree(cipher_suite).await;
+
+ let leaf1 = get_basic_test_node(cipher_suite, "leaf1").await;
+ let leaf2 = get_basic_test_node(cipher_suite, "leaf2").await;
+
+ test_tree
+ .public
+ .add_leaves(
+ vec![leaf1, leaf2],
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap();
+
+ test_tree.public.nodes[1] = Some(Node::Parent(test_parent_node(cipher_suite).await));
+ test_tree.public.nodes[3] = Some(Node::Parent(test_parent_node(cipher_suite).await));
+
+ TreeKem::new(&mut test_tree.public, &mut test_tree.private)
+ .encap(
+ &mut get_test_group_context(42, cipher_suite).await,
+ &[LeafIndex(1), LeafIndex(2)],
+ &test_tree.creator_signing_key,
+ default_properties(),
+ None,
+ &cipher_suite_provider,
+ #[cfg(test)]
+ &Default::default(),
+ )
+ .await
+ .unwrap();
+
+ test_tree.public
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_valid_tree() {
+ for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
+ let cipher_suite_provider = test_cipher_suite_provider(cipher_suite);
+
+ let mut test_tree = get_valid_tree(cipher_suite).await;
+
+ let mut context = get_test_group_context(1, cipher_suite).await;
+ context.tree_hash = test_tree.tree_hash(&cipher_suite_provider).await.unwrap();
+
+ let validator =
+ TreeValidator::new(&cipher_suite_provider, &context, &BasicIdentityProvider);
+
+ validator.validate(&mut test_tree).await.unwrap();
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_tree_hash_mismatch() {
+ for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
+ let mut test_tree = get_valid_tree(cipher_suite).await;
+
+ let cipher_suite_provider = test_cipher_suite_provider(cipher_suite);
+ let context = get_test_group_context(1, cipher_suite).await;
+
+ let validator =
+ TreeValidator::new(&cipher_suite_provider, &context, &BasicIdentityProvider);
+
+ let res = validator.validate(&mut test_tree).await;
+
+ assert_matches!(res, Err(MlsError::TreeHashMismatch));
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_parent_hash_mismatch() {
+ for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
+ let mut test_tree = get_valid_tree(cipher_suite).await;
+
+ let parent_node = test_tree.nodes.borrow_as_parent_mut(1).unwrap();
+ parent_node.parent_hash = ParentHash::from(random_bytes(32));
+
+ let cipher_suite_provider = test_cipher_suite_provider(cipher_suite);
+ let mut context = get_test_group_context(1, cipher_suite).await;
+ context.tree_hash = test_tree.tree_hash(&cipher_suite_provider).await.unwrap();
+
+ let validator =
+ TreeValidator::new(&cipher_suite_provider, &context, &BasicIdentityProvider);
+
+ let res = validator.validate(&mut test_tree).await;
+
+ assert_matches!(res, Err(MlsError::ParentHashMismatch));
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_key_package_validation_failure() {
+ for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
+ let mut test_tree = get_valid_tree(cipher_suite).await;
+
+ test_tree
+ .nodes
+ .borrow_as_leaf_mut(LeafIndex(0))
+ .unwrap()
+ .signature = random_bytes(32);
+
+ let cipher_suite_provider = test_cipher_suite_provider(cipher_suite);
+ let mut context = get_test_group_context(1, cipher_suite).await;
+ context.tree_hash = test_tree.tree_hash(&cipher_suite_provider).await.unwrap();
+
+ let validator =
+ TreeValidator::new(&cipher_suite_provider, &context, &BasicIdentityProvider);
+
+ let res = validator.validate(&mut test_tree).await;
+
+ assert_matches!(res, Err(MlsError::InvalidSignature));
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn verify_unmerged_with_correct_tree() {
+ let tree = get_test_tree_fig_12(TEST_CIPHER_SUITE).await;
+ validate_unmerged(&tree).unwrap();
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn verify_unmerged_with_blank_leaf() {
+ let mut tree = get_test_tree_fig_12(TEST_CIPHER_SUITE).await;
+
+ // Blank leaf D unmerged at nodes 3, 7
+ tree.nodes[6] = None;
+
+ assert_matches!(
+ validate_unmerged(&tree),
+ Err(MlsError::UnmergedLeavesMismatch)
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn verify_unmerged_with_broken_path() {
+ let mut tree = get_test_tree_fig_12(TEST_CIPHER_SUITE).await;
+
+ // Make D with direct path [3, 7] unmerged at 7 but not 3
+ tree.nodes.borrow_as_parent_mut(3).unwrap().unmerged_leaves = vec![];
+
+ assert_matches!(
+ validate_unmerged(&tree),
+ Err(MlsError::UnmergedLeavesMismatch)
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn verify_unmerged_with_leaf_outside_tree() {
+ let mut tree = get_test_tree_fig_12(TEST_CIPHER_SUITE).await;
+
+ // Add leaf E from the right subtree of the root to unmerged leaves of node 1 on the left
+ tree.nodes.borrow_as_parent_mut(1).unwrap().unmerged_leaves = vec![LeafIndex(4)];
+
+ assert_matches!(
+ validate_unmerged(&tree),
+ Err(MlsError::UnmergedLeavesMismatch)
+ );
+ }
+}
diff --git a/src/tree_kem/update_path.rs b/src/tree_kem/update_path.rs
new file mode 100644
index 0000000..654c21f
--- /dev/null
+++ b/src/tree_kem/update_path.rs
@@ -0,0 +1,274 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::{vec, vec::Vec};
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+use mls_rs_core::{error::IntoAnyError, identity::IdentityProvider};
+
+use super::{
+ leaf_node::LeafNode,
+ leaf_node_validator::{LeafNodeValidator, ValidationContext},
+ node::LeafIndex,
+};
+use crate::{
+ client::MlsError,
+ crypto::{CipherSuiteProvider, HpkeCiphertext, HpkePublicKey},
+};
+use crate::{group::message_processor::ProvisionalState, time::MlsTime};
+
+#[derive(Clone, Debug, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub struct UpdatePathNode {
+ pub public_key: HpkePublicKey,
+ pub encrypted_path_secret: Vec<HpkeCiphertext>,
+}
+
+#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub struct UpdatePath {
+ pub leaf_node: LeafNode,
+ pub nodes: Vec<UpdatePathNode>,
+}
+
+#[derive(Clone, Debug, PartialEq)]
+pub struct ValidatedUpdatePath {
+ pub leaf_node: LeafNode,
+ pub nodes: Vec<Option<UpdatePathNode>>,
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(crate) async fn validate_update_path<C: IdentityProvider, CSP: CipherSuiteProvider>(
+ identity_provider: &C,
+ cipher_suite_provider: &CSP,
+ path: UpdatePath,
+ state: &ProvisionalState,
+ sender: LeafIndex,
+ commit_time: Option<MlsTime>,
+) -> Result<ValidatedUpdatePath, MlsError> {
+ let group_context_extensions = &state.group_context.extensions;
+
+ let leaf_validator = LeafNodeValidator::new(
+ cipher_suite_provider,
+ identity_provider,
+ Some(group_context_extensions),
+ );
+
+ leaf_validator
+ .check_if_valid(
+ &path.leaf_node,
+ ValidationContext::Commit((&state.group_context.group_id, *sender, commit_time)),
+ )
+ .await?;
+
+ let check_identity_eq = state.applied_proposals.external_initializations.is_empty();
+
+ if check_identity_eq {
+ let existing_leaf = state.public_tree.nodes.borrow_as_leaf(sender)?;
+ let original_leaf_node = existing_leaf.clone();
+
+ identity_provider
+ .valid_successor(
+ &original_leaf_node.signing_identity,
+ &path.leaf_node.signing_identity,
+ group_context_extensions,
+ )
+ .await
+ .map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))?
+ .then_some(())
+ .ok_or(MlsError::InvalidSuccessor)?;
+
+ (existing_leaf.public_key != path.leaf_node.public_key)
+ .then_some(())
+ .ok_or(MlsError::SameHpkeKey(*sender))?;
+ }
+
+ // Unfilter the update path
+ let filtered = state.public_tree.nodes.filtered(sender)?;
+ let mut unfiltered_nodes = vec![];
+ let mut i = 0;
+
+ for n in path.nodes {
+ while *filtered.get(i).ok_or(MlsError::WrongPathLen)? {
+ unfiltered_nodes.push(None);
+ i += 1;
+ }
+
+ unfiltered_nodes.push(Some(n));
+ i += 1;
+ }
+
+ Ok(ValidatedUpdatePath {
+ leaf_node: path.leaf_node,
+ nodes: unfiltered_nodes,
+ })
+}
+
+#[cfg(test)]
+mod tests {
+ use alloc::vec;
+ use assert_matches::assert_matches;
+
+ use crate::client::test_utils::TEST_CIPHER_SUITE;
+ use crate::crypto::test_utils::test_cipher_suite_provider;
+ use crate::crypto::HpkeCiphertext;
+ use crate::group::message_processor::ProvisionalState;
+ use crate::group::test_utils::{get_test_group_context, random_bytes, TEST_GROUP};
+ use crate::identity::basic::BasicIdentityProvider;
+ use crate::tree_kem::leaf_node::test_utils::default_properties;
+ use crate::tree_kem::leaf_node::test_utils::get_basic_test_node_sig_key;
+ use crate::tree_kem::leaf_node::LeafNodeSource;
+ use crate::tree_kem::node::LeafIndex;
+ use crate::tree_kem::parent_hash::ParentHash;
+ use crate::tree_kem::test_utils::{get_test_leaf_nodes, get_test_tree};
+ use crate::tree_kem::validate_update_path;
+
+ use super::{UpdatePath, UpdatePathNode};
+ use crate::{cipher_suite::CipherSuite, tree_kem::MlsError};
+
+ use alloc::vec::Vec;
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn test_update_path(cipher_suite: CipherSuite, cred: &str) -> UpdatePath {
+ let (mut leaf_node, _, signer) = get_basic_test_node_sig_key(cipher_suite, cred).await;
+
+ leaf_node.leaf_node_source = LeafNodeSource::Commit(ParentHash::from(hex!("beef")));
+
+ leaf_node
+ .commit(
+ &test_cipher_suite_provider(cipher_suite),
+ TEST_GROUP,
+ 0,
+ default_properties(),
+ None,
+ &signer,
+ )
+ .await
+ .unwrap();
+
+ let node = UpdatePathNode {
+ public_key: random_bytes(32).into(),
+ encrypted_path_secret: vec![HpkeCiphertext {
+ kem_output: random_bytes(32),
+ ciphertext: random_bytes(32),
+ }],
+ };
+
+ UpdatePath {
+ leaf_node,
+ nodes: vec![node.clone(), node],
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn test_provisional_state(cipher_suite: CipherSuite) -> ProvisionalState {
+ let mut tree = get_test_tree(cipher_suite).await.public;
+ let leaf_nodes = get_test_leaf_nodes(cipher_suite).await;
+
+ tree.add_leaves(
+ leaf_nodes,
+ &BasicIdentityProvider,
+ &test_cipher_suite_provider(cipher_suite),
+ )
+ .await
+ .unwrap();
+
+ ProvisionalState {
+ public_tree: tree,
+ applied_proposals: Default::default(),
+ group_context: get_test_group_context(1, cipher_suite).await,
+ indexes_of_added_kpkgs: vec![],
+ external_init_index: None,
+ #[cfg(feature = "state_update")]
+ unused_proposals: vec![],
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_valid_leaf_node() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+ let update_path = test_update_path(TEST_CIPHER_SUITE, "creator").await;
+
+ let validated = validate_update_path(
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ update_path.clone(),
+ &test_provisional_state(TEST_CIPHER_SUITE).await,
+ LeafIndex(0),
+ None,
+ )
+ .await
+ .unwrap();
+
+ let expected = update_path.nodes.into_iter().map(Some).collect::<Vec<_>>();
+
+ assert_eq!(validated.nodes, expected);
+ assert_eq!(validated.leaf_node, update_path.leaf_node);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_invalid_key_package() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+ let mut update_path = test_update_path(TEST_CIPHER_SUITE, "creator").await;
+ update_path.leaf_node.signature = random_bytes(32);
+
+ let validated = validate_update_path(
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ update_path,
+ &test_provisional_state(TEST_CIPHER_SUITE).await,
+ LeafIndex(0),
+ None,
+ )
+ .await;
+
+ assert_matches!(validated, Err(MlsError::InvalidSignature));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn validating_path_fails_with_different_identity() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+ let cipher_suite = TEST_CIPHER_SUITE;
+ let update_path = test_update_path(cipher_suite, "foobar").await;
+
+ let validated = validate_update_path(
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ update_path,
+ &test_provisional_state(cipher_suite).await,
+ LeafIndex(0),
+ None,
+ )
+ .await;
+
+ assert_matches!(validated, Err(MlsError::InvalidSuccessor));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn validating_path_fails_with_same_hpke_key() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+ let update_path = test_update_path(TEST_CIPHER_SUITE, "creator").await;
+ let mut state = test_provisional_state(TEST_CIPHER_SUITE).await;
+
+ state
+ .public_tree
+ .nodes
+ .borrow_as_leaf_mut(LeafIndex(0))
+ .unwrap()
+ .public_key = update_path.leaf_node.public_key.clone();
+
+ let validated = validate_update_path(
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ update_path,
+ &state,
+ LeafIndex(0),
+ None,
+ )
+ .await;
+
+ assert_matches!(validated, Err(MlsError::SameHpkeKey(_)));
+ }
+}