Import 'mls-rs' crate
Request Document: go/android-rust-importing-crates
For CL Reviewers: go/android3p#cl-review
For Build Team: go/ab-third-party-imports
Bug: http://b/330708876
Test: m libmls_rs
Change-Id: Ib0a891a4d7bf582ebea9ba7a1447ea959e42e0d3
diff --git a/src/group/ciphertext_processor.rs b/src/group/ciphertext_processor.rs
new file mode 100644
index 0000000..bf70f5d
--- /dev/null
+++ b/src/group/ciphertext_processor.rs
@@ -0,0 +1,410 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use self::{
+ message_key::MessageKey,
+ reuse_guard::ReuseGuard,
+ sender_data_key::{SenderData, SenderDataAAD, SenderDataKey},
+};
+
+use super::{
+ epoch::EpochSecrets,
+ framing::{ContentType, FramedContent, Sender, WireFormat},
+ message_signature::AuthenticatedContent,
+ padding::PaddingMode,
+ secret_tree::{KeyType, MessageKeyData},
+ GroupContext,
+};
+use crate::{
+ client::MlsError,
+ tree_kem::node::{LeafIndex, NodeIndex},
+};
+use mls_rs_codec::MlsEncode;
+use mls_rs_core::{crypto::CipherSuiteProvider, error::IntoAnyError};
+use zeroize::Zeroizing;
+
+mod message_key;
+mod reuse_guard;
+mod sender_data_key;
+
+#[cfg(feature = "private_message")]
+use super::framing::{PrivateContentAAD, PrivateMessage, PrivateMessageContent};
+
+#[cfg(test)]
+pub use sender_data_key::test_utils::*;
+
+pub(crate) trait GroupStateProvider {
+ fn group_context(&self) -> &GroupContext;
+ fn self_index(&self) -> LeafIndex;
+ fn epoch_secrets_mut(&mut self) -> &mut EpochSecrets;
+ fn epoch_secrets(&self) -> &EpochSecrets;
+}
+
+pub(crate) struct CiphertextProcessor<'a, GS, CP>
+where
+ GS: GroupStateProvider,
+ CP: CipherSuiteProvider,
+{
+ group_state: &'a mut GS,
+ cipher_suite_provider: CP,
+}
+
+impl<'a, GS, CP> CiphertextProcessor<'a, GS, CP>
+where
+ GS: GroupStateProvider,
+ CP: CipherSuiteProvider,
+{
+ pub fn new(
+ group_state: &'a mut GS,
+ cipher_suite_provider: CP,
+ ) -> CiphertextProcessor<'a, GS, CP> {
+ Self {
+ group_state,
+ cipher_suite_provider,
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn next_encryption_key(
+ &mut self,
+ key_type: KeyType,
+ ) -> Result<MessageKeyData, MlsError> {
+ let self_index = NodeIndex::from(self.group_state.self_index());
+
+ self.group_state
+ .epoch_secrets_mut()
+ .secret_tree
+ .next_message_key(&self.cipher_suite_provider, self_index, key_type)
+ .await
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn decryption_key(
+ &mut self,
+ sender: LeafIndex,
+ key_type: KeyType,
+ generation: u32,
+ ) -> Result<MessageKeyData, MlsError> {
+ let sender = NodeIndex::from(sender);
+
+ self.group_state
+ .epoch_secrets_mut()
+ .secret_tree
+ .message_key_generation(&self.cipher_suite_provider, sender, key_type, generation)
+ .await
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn seal(
+ &mut self,
+ auth_content: AuthenticatedContent,
+ padding: PaddingMode,
+ ) -> Result<PrivateMessage, MlsError> {
+ if Sender::Member(*self.group_state.self_index()) != auth_content.content.sender {
+ return Err(MlsError::InvalidSender);
+ }
+
+ let content_type = ContentType::from(&auth_content.content.content);
+ let authenticated_data = auth_content.content.authenticated_data;
+
+ // Build a ciphertext content using the plaintext content and signature
+ let private_content = PrivateMessageContent {
+ content: auth_content.content.content,
+ auth: auth_content.auth,
+ };
+
+ // Build ciphertext aad using the plaintext message
+ let aad = PrivateContentAAD {
+ group_id: auth_content.content.group_id,
+ epoch: auth_content.content.epoch,
+ content_type,
+ authenticated_data: authenticated_data.clone(),
+ };
+
+ // Generate a 4 byte reuse guard
+ let reuse_guard = ReuseGuard::random(&self.cipher_suite_provider)
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
+
+ // Grab an encryption key from the current epoch's key schedule
+ let key_type = match &content_type {
+ ContentType::Application => KeyType::Application,
+ _ => KeyType::Handshake,
+ };
+
+ let mut serialized_private_content = private_content.mls_encode_to_vec()?;
+
+ // Apply padding to private content based on the current padding mode.
+ serialized_private_content.resize(padding.padded_size(serialized_private_content.len()), 0);
+
+ let serialized_private_content = Zeroizing::new(serialized_private_content);
+
+ // Encrypt the ciphertext content using the encryption key and a nonce that is
+ // reuse safe by xor the reuse guard with the first 4 bytes
+ let self_index = self.group_state.self_index();
+
+ let key_data = self.next_encryption_key(key_type).await?;
+ let generation = key_data.generation;
+
+ let ciphertext = MessageKey::new(key_data)
+ .encrypt(
+ &self.cipher_suite_provider,
+ &serialized_private_content,
+ &aad.mls_encode_to_vec()?,
+ &reuse_guard,
+ )
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
+
+ // Construct an mls sender data struct using the plaintext sender info, the generation
+ // of the key schedule encryption key, and the reuse guard used to encrypt ciphertext
+ let sender_data = SenderData {
+ sender: self_index,
+ generation,
+ reuse_guard,
+ };
+
+ let sender_data_aad = SenderDataAAD {
+ group_id: self.group_state.group_context().group_id.clone(),
+ epoch: self.group_state.group_context().epoch,
+ content_type,
+ };
+
+ // Encrypt the sender data with the derived sender_key and sender_nonce from the current
+ // epoch's key schedule
+ let sender_data_key = SenderDataKey::new(
+ &self.group_state.epoch_secrets().sender_data_secret,
+ &ciphertext,
+ &self.cipher_suite_provider,
+ )
+ .await?;
+
+ let encrypted_sender_data = sender_data_key.seal(&sender_data, &sender_data_aad).await?;
+
+ Ok(PrivateMessage {
+ group_id: self.group_state.group_context().group_id.clone(),
+ epoch: self.group_state.group_context().epoch,
+ content_type,
+ authenticated_data,
+ encrypted_sender_data,
+ ciphertext,
+ })
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn open(
+ &mut self,
+ ciphertext: &PrivateMessage,
+ ) -> Result<AuthenticatedContent, MlsError> {
+ // Decrypt the sender data with the derived sender_key and sender_nonce from the message
+ // epoch's key schedule
+ let sender_data_aad = SenderDataAAD {
+ group_id: self.group_state.group_context().group_id.clone(),
+ epoch: self.group_state.group_context().epoch,
+ content_type: ciphertext.content_type,
+ };
+
+ let sender_data_key = SenderDataKey::new(
+ &self.group_state.epoch_secrets().sender_data_secret,
+ &ciphertext.ciphertext,
+ &self.cipher_suite_provider,
+ )
+ .await?;
+
+ let sender_data = sender_data_key
+ .open(&ciphertext.encrypted_sender_data, &sender_data_aad)
+ .await?;
+
+ if self.group_state.self_index() == sender_data.sender {
+ return Err(MlsError::CantProcessMessageFromSelf);
+ }
+
+ // Grab a decryption key from the message epoch's key schedule
+ let key_type = match &ciphertext.content_type {
+ ContentType::Application => KeyType::Application,
+ _ => KeyType::Handshake,
+ };
+
+ // Decrypt the content of the message using the grabbed key
+ let key = self
+ .decryption_key(sender_data.sender, key_type, sender_data.generation)
+ .await?;
+
+ let sender = Sender::Member(*sender_data.sender);
+
+ let decrypted_content = MessageKey::new(key)
+ .decrypt(
+ &self.cipher_suite_provider,
+ &ciphertext.ciphertext,
+ &PrivateContentAAD::from(ciphertext).mls_encode_to_vec()?,
+ &sender_data.reuse_guard,
+ )
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
+
+ let ciphertext_content =
+ PrivateMessageContent::mls_decode(&mut &**decrypted_content, ciphertext.content_type)?;
+
+ // Build the MLS plaintext object and process it
+ let auth_content = AuthenticatedContent {
+ wire_format: WireFormat::PrivateMessage,
+ content: FramedContent {
+ group_id: ciphertext.group_id.clone(),
+ epoch: ciphertext.epoch,
+ sender,
+ authenticated_data: ciphertext.authenticated_data.clone(),
+ content: ciphertext_content.content,
+ },
+ auth: ciphertext_content.auth,
+ };
+
+ Ok(auth_content)
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use crate::{
+ cipher_suite::CipherSuite,
+ client::test_utils::{TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
+ crypto::{
+ test_utils::{test_cipher_suite_provider, TestCryptoProvider},
+ CipherSuiteProvider,
+ },
+ group::{
+ framing::{ApplicationData, Content, Sender, WireFormat},
+ message_signature::AuthenticatedContent,
+ padding::PaddingMode,
+ test_utils::{random_bytes, test_group, TestGroup},
+ },
+ tree_kem::node::LeafIndex,
+ };
+
+ use super::{CiphertextProcessor, GroupStateProvider, MlsError};
+
+ use alloc::vec;
+ use assert_matches::assert_matches;
+
+ struct TestData {
+ group: TestGroup,
+ content: AuthenticatedContent,
+ }
+
+ fn test_processor(
+ group: &mut TestGroup,
+ cipher_suite: CipherSuite,
+ ) -> CiphertextProcessor<'_, impl GroupStateProvider, impl CipherSuiteProvider> {
+ CiphertextProcessor::new(&mut group.group, test_cipher_suite_provider(cipher_suite))
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn test_data(cipher_suite: CipherSuite) -> TestData {
+ let provider = test_cipher_suite_provider(cipher_suite);
+
+ let group = test_group(TEST_PROTOCOL_VERSION, cipher_suite).await;
+
+ let content = AuthenticatedContent::new_signed(
+ &provider,
+ group.group.context(),
+ Sender::Member(0),
+ Content::Application(ApplicationData::from(b"test".to_vec())),
+ &group.group.signer,
+ WireFormat::PrivateMessage,
+ vec![],
+ )
+ .await
+ .unwrap();
+
+ TestData { group, content }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_encrypt_decrypt() {
+ for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
+ let mut test_data = test_data(cipher_suite).await;
+ let mut receiver_group = test_data.group.clone();
+
+ let mut ciphertext_processor = test_processor(&mut test_data.group, cipher_suite);
+
+ let ciphertext = ciphertext_processor
+ .seal(test_data.content.clone(), PaddingMode::StepFunction)
+ .await
+ .unwrap();
+
+ receiver_group.group.private_tree.self_index = LeafIndex::new(1);
+
+ let mut receiver_processor = test_processor(&mut receiver_group, cipher_suite);
+
+ let decrypted = receiver_processor.open(&ciphertext).await.unwrap();
+
+ assert_eq!(decrypted, test_data.content);
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_padding_use() {
+ let mut test_data = test_data(TEST_CIPHER_SUITE).await;
+ let mut ciphertext_processor = test_processor(&mut test_data.group, TEST_CIPHER_SUITE);
+
+ let ciphertext_step = ciphertext_processor
+ .seal(test_data.content.clone(), PaddingMode::StepFunction)
+ .await
+ .unwrap();
+
+ let ciphertext_no_pad = ciphertext_processor
+ .seal(test_data.content.clone(), PaddingMode::None)
+ .await
+ .unwrap();
+
+ assert!(ciphertext_step.ciphertext.len() > ciphertext_no_pad.ciphertext.len());
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_invalid_sender() {
+ let mut test_data = test_data(TEST_CIPHER_SUITE).await;
+ test_data.content.content.sender = Sender::Member(3);
+
+ let mut ciphertext_processor = test_processor(&mut test_data.group, TEST_CIPHER_SUITE);
+
+ let res = ciphertext_processor
+ .seal(test_data.content, PaddingMode::None)
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidSender))
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_cant_process_from_self() {
+ let mut test_data = test_data(TEST_CIPHER_SUITE).await;
+
+ let mut ciphertext_processor = test_processor(&mut test_data.group, TEST_CIPHER_SUITE);
+
+ let ciphertext = ciphertext_processor
+ .seal(test_data.content, PaddingMode::None)
+ .await
+ .unwrap();
+
+ let res = ciphertext_processor.open(&ciphertext).await;
+
+ assert_matches!(res, Err(MlsError::CantProcessMessageFromSelf))
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_decryption_error() {
+ let mut test_data = test_data(TEST_CIPHER_SUITE).await;
+ let mut receiver_group = test_data.group.clone();
+ let mut ciphertext_processor = test_processor(&mut test_data.group, TEST_CIPHER_SUITE);
+
+ let mut ciphertext = ciphertext_processor
+ .seal(test_data.content.clone(), PaddingMode::StepFunction)
+ .await
+ .unwrap();
+
+ ciphertext.ciphertext = random_bytes(ciphertext.ciphertext.len());
+ receiver_group.group.private_tree.self_index = LeafIndex::new(1);
+
+ let res = ciphertext_processor.open(&ciphertext).await;
+
+ assert!(res.is_err());
+ }
+}
diff --git a/src/group/ciphertext_processor/message_key.rs b/src/group/ciphertext_processor/message_key.rs
new file mode 100644
index 0000000..256db7d
--- /dev/null
+++ b/src/group/ciphertext_processor/message_key.rs
@@ -0,0 +1,57 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::vec::Vec;
+use zeroize::Zeroizing;
+
+use crate::{crypto::CipherSuiteProvider, group::secret_tree::MessageKeyData};
+
+use super::reuse_guard::ReuseGuard;
+
+#[derive(Debug, PartialEq, Eq)]
+pub struct MessageKey(MessageKeyData);
+
+impl MessageKey {
+ pub(crate) fn new(key: MessageKeyData) -> MessageKey {
+ MessageKey(key)
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn encrypt<P: CipherSuiteProvider>(
+ &self,
+ provider: &P,
+ data: &[u8],
+ aad: &[u8],
+ reuse_guard: &ReuseGuard,
+ ) -> Result<Vec<u8>, P::Error> {
+ provider
+ .aead_seal(
+ &self.0.key,
+ data,
+ Some(aad),
+ &reuse_guard.apply(&self.0.nonce),
+ )
+ .await
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn decrypt<P: CipherSuiteProvider>(
+ &self,
+ provider: &P,
+ data: &[u8],
+ aad: &[u8],
+ reuse_guard: &ReuseGuard,
+ ) -> Result<Zeroizing<Vec<u8>>, P::Error> {
+ provider
+ .aead_open(
+ &self.0.key,
+ data,
+ Some(aad),
+ &reuse_guard.apply(&self.0.nonce),
+ )
+ .await
+ }
+}
+
+// TODO: Write test vectors
diff --git a/src/group/ciphertext_processor/reuse_guard.rs b/src/group/ciphertext_processor/reuse_guard.rs
new file mode 100644
index 0000000..10e1db1
--- /dev/null
+++ b/src/group/ciphertext_processor/reuse_guard.rs
@@ -0,0 +1,133 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::vec::Vec;
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+
+use crate::CipherSuiteProvider;
+
+const REUSE_GUARD_SIZE: usize = 4;
+
+#[derive(Clone, Debug, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
+pub(crate) struct ReuseGuard([u8; REUSE_GUARD_SIZE]);
+
+impl From<[u8; REUSE_GUARD_SIZE]> for ReuseGuard {
+ fn from(value: [u8; REUSE_GUARD_SIZE]) -> Self {
+ ReuseGuard(value)
+ }
+}
+
+impl From<ReuseGuard> for [u8; REUSE_GUARD_SIZE] {
+ fn from(value: ReuseGuard) -> Self {
+ value.0
+ }
+}
+
+impl AsRef<[u8]> for ReuseGuard {
+ fn as_ref(&self) -> &[u8] {
+ &self.0
+ }
+}
+
+impl ReuseGuard {
+ pub(crate) fn random<P: CipherSuiteProvider>(provider: &P) -> Result<Self, P::Error> {
+ let mut data = [0u8; REUSE_GUARD_SIZE];
+ provider.random_bytes(&mut data).map(|_| ReuseGuard(data))
+ }
+
+ pub(crate) fn apply(&self, nonce: &[u8]) -> Vec<u8> {
+ let mut new_nonce = nonce.to_vec();
+
+ new_nonce
+ .iter_mut()
+ .zip(self.as_ref().iter())
+ .for_each(|(nonce_byte, guard_byte)| *nonce_byte ^= guard_byte);
+
+ new_nonce
+ }
+}
+
+#[cfg(test)]
+mod test_utils {
+ use alloc::vec::Vec;
+
+ use super::{ReuseGuard, REUSE_GUARD_SIZE};
+
+ impl ReuseGuard {
+ pub fn new(guard: Vec<u8>) -> Self {
+ let mut data = [0u8; REUSE_GUARD_SIZE];
+ data.copy_from_slice(&guard);
+ Self(data)
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use alloc::vec::Vec;
+ use mls_rs_core::crypto::CipherSuiteProvider;
+
+ use crate::{
+ client::test_utils::TEST_CIPHER_SUITE, crypto::test_utils::test_cipher_suite_provider,
+ };
+
+ use super::{ReuseGuard, REUSE_GUARD_SIZE};
+
+ #[test]
+ fn test_random_generation() {
+ let test_guard =
+ ReuseGuard::random(&test_cipher_suite_provider(TEST_CIPHER_SUITE)).unwrap();
+
+ (0..1000).for_each(|_| {
+ let next = ReuseGuard::random(&test_cipher_suite_provider(TEST_CIPHER_SUITE)).unwrap();
+ assert_ne!(next, test_guard);
+ })
+ }
+
+ #[derive(Debug, serde::Serialize, serde::Deserialize)]
+ struct TestCase {
+ nonce: Vec<u8>,
+ guard: [u8; REUSE_GUARD_SIZE],
+ result: Vec<u8>,
+ }
+
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn generate_reuse_guard_test_cases() -> Vec<TestCase> {
+ let provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ [16, 32]
+ .into_iter()
+ .map(
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ |len| {
+ let nonce = provider.random_bytes_vec(len).unwrap();
+ let guard = ReuseGuard::random(&provider).unwrap();
+
+ let result = guard.apply(&nonce);
+
+ TestCase {
+ nonce,
+ guard: guard.into(),
+ result,
+ }
+ },
+ )
+ .collect()
+ }
+
+ fn load_test_cases() -> Vec<TestCase> {
+ load_test_case_json!(reuse_guard, generate_reuse_guard_test_cases())
+ }
+
+ #[test]
+ fn test_reuse_guard() {
+ let test_cases = load_test_cases();
+
+ for case in test_cases {
+ let guard = ReuseGuard::from(case.guard);
+ let result = guard.apply(&case.nonce);
+ assert_eq!(result, case.result);
+ }
+ }
+}
diff --git a/src/group/ciphertext_processor/sender_data_key.rs b/src/group/ciphertext_processor/sender_data_key.rs
new file mode 100644
index 0000000..983920a
--- /dev/null
+++ b/src/group/ciphertext_processor/sender_data_key.rs
@@ -0,0 +1,360 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::vec::Vec;
+use core::fmt::{self, Debug};
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+use mls_rs_core::error::IntoAnyError;
+use zeroize::Zeroizing;
+
+use crate::{
+ client::MlsError,
+ crypto::CipherSuiteProvider,
+ group::{epoch::SenderDataSecret, framing::ContentType, key_schedule::kdf_expand_with_label},
+ tree_kem::node::LeafIndex,
+};
+
+use super::ReuseGuard;
+
+#[derive(Clone, Debug, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
+pub(crate) struct SenderData {
+ pub sender: LeafIndex,
+ pub generation: u32,
+ pub reuse_guard: ReuseGuard,
+}
+
+#[derive(Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
+pub(crate) struct SenderDataAAD {
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ pub group_id: Vec<u8>,
+ pub epoch: u64,
+ pub content_type: ContentType,
+}
+
+impl Debug for SenderDataAAD {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("SenderDataAAD")
+ .field(
+ "group_id",
+ &mls_rs_core::debug::pretty_group_id(&self.group_id),
+ )
+ .field("epoch", &self.epoch)
+ .field("content_type", &self.content_type)
+ .finish()
+ }
+}
+
+pub(crate) struct SenderDataKey<'a, CP: CipherSuiteProvider> {
+ pub(crate) key: Zeroizing<Vec<u8>>,
+ pub(crate) nonce: Zeroizing<Vec<u8>>,
+ cipher_suite_provider: &'a CP,
+}
+
+impl<CP: CipherSuiteProvider + Debug> Debug for SenderDataKey<'_, CP> {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("SenderDataKey")
+ .field("key", &mls_rs_core::debug::pretty_bytes(&self.key))
+ .field("nonce", &mls_rs_core::debug::pretty_bytes(&self.nonce))
+ .field("cipher_suite_provider", self.cipher_suite_provider)
+ .finish()
+ }
+}
+
+impl<'a, CP: CipherSuiteProvider> SenderDataKey<'a, CP> {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(super) async fn new(
+ sender_data_secret: &SenderDataSecret,
+ ciphertext: &[u8],
+ cipher_suite_provider: &'a CP,
+ ) -> Result<SenderDataKey<'a, CP>, MlsError> {
+ // Sample the first extract_size bytes of the ciphertext, and if it is shorter, just use
+ // the ciphertext itself
+ let extract_size = cipher_suite_provider.kdf_extract_size();
+ let ciphertext_sample = ciphertext.get(0..extract_size).unwrap_or(ciphertext);
+
+ // Generate a sender data key and nonce using the sender_data_secret from the current
+ // epoch's key schedule
+ let key = kdf_expand_with_label(
+ cipher_suite_provider,
+ sender_data_secret,
+ b"key",
+ ciphertext_sample,
+ Some(cipher_suite_provider.aead_key_size()),
+ )
+ .await?;
+
+ let nonce = kdf_expand_with_label(
+ cipher_suite_provider,
+ sender_data_secret,
+ b"nonce",
+ ciphertext_sample,
+ Some(cipher_suite_provider.aead_nonce_size()),
+ )
+ .await?;
+
+ Ok(Self {
+ key,
+ nonce,
+ cipher_suite_provider,
+ })
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn seal(
+ &self,
+ sender_data: &SenderData,
+ aad: &SenderDataAAD,
+ ) -> Result<Vec<u8>, MlsError> {
+ self.cipher_suite_provider
+ .aead_seal(
+ &self.key,
+ &sender_data.mls_encode_to_vec()?,
+ Some(&aad.mls_encode_to_vec()?),
+ &self.nonce,
+ )
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn open(
+ &self,
+ sender_data: &[u8],
+ aad: &SenderDataAAD,
+ ) -> Result<SenderData, MlsError> {
+ self.cipher_suite_provider
+ .aead_open(
+ &self.key,
+ sender_data,
+ Some(&aad.mls_encode_to_vec()?),
+ &self.nonce,
+ )
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
+ .and_then(|data| SenderData::mls_decode(&mut &**data).map_err(From::from))
+ }
+}
+
+#[cfg(test)]
+pub(crate) mod test_utils {
+ use alloc::vec::Vec;
+ use mls_rs_core::crypto::CipherSuiteProvider;
+
+ use super::SenderDataKey;
+
+ #[derive(Debug, serde::Serialize, serde::Deserialize)]
+ pub struct InteropSenderData {
+ #[serde(with = "hex::serde")]
+ pub sender_data_secret: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub ciphertext: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub key: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub nonce: Vec<u8>,
+ }
+
+ impl InteropSenderData {
+ #[cfg(not(mls_build_async))]
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ pub(crate) fn new<P: CipherSuiteProvider>(cs: &P) -> Self {
+ let secret = cs.random_bytes_vec(cs.kdf_extract_size()).unwrap().into();
+ let ciphertext = cs.random_bytes_vec(77).unwrap();
+ let key = SenderDataKey::new(&secret, &ciphertext, cs).unwrap();
+ let secret = (*secret).clone();
+
+ Self {
+ ciphertext,
+ key: key.key.to_vec(),
+ nonce: key.nonce.to_vec(),
+ sender_data_secret: secret,
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn verify<P: CipherSuiteProvider>(&self, cs: &P) {
+ let secret = self.sender_data_secret.clone().into();
+
+ let key = SenderDataKey::new(&secret, &self.ciphertext, cs)
+ .await
+ .unwrap();
+
+ assert_eq!(key.key.to_vec(), self.key, "sender data key mismatch");
+ assert_eq!(key.nonce.to_vec(), self.nonce, "sender data nonce mismatch");
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+
+ use alloc::vec::Vec;
+ #[cfg(target_arch = "wasm32")]
+ use wasm_bindgen_test::wasm_bindgen_test as test;
+
+ use crate::{
+ crypto::test_utils::try_test_cipher_suite_provider,
+ group::{ciphertext_processor::reuse_guard::ReuseGuard, framing::ContentType},
+ tree_kem::node::LeafIndex,
+ };
+
+ use super::{SenderData, SenderDataAAD, SenderDataKey};
+
+ #[cfg(not(mls_build_async))]
+ use crate::{
+ cipher_suite::CipherSuite, crypto::test_utils::test_cipher_suite_provider,
+ group::test_utils::random_bytes, CipherSuiteProvider,
+ };
+
+ #[derive(serde::Deserialize, serde::Serialize)]
+ struct TestCase {
+ cipher_suite: u16,
+ #[serde(with = "hex::serde")]
+ secret: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ ciphertext_bytes: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ expected_key: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ expected_nonce: Vec<u8>,
+ sender_data: TestSenderData,
+ sender_data_aad: TestSenderDataAAD,
+ #[serde(with = "hex::serde")]
+ expected_ciphertext: Vec<u8>,
+ }
+
+ #[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
+ struct TestSenderData {
+ sender: u32,
+ generation: u32,
+ #[serde(with = "hex::serde")]
+ reuse_guard: Vec<u8>,
+ }
+
+ impl From<TestSenderData> for SenderData {
+ fn from(value: TestSenderData) -> Self {
+ let reuse_guard = ReuseGuard::new(value.reuse_guard);
+
+ Self {
+ sender: LeafIndex(value.sender),
+ generation: value.generation,
+ reuse_guard,
+ }
+ }
+ }
+
+ #[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
+ struct TestSenderDataAAD {
+ epoch: u64,
+ #[serde(with = "hex::serde")]
+ group_id: Vec<u8>,
+ }
+
+ impl From<TestSenderDataAAD> for SenderDataAAD {
+ fn from(value: TestSenderDataAAD) -> Self {
+ Self {
+ epoch: value.epoch,
+ group_id: value.group_id,
+ content_type: ContentType::Application,
+ }
+ }
+ }
+
+ #[cfg(not(mls_build_async))]
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn generate_test_vector() -> Vec<TestCase> {
+ let test_cases = CipherSuite::all().map(test_cipher_suite_provider).map(
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ |provider| {
+ let ext_size = provider.kdf_extract_size();
+ let secret = random_bytes(ext_size).into();
+ let ciphertext_sizes = [ext_size - 5, ext_size, ext_size + 5];
+
+ let sender_data = TestSenderData {
+ sender: 0,
+ generation: 13,
+ reuse_guard: random_bytes(4),
+ };
+
+ let sender_data_aad = TestSenderDataAAD {
+ group_id: b"group".to_vec(),
+ epoch: 42,
+ };
+
+ ciphertext_sizes.into_iter().map(
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ move |ciphertext_size| {
+ let ciphertext_bytes = random_bytes(ciphertext_size);
+
+ let sender_data_key =
+ SenderDataKey::new(&secret, &ciphertext_bytes, &provider).unwrap();
+
+ let expected_ciphertext = sender_data_key
+ .seal(&sender_data.clone().into(), &sender_data_aad.clone().into())
+ .unwrap();
+
+ TestCase {
+ cipher_suite: provider.cipher_suite().into(),
+ secret: secret.to_vec(),
+ ciphertext_bytes,
+ expected_key: sender_data_key.key.to_vec(),
+ expected_nonce: sender_data_key.nonce.to_vec(),
+ sender_data: sender_data.clone(),
+ sender_data_aad: sender_data_aad.clone(),
+ expected_ciphertext,
+ }
+ },
+ )
+ },
+ );
+
+ test_cases.flatten().collect()
+ }
+
+ #[cfg(mls_build_async)]
+ fn generate_test_vector() -> Vec<TestCase> {
+ panic!("Tests cannot be generated in async mode");
+ }
+
+ fn load_test_cases() -> Vec<TestCase> {
+ load_test_case_json!(sender_data_key_test_vector, generate_test_vector())
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sender_data_key_test_vector() {
+ for test_case in load_test_cases() {
+ let Some(provider) = try_test_cipher_suite_provider(test_case.cipher_suite) else {
+ continue;
+ };
+
+ let sender_data_key = SenderDataKey::new(
+ &test_case.secret.into(),
+ &test_case.ciphertext_bytes,
+ &provider,
+ )
+ .await
+ .unwrap();
+
+ assert_eq!(sender_data_key.key.to_vec(), test_case.expected_key);
+ assert_eq!(sender_data_key.nonce.to_vec(), test_case.expected_nonce);
+
+ let sender_data = test_case.sender_data.into();
+ let sender_data_aad = test_case.sender_data_aad.into();
+
+ let ciphertext = sender_data_key
+ .seal(&sender_data, &sender_data_aad)
+ .await
+ .unwrap();
+
+ assert_eq!(ciphertext, test_case.expected_ciphertext);
+
+ let plaintext = sender_data_key
+ .open(&ciphertext, &sender_data_aad)
+ .await
+ .unwrap();
+
+ assert_eq!(plaintext, sender_data);
+ }
+ }
+}
diff --git a/src/group/commit.rs b/src/group/commit.rs
new file mode 100644
index 0000000..c201057
--- /dev/null
+++ b/src/group/commit.rs
@@ -0,0 +1,1601 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::vec;
+use alloc::vec::Vec;
+use core::fmt::{self, Debug};
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+use mls_rs_core::{
+ crypto::{CipherSuiteProvider, SignatureSecretKey},
+ error::IntoAnyError,
+};
+
+use crate::{
+ cipher_suite::CipherSuite,
+ client::MlsError,
+ client_config::ClientConfig,
+ extension::RatchetTreeExt,
+ identity::SigningIdentity,
+ protocol_version::ProtocolVersion,
+ signer::Signable,
+ tree_kem::{
+ kem::TreeKem, node::LeafIndex, path_secret::PathSecret, TreeKemPrivate, UpdatePath,
+ },
+ ExtensionList, MlsRules,
+};
+
+#[cfg(all(not(mls_build_async), feature = "rayon"))]
+use {crate::iter::ParallelIteratorExt, rayon::prelude::*};
+
+use crate::tree_kem::leaf_node::LeafNode;
+
+#[cfg(not(feature = "private_message"))]
+use crate::WireFormat;
+
+#[cfg(feature = "psk")]
+use crate::{
+ group::{JustPreSharedKeyID, PskGroupId, ResumptionPSKUsage, ResumptionPsk},
+ psk::ExternalPskId,
+};
+
+use super::{
+ confirmation_tag::ConfirmationTag,
+ framing::{Content, MlsMessage, MlsMessagePayload, Sender},
+ key_schedule::{KeySchedule, WelcomeSecret},
+ message_processor::{path_update_required, MessageProcessor},
+ message_signature::AuthenticatedContent,
+ mls_rules::CommitDirection,
+ proposal::{Proposal, ProposalOrRef},
+ ConfirmedTranscriptHash, EncryptedGroupSecrets, ExportedTree, Group, GroupContext, GroupInfo,
+ Welcome,
+};
+
+#[cfg(not(feature = "by_ref_proposal"))]
+use super::proposal_cache::prepare_commit;
+
+#[cfg(feature = "custom_proposal")]
+use super::proposal::CustomProposal;
+
+#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(mls_rs_core::arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub(crate) struct Commit {
+ pub proposals: Vec<ProposalOrRef>,
+ pub path: Option<UpdatePath>,
+}
+
+#[derive(Clone, PartialEq, Debug, MlsEncode, MlsDecode, MlsSize)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub(super) struct CommitGeneration {
+ pub content: AuthenticatedContent,
+ pub pending_private_tree: TreeKemPrivate,
+ pub pending_commit_secret: PathSecret,
+ pub commit_message_hash: CommitHash,
+}
+
+#[derive(Clone, PartialEq, MlsEncode, MlsDecode, MlsSize)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub(crate) struct CommitHash(
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
+ Vec<u8>,
+);
+
+impl Debug for CommitHash {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ mls_rs_core::debug::pretty_bytes(&self.0)
+ .named("CommitHash")
+ .fmt(f)
+ }
+}
+
+impl CommitHash {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn compute<CS: CipherSuiteProvider>(
+ cs: &CS,
+ commit: &MlsMessage,
+ ) -> Result<Self, MlsError> {
+ cs.hash(&commit.mls_encode_to_vec()?)
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
+ .map(Self)
+ }
+}
+
+#[cfg_attr(
+ all(feature = "ffi", not(test)),
+ safer_ffi_gen::ffi_type(clone, opaque)
+)]
+#[derive(Clone, Debug)]
+#[non_exhaustive]
+/// Result of MLS commit operation using
+/// [`Group::commit`](crate::group::Group::commit) or
+/// [`CommitBuilder::build`](CommitBuilder::build).
+pub struct CommitOutput {
+ /// Commit message to send to other group members.
+ pub commit_message: MlsMessage,
+ /// Welcome messages to send to new group members. If the commit does not add members,
+ /// this list is empty. Otherwise, if [`MlsRules::commit_options`] returns `single_welcome_message`
+ /// set to true, then this list contains a single message sent to all members. Else, the list
+ /// contains one message for each added member. Recipients of each message can be identified using
+ /// [`MlsMessage::key_package_reference`] of their key packages and
+ /// [`MlsMessage::welcome_key_package_references`].
+ pub welcome_messages: Vec<MlsMessage>,
+ /// Ratchet tree that can be sent out of band if
+ /// `ratchet_tree_extension` is not used according to
+ /// [`MlsRules::commit_options`].
+ pub ratchet_tree: Option<ExportedTree<'static>>,
+ /// A group info that can be provided to new members in order to enable external commit
+ /// functionality. This value is set if [`MlsRules::commit_options`] returns
+ /// `allow_external_commit` set to true.
+ pub external_commit_group_info: Option<MlsMessage>,
+ /// Proposals that were received in the prior epoch but not included in the following commit.
+ #[cfg(feature = "by_ref_proposal")]
+ pub unused_proposals: Vec<crate::mls_rules::ProposalInfo<Proposal>>,
+}
+
+#[cfg_attr(all(feature = "ffi", not(test)), ::safer_ffi_gen::safer_ffi_gen)]
+impl CommitOutput {
+ /// Commit message to send to other group members.
+ #[cfg(feature = "ffi")]
+ pub fn commit_message(&self) -> &MlsMessage {
+ &self.commit_message
+ }
+
+ /// Welcome message to send to new group members.
+ #[cfg(feature = "ffi")]
+ pub fn welcome_messages(&self) -> &[MlsMessage] {
+ &self.welcome_messages
+ }
+
+ /// Ratchet tree that can be sent out of band if
+ /// `ratchet_tree_extension` is not used according to
+ /// [`MlsRules::commit_options`].
+ #[cfg(feature = "ffi")]
+ pub fn ratchet_tree(&self) -> Option<&ExportedTree<'static>> {
+ self.ratchet_tree.as_ref()
+ }
+
+ /// A group info that can be provided to new members in order to enable external commit
+ /// functionality. This value is set if [`MlsRules::commit_options`] returns
+ /// `allow_external_commit` set to true.
+ #[cfg(feature = "ffi")]
+ pub fn external_commit_group_info(&self) -> Option<&MlsMessage> {
+ self.external_commit_group_info.as_ref()
+ }
+
+ /// Proposals that were received in the prior epoch but not included in the following commit.
+ #[cfg(all(feature = "ffi", feature = "by_ref_proposal"))]
+ pub fn unused_proposals(&self) -> &[crate::mls_rules::ProposalInfo<Proposal>] {
+ &self.unused_proposals
+ }
+}
+
+/// Build a commit with multiple proposals by-value.
+///
+/// Proposals within a commit can be by-value or by-reference.
+/// Proposals received during the current epoch will be added to the resulting
+/// commit by-reference automatically so long as they pass the rules defined
+/// in the current
+/// [proposal rules](crate::client_builder::ClientBuilder::mls_rules).
+pub struct CommitBuilder<'a, C>
+where
+ C: ClientConfig + Clone,
+{
+ group: &'a mut Group<C>,
+ pub(super) proposals: Vec<Proposal>,
+ authenticated_data: Vec<u8>,
+ group_info_extensions: ExtensionList,
+ new_signer: Option<SignatureSecretKey>,
+ new_signing_identity: Option<SigningIdentity>,
+}
+
+impl<'a, C> CommitBuilder<'a, C>
+where
+ C: ClientConfig + Clone,
+{
+ /// Insert an [`AddProposal`](crate::group::proposal::AddProposal) into
+ /// the current commit that is being built.
+ pub fn add_member(mut self, key_package: MlsMessage) -> Result<CommitBuilder<'a, C>, MlsError> {
+ let proposal = self.group.add_proposal(key_package)?;
+ self.proposals.push(proposal);
+ Ok(self)
+ }
+
+ /// Set group info extensions that will be inserted into the resulting
+ /// [welcome messages](CommitOutput::welcome_messages) for new members.
+ ///
+ /// Group info extensions that are transmitted as part of a welcome message
+ /// are encrypted along with other private values.
+ ///
+ /// These extensions can be retrieved as part of
+ /// [`NewMemberInfo`](crate::group::NewMemberInfo) that is returned
+ /// by joining the group via
+ /// [`Client::join_group`](crate::Client::join_group).
+ pub fn set_group_info_ext(self, extensions: ExtensionList) -> Self {
+ Self {
+ group_info_extensions: extensions,
+ ..self
+ }
+ }
+
+ /// Insert a [`RemoveProposal`](crate::group::proposal::RemoveProposal) into
+ /// the current commit that is being built.
+ pub fn remove_member(mut self, index: u32) -> Result<Self, MlsError> {
+ let proposal = self.group.remove_proposal(index)?;
+ self.proposals.push(proposal);
+ Ok(self)
+ }
+
+ /// Insert a
+ /// [`GroupContextExtensions`](crate::group::proposal::Proposal::GroupContextExtensions)
+ /// into the current commit that is being built.
+ pub fn set_group_context_ext(mut self, extensions: ExtensionList) -> Result<Self, MlsError> {
+ let proposal = self.group.group_context_extensions_proposal(extensions);
+ self.proposals.push(proposal);
+ Ok(self)
+ }
+
+ /// Insert a
+ /// [`PreSharedKeyProposal`](crate::group::proposal::PreSharedKeyProposal) with
+ /// an external PSK into the current commit that is being built.
+ #[cfg(feature = "psk")]
+ pub fn add_external_psk(mut self, psk_id: ExternalPskId) -> Result<Self, MlsError> {
+ let key_id = JustPreSharedKeyID::External(psk_id);
+ let proposal = self.group.psk_proposal(key_id)?;
+ self.proposals.push(proposal);
+ Ok(self)
+ }
+
+ /// Insert a
+ /// [`PreSharedKeyProposal`](crate::group::proposal::PreSharedKeyProposal) with
+ /// a resumption PSK into the current commit that is being built.
+ #[cfg(feature = "psk")]
+ pub fn add_resumption_psk(mut self, psk_epoch: u64) -> Result<Self, MlsError> {
+ let psk_id = ResumptionPsk {
+ psk_epoch,
+ usage: ResumptionPSKUsage::Application,
+ psk_group_id: PskGroupId(self.group.group_id().to_vec()),
+ };
+
+ let key_id = JustPreSharedKeyID::Resumption(psk_id);
+ let proposal = self.group.psk_proposal(key_id)?;
+ self.proposals.push(proposal);
+ Ok(self)
+ }
+
+ /// Insert a [`ReInitProposal`](crate::group::proposal::ReInitProposal) into
+ /// the current commit that is being built.
+ pub fn reinit(
+ mut self,
+ group_id: Option<Vec<u8>>,
+ version: ProtocolVersion,
+ cipher_suite: CipherSuite,
+ extensions: ExtensionList,
+ ) -> Result<Self, MlsError> {
+ let proposal = self
+ .group
+ .reinit_proposal(group_id, version, cipher_suite, extensions)?;
+
+ self.proposals.push(proposal);
+ Ok(self)
+ }
+
+ /// Insert a [`CustomProposal`](crate::group::proposal::CustomProposal) into
+ /// the current commit that is being built.
+ #[cfg(feature = "custom_proposal")]
+ pub fn custom_proposal(mut self, proposal: CustomProposal) -> Self {
+ self.proposals.push(Proposal::Custom(proposal));
+ self
+ }
+
+ /// Insert a proposal that was previously constructed such as when a
+ /// proposal is returned from
+ /// [`StateUpdate::unused_proposals`](super::StateUpdate::unused_proposals).
+ pub fn raw_proposal(mut self, proposal: Proposal) -> Self {
+ self.proposals.push(proposal);
+ self
+ }
+
+ /// Insert proposals that were previously constructed such as when a
+ /// proposal is returned from
+ /// [`StateUpdate::unused_proposals`](super::StateUpdate::unused_proposals).
+ pub fn raw_proposals(mut self, mut proposals: Vec<Proposal>) -> Self {
+ self.proposals.append(&mut proposals);
+ self
+ }
+
+ /// Add additional authenticated data to the commit.
+ ///
+ /// # Warning
+ ///
+ /// The data provided here is always sent unencrypted.
+ pub fn authenticated_data(self, authenticated_data: Vec<u8>) -> Self {
+ Self {
+ authenticated_data,
+ ..self
+ }
+ }
+
+ /// Change the committer's signing identity as part of making this commit.
+ /// This will only succeed if the [`IdentityProvider`](crate::IdentityProvider)
+ /// in use by the group considers the credential inside this signing_identity
+ /// [valid](crate::IdentityProvider::validate_member)
+ /// and results in the same
+ /// [identity](crate::IdentityProvider::identity)
+ /// being used.
+ pub fn set_new_signing_identity(
+ self,
+ signer: SignatureSecretKey,
+ signing_identity: SigningIdentity,
+ ) -> Self {
+ Self {
+ new_signer: Some(signer),
+ new_signing_identity: Some(signing_identity),
+ ..self
+ }
+ }
+
+ /// Finalize the commit to send.
+ ///
+ /// # Errors
+ ///
+ /// This function will return an error if any of the proposals provided
+ /// are not contextually valid according to the rules defined by the
+ /// MLS RFC, or if they do not pass the custom rules defined by the current
+ /// [proposal rules](crate::client_builder::ClientBuilder::mls_rules).
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn build(self) -> Result<CommitOutput, MlsError> {
+ self.group
+ .commit_internal(
+ self.proposals,
+ None,
+ self.authenticated_data,
+ self.group_info_extensions,
+ self.new_signer,
+ self.new_signing_identity,
+ )
+ .await
+ }
+}
+
+impl<C> Group<C>
+where
+ C: ClientConfig + Clone,
+{
+ /// Perform a commit of received proposals.
+ ///
+ /// This function is the equivalent of [`Group::commit_builder`] immediately
+ /// followed by [`CommitBuilder::build`]. Any received proposals since the
+ /// last commit will be included in the resulting message by-reference.
+ ///
+ /// Data provided in the `authenticated_data` field will be placed into
+ /// the resulting commit message unencrypted.
+ ///
+ /// # Pending Commits
+ ///
+ /// When a commit is created, it is not applied immediately in order to
+ /// allow for the resolution of conflicts when multiple members of a group
+ /// attempt to make commits at the same time. For example, a central relay
+ /// can be used to decide which commit should be accepted by the group by
+ /// determining a consistent view of commit packet order for all clients.
+ ///
+ /// Pending commits are stored internally as part of the group's state
+ /// so they do not need to be tracked outside of this library. Any commit
+ /// message that is processed before calling [Group::apply_pending_commit]
+ /// will clear the currently pending commit.
+ ///
+ /// # Empty Commits
+ ///
+ /// Sending a commit that contains no proposals is a valid operation
+ /// within the MLS protocol. It is useful for providing stronger forward
+ /// secrecy and post-compromise security, especially for long running
+ /// groups when group membership does not change often.
+ ///
+ /// # Path Updates
+ ///
+ /// Path updates provide forward secrecy and post-compromise security
+ /// within the MLS protocol.
+ /// The `path_required` option returned by [`MlsRules::commit_options`](`crate::MlsRules::commit_options`)
+ /// controls the ability of a group to send a commit without a path update.
+ /// An update path will automatically be sent if there are no proposals
+ /// in the commit, or if any proposal other than
+ /// [`Add`](crate::group::proposal::Proposal::Add),
+ /// [`Psk`](crate::group::proposal::Proposal::Psk),
+ /// or [`ReInit`](crate::group::proposal::Proposal::ReInit) are part of the commit.
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn commit(&mut self, authenticated_data: Vec<u8>) -> Result<CommitOutput, MlsError> {
+ self.commit_internal(
+ vec![],
+ None,
+ authenticated_data,
+ Default::default(),
+ None,
+ None,
+ )
+ .await
+ }
+
+ /// Create a new commit builder that can include proposals
+ /// by-value.
+ pub fn commit_builder(&mut self) -> CommitBuilder<C> {
+ CommitBuilder {
+ group: self,
+ proposals: Default::default(),
+ authenticated_data: Default::default(),
+ group_info_extensions: Default::default(),
+ new_signer: Default::default(),
+ new_signing_identity: Default::default(),
+ }
+ }
+
+ /// Returns commit and optional [`MlsMessage`] containing a welcome message
+ /// for newly added members.
+ #[allow(clippy::too_many_arguments)]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(super) async fn commit_internal(
+ &mut self,
+ proposals: Vec<Proposal>,
+ external_leaf: Option<&LeafNode>,
+ authenticated_data: Vec<u8>,
+ mut welcome_group_info_extensions: ExtensionList,
+ new_signer: Option<SignatureSecretKey>,
+ new_signing_identity: Option<SigningIdentity>,
+ ) -> Result<CommitOutput, MlsError> {
+ if self.pending_commit.is_some() {
+ return Err(MlsError::ExistingPendingCommit);
+ }
+
+ if self.state.pending_reinit.is_some() {
+ return Err(MlsError::GroupUsedAfterReInit);
+ }
+
+ let mls_rules = self.config.mls_rules();
+
+ let is_external = external_leaf.is_some();
+
+ // Construct an initial Commit object with the proposals field populated from Proposals
+ // received during the current epoch, and an empty path field. Add passed in proposals
+ // by value
+ let sender = if is_external {
+ Sender::NewMemberCommit
+ } else {
+ Sender::Member(*self.private_tree.self_index)
+ };
+
+ let new_signer_ref = new_signer.as_ref().unwrap_or(&self.signer);
+ let old_signer = &self.signer;
+
+ #[cfg(feature = "std")]
+ let time = Some(crate::time::MlsTime::now());
+
+ #[cfg(not(feature = "std"))]
+ let time = None;
+
+ #[cfg(feature = "by_ref_proposal")]
+ let proposals = self.state.proposals.prepare_commit(sender, proposals);
+
+ #[cfg(not(feature = "by_ref_proposal"))]
+ let proposals = prepare_commit(sender, proposals);
+
+ let mut provisional_state = self
+ .state
+ .apply_resolved(
+ sender,
+ proposals,
+ external_leaf,
+ &self.config.identity_provider(),
+ &self.cipher_suite_provider,
+ &self.config.secret_store(),
+ &mls_rules,
+ time,
+ CommitDirection::Send,
+ )
+ .await?;
+
+ let (mut provisional_private_tree, _) =
+ self.provisional_private_tree(&provisional_state)?;
+
+ if is_external {
+ provisional_private_tree.self_index = provisional_state
+ .external_init_index
+ .ok_or(MlsError::ExternalCommitMissingExternalInit)?;
+
+ self.private_tree.self_index = provisional_private_tree.self_index;
+ }
+
+ let mut provisional_group_context = provisional_state.group_context;
+
+ // Decide whether to populate the path field: If the path field is required based on the
+ // proposals that are in the commit (see above), then it MUST be populated. Otherwise, the
+ // sender MAY omit the path field at its discretion.
+ let commit_options = mls_rules
+ .commit_options(
+ &provisional_state.public_tree.roster(),
+ &provisional_group_context.extensions,
+ &provisional_state.applied_proposals,
+ )
+ .map_err(|e| MlsError::MlsRulesError(e.into_any_error()))?;
+
+ let perform_path_update = commit_options.path_required
+ || path_update_required(&provisional_state.applied_proposals);
+
+ let (update_path, path_secrets, commit_secret) = if perform_path_update {
+ // If populating the path field: Create an UpdatePath using the new tree. Any new
+ // member (from an add proposal) MUST be excluded from the resolution during the
+ // computation of the UpdatePath. The GroupContext for this operation uses the
+ // group_id, epoch, tree_hash, and confirmed_transcript_hash values in the initial
+ // GroupContext object. The leaf_key_package for this UpdatePath must have a
+ // parent_hash extension.
+ let encap_gen = TreeKem::new(
+ &mut provisional_state.public_tree,
+ &mut provisional_private_tree,
+ )
+ .encap(
+ &mut provisional_group_context,
+ &provisional_state.indexes_of_added_kpkgs,
+ new_signer_ref,
+ self.config.leaf_properties(),
+ new_signing_identity,
+ &self.cipher_suite_provider,
+ #[cfg(test)]
+ &self.commit_modifiers,
+ )
+ .await?;
+
+ (
+ Some(encap_gen.update_path),
+ Some(encap_gen.path_secrets),
+ encap_gen.commit_secret,
+ )
+ } else {
+ // Update the tree hash, since it was not updated by encap.
+ provisional_state
+ .public_tree
+ .update_hashes(
+ &[provisional_private_tree.self_index],
+ &self.cipher_suite_provider,
+ )
+ .await?;
+
+ provisional_group_context.tree_hash = provisional_state
+ .public_tree
+ .tree_hash(&self.cipher_suite_provider)
+ .await?;
+
+ (None, None, PathSecret::empty(&self.cipher_suite_provider))
+ };
+
+ #[cfg(feature = "psk")]
+ let (psk_secret, psks) = self
+ .get_psk(&provisional_state.applied_proposals.psks)
+ .await?;
+
+ #[cfg(not(feature = "psk"))]
+ let psk_secret = self.get_psk();
+
+ let added_key_pkgs: Vec<_> = provisional_state
+ .applied_proposals
+ .additions
+ .iter()
+ .map(|info| info.proposal.key_package.clone())
+ .collect();
+
+ let commit = Commit {
+ proposals: provisional_state.applied_proposals.into_proposals_or_refs(),
+ path: update_path,
+ };
+
+ let mut auth_content = AuthenticatedContent::new_signed(
+ &self.cipher_suite_provider,
+ self.context(),
+ sender,
+ Content::Commit(alloc::boxed::Box::new(commit)),
+ old_signer,
+ #[cfg(feature = "private_message")]
+ self.encryption_options()?.control_wire_format(sender),
+ #[cfg(not(feature = "private_message"))]
+ WireFormat::PublicMessage,
+ authenticated_data,
+ )
+ .await?;
+
+ // Use the signature, the commit_secret and the psk_secret to advance the key schedule and
+ // compute the confirmation_tag value in the MlsPlaintext.
+ let confirmed_transcript_hash = ConfirmedTranscriptHash::create(
+ self.cipher_suite_provider(),
+ &self.state.interim_transcript_hash,
+ &auth_content,
+ )
+ .await?;
+
+ provisional_group_context.confirmed_transcript_hash = confirmed_transcript_hash;
+
+ let key_schedule_result = KeySchedule::from_key_schedule(
+ &self.key_schedule,
+ &commit_secret,
+ &provisional_group_context,
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ self.state.public_tree.total_leaf_count(),
+ &psk_secret,
+ &self.cipher_suite_provider,
+ )
+ .await?;
+
+ let confirmation_tag = ConfirmationTag::create(
+ &key_schedule_result.confirmation_key,
+ &provisional_group_context.confirmed_transcript_hash,
+ &self.cipher_suite_provider,
+ )
+ .await?;
+
+ auth_content.auth.confirmation_tag = Some(confirmation_tag.clone());
+
+ let ratchet_tree_ext = commit_options
+ .ratchet_tree_extension
+ .then(|| RatchetTreeExt {
+ tree_data: ExportedTree::new(provisional_state.public_tree.nodes.clone()),
+ });
+
+ // Generate external commit group info if required by commit_options
+ let external_commit_group_info = match commit_options.allow_external_commit {
+ true => {
+ let mut extensions = ExtensionList::new();
+
+ extensions.set_from({
+ key_schedule_result
+ .key_schedule
+ .get_external_key_pair_ext(&self.cipher_suite_provider)
+ .await?
+ })?;
+
+ if let Some(ref ratchet_tree_ext) = ratchet_tree_ext {
+ extensions.set_from(ratchet_tree_ext.clone())?;
+ }
+
+ let info = self
+ .make_group_info(
+ &provisional_group_context,
+ extensions,
+ &confirmation_tag,
+ new_signer_ref,
+ )
+ .await?;
+
+ let msg =
+ MlsMessage::new(self.protocol_version(), MlsMessagePayload::GroupInfo(info));
+
+ Some(msg)
+ }
+ false => None,
+ };
+
+ // Build the group info that will be placed into the welcome messages.
+ // Add the ratchet tree extension if necessary
+ if let Some(ratchet_tree_ext) = ratchet_tree_ext {
+ welcome_group_info_extensions.set_from(ratchet_tree_ext)?;
+ }
+
+ let welcome_group_info = self
+ .make_group_info(
+ &provisional_group_context,
+ welcome_group_info_extensions,
+ &confirmation_tag,
+ new_signer_ref,
+ )
+ .await?;
+
+ // Encrypt the GroupInfo using the key and nonce derived from the joiner_secret for
+ // the new epoch
+ let welcome_secret = WelcomeSecret::from_joiner_secret(
+ &self.cipher_suite_provider,
+ &key_schedule_result.joiner_secret,
+ &psk_secret,
+ )
+ .await?;
+
+ let encrypted_group_info = welcome_secret
+ .encrypt(&welcome_group_info.mls_encode_to_vec()?)
+ .await?;
+
+ // Encrypt path secrets and joiner secret to new members
+ let path_secrets = path_secrets.as_ref();
+
+ #[cfg(not(any(mls_build_async, not(feature = "rayon"))))]
+ let encrypted_path_secrets: Vec<_> = added_key_pkgs
+ .into_par_iter()
+ .zip(provisional_state.indexes_of_added_kpkgs)
+ .map(|(key_package, leaf_index)| {
+ self.encrypt_group_secrets(
+ &key_package,
+ leaf_index,
+ &key_schedule_result.joiner_secret,
+ path_secrets,
+ #[cfg(feature = "psk")]
+ psks.clone(),
+ &encrypted_group_info,
+ )
+ })
+ .try_collect()?;
+
+ #[cfg(any(mls_build_async, not(feature = "rayon")))]
+ let encrypted_path_secrets = {
+ let mut secrets = Vec::new();
+
+ for (key_package, leaf_index) in added_key_pkgs
+ .into_iter()
+ .zip(provisional_state.indexes_of_added_kpkgs)
+ {
+ secrets.push(
+ self.encrypt_group_secrets(
+ &key_package,
+ leaf_index,
+ &key_schedule_result.joiner_secret,
+ path_secrets,
+ #[cfg(feature = "psk")]
+ psks.clone(),
+ &encrypted_group_info,
+ )
+ .await?,
+ );
+ }
+
+ secrets
+ };
+
+ let welcome_messages =
+ if commit_options.single_welcome_message && !encrypted_path_secrets.is_empty() {
+ vec![self.make_welcome_message(encrypted_path_secrets, encrypted_group_info)]
+ } else {
+ encrypted_path_secrets
+ .into_iter()
+ .map(|s| self.make_welcome_message(vec![s], encrypted_group_info.clone()))
+ .collect()
+ };
+
+ let commit_message = self.format_for_wire(auth_content.clone()).await?;
+
+ let pending_commit = CommitGeneration {
+ content: auth_content,
+ pending_private_tree: provisional_private_tree,
+ pending_commit_secret: commit_secret,
+ commit_message_hash: CommitHash::compute(&self.cipher_suite_provider, &commit_message)
+ .await?,
+ };
+
+ self.pending_commit = Some(pending_commit);
+
+ let ratchet_tree = (!commit_options.ratchet_tree_extension)
+ .then(|| ExportedTree::new(provisional_state.public_tree.nodes));
+
+ if let Some(signer) = new_signer {
+ self.signer = signer;
+ }
+
+ Ok(CommitOutput {
+ commit_message,
+ welcome_messages,
+ ratchet_tree,
+ external_commit_group_info,
+ #[cfg(feature = "by_ref_proposal")]
+ unused_proposals: provisional_state.unused_proposals,
+ })
+ }
+
+ // Construct a GroupInfo reflecting the new state
+ // Group ID, epoch, tree, and confirmed transcript hash from the new state
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn make_group_info(
+ &self,
+ group_context: &GroupContext,
+ extensions: ExtensionList,
+ confirmation_tag: &ConfirmationTag,
+ signer: &SignatureSecretKey,
+ ) -> Result<GroupInfo, MlsError> {
+ let mut group_info = GroupInfo {
+ group_context: group_context.clone(),
+ extensions,
+ confirmation_tag: confirmation_tag.clone(), // The confirmation_tag from the MlsPlaintext object
+ signer: LeafIndex(self.current_member_index()),
+ signature: vec![],
+ };
+
+ group_info.grease(self.cipher_suite_provider())?;
+
+ // Sign the GroupInfo using the member's private signing key
+ group_info
+ .sign(&self.cipher_suite_provider, signer, &())
+ .await?;
+
+ Ok(group_info)
+ }
+
+ fn make_welcome_message(
+ &self,
+ secrets: Vec<EncryptedGroupSecrets>,
+ encrypted_group_info: Vec<u8>,
+ ) -> MlsMessage {
+ MlsMessage::new(
+ self.context().protocol_version,
+ MlsMessagePayload::Welcome(Welcome {
+ cipher_suite: self.context().cipher_suite,
+ secrets,
+ encrypted_group_info,
+ }),
+ )
+ }
+}
+
+#[cfg(test)]
+pub(crate) mod test_utils {
+ use alloc::vec::Vec;
+
+ use crate::{
+ crypto::SignatureSecretKey,
+ tree_kem::{leaf_node::LeafNode, TreeKemPublic, UpdatePathNode},
+ };
+
+ #[derive(Copy, Clone, Debug)]
+ pub struct CommitModifiers {
+ pub modify_leaf: fn(&mut LeafNode, &SignatureSecretKey) -> Option<SignatureSecretKey>,
+ pub modify_tree: fn(&mut TreeKemPublic),
+ pub modify_path: fn(Vec<UpdatePathNode>) -> Vec<UpdatePathNode>,
+ }
+
+ impl Default for CommitModifiers {
+ fn default() -> Self {
+ Self {
+ modify_leaf: |_, _| None,
+ modify_tree: |_| (),
+ modify_path: |a| a,
+ }
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use alloc::boxed::Box;
+
+ use mls_rs_core::{
+ error::IntoAnyError,
+ extension::ExtensionType,
+ identity::{CredentialType, IdentityProvider},
+ time::MlsTime,
+ };
+
+ use crate::{
+ crypto::test_utils::{test_cipher_suite_provider, TestCryptoProvider},
+ group::{mls_rules::DefaultMlsRules, test_utils::test_group_custom},
+ mls_rules::CommitOptions,
+ Client,
+ };
+
+ #[cfg(feature = "by_ref_proposal")]
+ use crate::extension::ExternalSendersExt;
+
+ use crate::{
+ client::test_utils::{test_client_with_key_pkg, TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
+ client_builder::{
+ test_utils::TestClientConfig, BaseConfig, ClientBuilder, WithCryptoProvider,
+ WithIdentityProvider,
+ },
+ client_config::ClientConfig,
+ extension::test_utils::{TestExtension, TEST_EXTENSION_TYPE},
+ group::{
+ proposal::ProposalType,
+ test_utils::{test_group_custom_config, test_n_member_group},
+ },
+ identity::test_utils::get_test_signing_identity,
+ identity::{basic::BasicIdentityProvider, test_utils::get_test_basic_credential},
+ key_package::test_utils::test_key_package_message,
+ };
+
+ use crate::extension::RequiredCapabilitiesExt;
+
+ #[cfg(feature = "psk")]
+ use crate::{
+ group::proposal::PreSharedKeyProposal,
+ psk::{JustPreSharedKeyID, PreSharedKey, PreSharedKeyID},
+ };
+
+ use super::*;
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn test_commit_builder_group() -> Group<TestClientConfig> {
+ test_group_custom_config(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, |b| {
+ b.custom_proposal_type(ProposalType::from(42))
+ .extension_type(TEST_EXTENSION_TYPE.into())
+ })
+ .await
+ .group
+ }
+
+ fn assert_commit_builder_output<C: ClientConfig>(
+ group: Group<C>,
+ mut commit_output: CommitOutput,
+ expected: Vec<Proposal>,
+ welcome_count: usize,
+ ) {
+ let plaintext = commit_output.commit_message.into_plaintext().unwrap();
+
+ let commit_data = match plaintext.content.content {
+ Content::Commit(commit) => commit,
+ #[cfg(any(feature = "private_message", feature = "by_ref_proposal"))]
+ _ => panic!("Found non-commit data"),
+ };
+
+ assert_eq!(commit_data.proposals.len(), expected.len());
+
+ commit_data.proposals.into_iter().for_each(|proposal| {
+ let proposal = match proposal {
+ ProposalOrRef::Proposal(p) => p,
+ #[cfg(feature = "by_ref_proposal")]
+ ProposalOrRef::Reference(_) => panic!("found proposal reference"),
+ };
+
+ #[cfg(feature = "psk")]
+ if let Some(psk_id) = match proposal.as_ref() {
+ Proposal::Psk(PreSharedKeyProposal { psk: PreSharedKeyID { key_id: JustPreSharedKeyID::External(psk_id), .. },}) => Some(psk_id),
+ _ => None,
+ } {
+ let found = expected.iter().any(|item| matches!(item, Proposal::Psk(PreSharedKeyProposal { psk: PreSharedKeyID { key_id: JustPreSharedKeyID::External(id), .. }}) if id == psk_id));
+
+ assert!(found)
+ } else {
+ assert!(expected.contains(&proposal));
+ }
+
+ #[cfg(not(feature = "psk"))]
+ assert!(expected.contains(&proposal));
+ });
+
+ if welcome_count > 0 {
+ let welcome_msg = commit_output.welcome_messages.pop().unwrap();
+
+ assert_eq!(welcome_msg.version, group.state.context.protocol_version);
+
+ let welcome_msg = welcome_msg.into_welcome().unwrap();
+
+ assert_eq!(welcome_msg.cipher_suite, group.state.context.cipher_suite);
+ assert_eq!(welcome_msg.secrets.len(), welcome_count);
+ } else {
+ assert!(commit_output.welcome_messages.is_empty());
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_commit_builder_add() {
+ let mut group = test_commit_builder_group().await;
+
+ let test_key_package =
+ test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "alice").await;
+
+ let commit_output = group
+ .commit_builder()
+ .add_member(test_key_package.clone())
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ let expected_add = group.add_proposal(test_key_package).unwrap();
+
+ assert_commit_builder_output(group, commit_output, vec![expected_add], 1)
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_commit_builder_add_with_ext() {
+ let mut group = test_commit_builder_group().await;
+
+ let (bob_client, bob_key_package) =
+ test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "bob").await;
+
+ let ext = TestExtension { foo: 42 };
+ let mut extension_list = ExtensionList::default();
+ extension_list.set_from(ext.clone()).unwrap();
+
+ let welcome_message = group
+ .commit_builder()
+ .add_member(bob_key_package)
+ .unwrap()
+ .set_group_info_ext(extension_list)
+ .build()
+ .await
+ .unwrap()
+ .welcome_messages
+ .remove(0);
+
+ let (_, context) = bob_client.join_group(None, &welcome_message).await.unwrap();
+
+ assert_eq!(
+ context
+ .group_info_extensions
+ .get_as::<TestExtension>()
+ .unwrap()
+ .unwrap(),
+ ext
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_commit_builder_remove() {
+ let mut group = test_commit_builder_group().await;
+ let test_key_package =
+ test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "alice").await;
+
+ group
+ .commit_builder()
+ .add_member(test_key_package)
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ group.apply_pending_commit().await.unwrap();
+
+ let commit_output = group
+ .commit_builder()
+ .remove_member(1)
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ let expected_remove = group.remove_proposal(1).unwrap();
+
+ assert_commit_builder_output(group, commit_output, vec![expected_remove], 0);
+ }
+
+ #[cfg(feature = "psk")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_commit_builder_psk() {
+ let mut group = test_commit_builder_group().await;
+ let test_psk = ExternalPskId::new(vec![1]);
+
+ group
+ .config
+ .secret_store()
+ .insert(test_psk.clone(), PreSharedKey::from(vec![1]));
+
+ let commit_output = group
+ .commit_builder()
+ .add_external_psk(test_psk.clone())
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ let key_id = JustPreSharedKeyID::External(test_psk);
+ let expected_psk = group.psk_proposal(key_id).unwrap();
+
+ assert_commit_builder_output(group, commit_output, vec![expected_psk], 0)
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_commit_builder_group_context_ext() {
+ let mut group = test_commit_builder_group().await;
+ let mut test_ext = ExtensionList::default();
+ test_ext
+ .set_from(RequiredCapabilitiesExt::default())
+ .unwrap();
+
+ let commit_output = group
+ .commit_builder()
+ .set_group_context_ext(test_ext.clone())
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ let expected_ext = group.group_context_extensions_proposal(test_ext);
+
+ assert_commit_builder_output(group, commit_output, vec![expected_ext], 0);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_commit_builder_reinit() {
+ let mut group = test_commit_builder_group().await;
+ let test_group_id = "foo".as_bytes().to_vec();
+ let test_cipher_suite = TEST_CIPHER_SUITE;
+ let test_protocol_version = TEST_PROTOCOL_VERSION;
+ let mut test_ext = ExtensionList::default();
+
+ test_ext
+ .set_from(RequiredCapabilitiesExt::default())
+ .unwrap();
+
+ let commit_output = group
+ .commit_builder()
+ .reinit(
+ Some(test_group_id.clone()),
+ test_protocol_version,
+ test_cipher_suite,
+ test_ext.clone(),
+ )
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ let expected_reinit = group
+ .reinit_proposal(
+ Some(test_group_id),
+ test_protocol_version,
+ test_cipher_suite,
+ test_ext,
+ )
+ .unwrap();
+
+ assert_commit_builder_output(group, commit_output, vec![expected_reinit], 0);
+ }
+
+ #[cfg(feature = "custom_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_commit_builder_custom_proposal() {
+ let mut group = test_commit_builder_group().await;
+
+ let proposal = CustomProposal::new(42.into(), vec![0, 1]);
+
+ let commit_output = group
+ .commit_builder()
+ .custom_proposal(proposal.clone())
+ .build()
+ .await
+ .unwrap();
+
+ assert_commit_builder_output(group, commit_output, vec![Proposal::Custom(proposal)], 0);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_commit_builder_chaining() {
+ let mut group = test_commit_builder_group().await;
+ let kp1 = test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "alice").await;
+ let kp2 = test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "bob").await;
+
+ let expected_adds = vec![
+ group.add_proposal(kp1.clone()).unwrap(),
+ group.add_proposal(kp2.clone()).unwrap(),
+ ];
+
+ let commit_output = group
+ .commit_builder()
+ .add_member(kp1)
+ .unwrap()
+ .add_member(kp2)
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ assert_commit_builder_output(group, commit_output, expected_adds, 2);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_commit_builder_empty_commit() {
+ let mut group = test_commit_builder_group().await;
+
+ let commit_output = group.commit_builder().build().await.unwrap();
+
+ assert_commit_builder_output(group, commit_output, vec![], 0);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_commit_builder_authenticated_data() {
+ let mut group = test_commit_builder_group().await;
+ let test_data = "test".as_bytes().to_vec();
+
+ let commit_output = group
+ .commit_builder()
+ .authenticated_data(test_data.clone())
+ .build()
+ .await
+ .unwrap();
+
+ assert_eq!(
+ commit_output
+ .commit_message
+ .into_plaintext()
+ .unwrap()
+ .content
+ .authenticated_data,
+ test_data
+ );
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_commit_builder_multiple_welcome_messages() {
+ let mut group = test_group_custom_config(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, |b| {
+ let options = CommitOptions::new().with_single_welcome_message(false);
+ b.mls_rules(DefaultMlsRules::new().with_commit_options(options))
+ })
+ .await;
+
+ let (alice, alice_kp) =
+ test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "a").await;
+
+ let (bob, bob_kp) =
+ test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "b").await;
+
+ group
+ .group
+ .propose_add(alice_kp.clone(), vec![])
+ .await
+ .unwrap();
+
+ group
+ .group
+ .propose_add(bob_kp.clone(), vec![])
+ .await
+ .unwrap();
+
+ let output = group.group.commit(Vec::new()).await.unwrap();
+ let welcomes = output.welcome_messages;
+
+ let cs = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ for (client, kp) in [(alice, alice_kp), (bob, bob_kp)] {
+ let kp_ref = kp.key_package_reference(&cs).await.unwrap().unwrap();
+
+ let welcome = welcomes
+ .iter()
+ .find(|w| w.welcome_key_package_references().contains(&&kp_ref))
+ .unwrap();
+
+ client.join_group(None, welcome).await.unwrap();
+
+ assert_eq!(welcome.clone().into_welcome().unwrap().secrets.len(), 1);
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn commit_can_change_credential() {
+ let cs = TEST_CIPHER_SUITE;
+ let mut groups = test_n_member_group(TEST_PROTOCOL_VERSION, cs, 3).await;
+ let (identity, secret_key) = get_test_signing_identity(cs, b"member").await;
+
+ let commit_output = groups[0]
+ .group
+ .commit_builder()
+ .set_new_signing_identity(secret_key, identity.clone())
+ .build()
+ .await
+ .unwrap();
+
+ // Check that the credential was updated by in the committer's state.
+ groups[0].process_pending_commit().await.unwrap();
+ let new_member = groups[0].group.roster().member_with_index(0).unwrap();
+
+ assert_eq!(
+ new_member.signing_identity.credential,
+ get_test_basic_credential(b"member".to_vec())
+ );
+
+ assert_eq!(
+ new_member.signing_identity.signature_key,
+ identity.signature_key
+ );
+
+ // Check that the credential was updated in another member's state.
+ groups[1]
+ .process_message(commit_output.commit_message)
+ .await
+ .unwrap();
+
+ let new_member = groups[1].group.roster().member_with_index(0).unwrap();
+
+ assert_eq!(
+ new_member.signing_identity.credential,
+ get_test_basic_credential(b"member".to_vec())
+ );
+
+ assert_eq!(
+ new_member.signing_identity.signature_key,
+ identity.signature_key
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn commit_includes_tree_if_no_ratchet_tree_ext() {
+ let mut group = test_group_custom(
+ TEST_PROTOCOL_VERSION,
+ TEST_CIPHER_SUITE,
+ Default::default(),
+ None,
+ Some(CommitOptions::new().with_ratchet_tree_extension(false)),
+ )
+ .await
+ .group;
+
+ let commit = group.commit(vec![]).await.unwrap();
+
+ group.apply_pending_commit().await.unwrap();
+
+ let new_tree = group.export_tree();
+
+ assert_eq!(new_tree, commit.ratchet_tree.unwrap())
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn commit_does_not_include_tree_if_ratchet_tree_ext() {
+ let mut group = test_group_custom(
+ TEST_PROTOCOL_VERSION,
+ TEST_CIPHER_SUITE,
+ Default::default(),
+ None,
+ Some(CommitOptions::new().with_ratchet_tree_extension(true)),
+ )
+ .await
+ .group;
+
+ let commit = group.commit(vec![]).await.unwrap();
+
+ assert!(commit.ratchet_tree.is_none());
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn commit_includes_external_commit_group_info_if_requested() {
+ let mut group = test_group_custom(
+ TEST_PROTOCOL_VERSION,
+ TEST_CIPHER_SUITE,
+ Default::default(),
+ None,
+ Some(
+ CommitOptions::new()
+ .with_allow_external_commit(true)
+ .with_ratchet_tree_extension(false),
+ ),
+ )
+ .await
+ .group;
+
+ let commit = group.commit(vec![]).await.unwrap();
+
+ let info = commit
+ .external_commit_group_info
+ .unwrap()
+ .into_group_info()
+ .unwrap();
+
+ assert!(!info.extensions.has_extension(ExtensionType::RATCHET_TREE));
+ assert!(info.extensions.has_extension(ExtensionType::EXTERNAL_PUB));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn commit_includes_external_commit_and_tree_if_requested() {
+ let mut group = test_group_custom(
+ TEST_PROTOCOL_VERSION,
+ TEST_CIPHER_SUITE,
+ Default::default(),
+ None,
+ Some(
+ CommitOptions::new()
+ .with_allow_external_commit(true)
+ .with_ratchet_tree_extension(true),
+ ),
+ )
+ .await
+ .group;
+
+ let commit = group.commit(vec![]).await.unwrap();
+
+ let info = commit
+ .external_commit_group_info
+ .unwrap()
+ .into_group_info()
+ .unwrap();
+
+ assert!(info.extensions.has_extension(ExtensionType::RATCHET_TREE));
+ assert!(info.extensions.has_extension(ExtensionType::EXTERNAL_PUB));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn commit_does_not_include_external_commit_group_info_if_not_requested() {
+ let mut group = test_group_custom(
+ TEST_PROTOCOL_VERSION,
+ TEST_CIPHER_SUITE,
+ Default::default(),
+ None,
+ Some(CommitOptions::new().with_allow_external_commit(false)),
+ )
+ .await
+ .group;
+
+ let commit = group.commit(vec![]).await.unwrap();
+
+ assert!(commit.external_commit_group_info.is_none());
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn member_identity_is_validated_against_new_extensions() {
+ let alice = client_with_test_extension(b"alice").await;
+ let mut alice = alice.create_group(ExtensionList::new()).await.unwrap();
+
+ let bob = client_with_test_extension(b"bob").await;
+ let bob_kp = bob.generate_key_package_message().await.unwrap();
+
+ let mut extension_list = ExtensionList::new();
+ let extension = TestExtension { foo: b'a' };
+ extension_list.set_from(extension).unwrap();
+
+ let res = alice
+ .commit_builder()
+ .add_member(bob_kp)
+ .unwrap()
+ .set_group_context_ext(extension_list.clone())
+ .unwrap()
+ .build()
+ .await;
+
+ assert!(res.is_err());
+
+ let alex = client_with_test_extension(b"alex").await;
+
+ alice
+ .commit_builder()
+ .add_member(alex.generate_key_package_message().await.unwrap())
+ .unwrap()
+ .set_group_context_ext(extension_list.clone())
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn server_identity_is_validated_against_new_extensions() {
+ let alice = client_with_test_extension(b"alice").await;
+ let mut alice = alice.create_group(ExtensionList::new()).await.unwrap();
+
+ let mut extension_list = ExtensionList::new();
+ let extension = TestExtension { foo: b'a' };
+ extension_list.set_from(extension).unwrap();
+
+ let (alex_server, _) = get_test_signing_identity(TEST_CIPHER_SUITE, b"alex").await;
+
+ let mut alex_extensions = extension_list.clone();
+
+ alex_extensions
+ .set_from(ExternalSendersExt {
+ allowed_senders: vec![alex_server],
+ })
+ .unwrap();
+
+ let res = alice
+ .commit_builder()
+ .set_group_context_ext(alex_extensions)
+ .unwrap()
+ .build()
+ .await;
+
+ assert!(res.is_err());
+
+ let (bob_server, _) = get_test_signing_identity(TEST_CIPHER_SUITE, b"bob").await;
+
+ let mut bob_extensions = extension_list;
+
+ bob_extensions
+ .set_from(ExternalSendersExt {
+ allowed_senders: vec![bob_server],
+ })
+ .unwrap();
+
+ alice
+ .commit_builder()
+ .set_group_context_ext(bob_extensions)
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+ }
+
+ #[derive(Debug, Clone)]
+ struct IdentityProviderWithExtension(BasicIdentityProvider);
+
+ #[derive(Clone, Debug)]
+ #[cfg_attr(feature = "std", derive(thiserror::Error))]
+ #[cfg_attr(feature = "std", error("test error"))]
+ struct IdentityProviderWithExtensionError {}
+
+ impl IntoAnyError for IdentityProviderWithExtensionError {
+ #[cfg(feature = "std")]
+ fn into_dyn_error(self) -> Result<Box<dyn std::error::Error + Send + Sync>, Self> {
+ Ok(self.into())
+ }
+ }
+
+ impl IdentityProviderWithExtension {
+ // True if the identity starts with the character `foo` from `TestExtension` or if `TestExtension`
+ // is not set.
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn starts_with_foo(
+ &self,
+ identity: &SigningIdentity,
+ _timestamp: Option<MlsTime>,
+ extensions: Option<&ExtensionList>,
+ ) -> bool {
+ if let Some(extensions) = extensions {
+ if let Some(ext) = extensions.get_as::<TestExtension>().unwrap() {
+ self.identity(identity, extensions).await.unwrap()[0] == ext.foo
+ } else {
+ true
+ }
+ } else {
+ true
+ }
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ #[cfg_attr(mls_build_async, maybe_async::must_be_async)]
+ impl IdentityProvider for IdentityProviderWithExtension {
+ type Error = IdentityProviderWithExtensionError;
+
+ async fn validate_member(
+ &self,
+ identity: &SigningIdentity,
+ timestamp: Option<MlsTime>,
+ extensions: Option<&ExtensionList>,
+ ) -> Result<(), Self::Error> {
+ self.starts_with_foo(identity, timestamp, extensions)
+ .await
+ .then_some(())
+ .ok_or(IdentityProviderWithExtensionError {})
+ }
+
+ async fn validate_external_sender(
+ &self,
+ identity: &SigningIdentity,
+ timestamp: Option<MlsTime>,
+ extensions: Option<&ExtensionList>,
+ ) -> Result<(), Self::Error> {
+ (!self.starts_with_foo(identity, timestamp, extensions).await)
+ .then_some(())
+ .ok_or(IdentityProviderWithExtensionError {})
+ }
+
+ async fn identity(
+ &self,
+ signing_identity: &SigningIdentity,
+ extensions: &ExtensionList,
+ ) -> Result<Vec<u8>, Self::Error> {
+ self.0
+ .identity(signing_identity, extensions)
+ .await
+ .map_err(|_| IdentityProviderWithExtensionError {})
+ }
+
+ async fn valid_successor(
+ &self,
+ _predecessor: &SigningIdentity,
+ _successor: &SigningIdentity,
+ _extensions: &ExtensionList,
+ ) -> Result<bool, Self::Error> {
+ Ok(true)
+ }
+
+ fn supported_types(&self) -> Vec<CredentialType> {
+ self.0.supported_types()
+ }
+ }
+
+ type ExtensionClientConfig = WithIdentityProvider<
+ IdentityProviderWithExtension,
+ WithCryptoProvider<TestCryptoProvider, BaseConfig>,
+ >;
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn client_with_test_extension(name: &[u8]) -> Client<ExtensionClientConfig> {
+ let (identity, secret_key) = get_test_signing_identity(TEST_CIPHER_SUITE, name).await;
+
+ ClientBuilder::new()
+ .crypto_provider(TestCryptoProvider::new())
+ .extension_types(vec![TEST_EXTENSION_TYPE.into()])
+ .identity_provider(IdentityProviderWithExtension(BasicIdentityProvider::new()))
+ .signing_identity(identity, secret_key, TEST_CIPHER_SUITE)
+ .build()
+ }
+}
diff --git a/src/group/confirmation_tag.rs b/src/group/confirmation_tag.rs
new file mode 100644
index 0000000..409b382
--- /dev/null
+++ b/src/group/confirmation_tag.rs
@@ -0,0 +1,150 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use crate::CipherSuiteProvider;
+use crate::{client::MlsError, group::transcript_hash::ConfirmedTranscriptHash};
+use alloc::vec::Vec;
+use core::{
+ fmt::{self, Debug},
+ ops::Deref,
+};
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+use mls_rs_core::error::IntoAnyError;
+
+#[derive(Clone, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub struct ConfirmationTag(
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
+ Vec<u8>,
+);
+
+impl Debug for ConfirmationTag {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ mls_rs_core::debug::pretty_bytes(&self.0)
+ .named("ConfirmationTag")
+ .fmt(f)
+ }
+}
+
+impl Deref for ConfirmationTag {
+ type Target = Vec<u8>;
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+impl ConfirmationTag {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn create<P: CipherSuiteProvider>(
+ confirmation_key: &[u8],
+ confirmed_transcript_hash: &ConfirmedTranscriptHash,
+ cipher_suite_provider: &P,
+ ) -> Result<Self, MlsError> {
+ cipher_suite_provider
+ .mac(confirmation_key, confirmed_transcript_hash)
+ .await
+ .map(ConfirmationTag)
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn matches<P: CipherSuiteProvider>(
+ &self,
+ confirmation_key: &[u8],
+ confirmed_transcript_hash: &ConfirmedTranscriptHash,
+ cipher_suite_provider: &P,
+ ) -> Result<bool, MlsError> {
+ let tag = ConfirmationTag::create(
+ confirmation_key,
+ confirmed_transcript_hash,
+ cipher_suite_provider,
+ )
+ .await?;
+
+ Ok(&tag == self)
+ }
+}
+
+#[cfg(test)]
+impl ConfirmationTag {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn empty<P: CipherSuiteProvider>(cipher_suite_provider: &P) -> Self {
+ Self(
+ cipher_suite_provider
+ .mac(
+ &alloc::vec![0; cipher_suite_provider.kdf_extract_size()],
+ &[],
+ )
+ .await
+ .unwrap(),
+ )
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::crypto::test_utils::{test_cipher_suite_provider, TestCryptoProvider};
+
+ #[cfg(target_arch = "wasm32")]
+ use wasm_bindgen_test::wasm_bindgen_test as test;
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_confirmation_tag_matching() {
+ for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
+ let cipher_suite_provider = test_cipher_suite_provider(cipher_suite);
+
+ let confirmed_hash_a = ConfirmedTranscriptHash::from(b"foo_a".to_vec());
+
+ let confirmation_key_a = b"bar_a".to_vec();
+
+ let confirmed_hash_b = ConfirmedTranscriptHash::from(b"foo_b".to_vec());
+
+ let confirmation_key_b = b"bar_b".to_vec();
+
+ let confirmation_tag = ConfirmationTag::create(
+ &confirmation_key_a,
+ &confirmed_hash_a,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap();
+
+ let matches = confirmation_tag
+ .matches(
+ &confirmation_key_a,
+ &confirmed_hash_a,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap();
+
+ assert!(matches);
+
+ let matches = confirmation_tag
+ .matches(
+ &confirmation_key_b,
+ &confirmed_hash_a,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap();
+
+ assert!(!matches);
+
+ let matches = confirmation_tag
+ .matches(
+ &confirmation_key_a,
+ &confirmed_hash_b,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap();
+
+ assert!(!matches);
+ }
+ }
+}
diff --git a/src/group/context.rs b/src/group/context.rs
new file mode 100644
index 0000000..4ec23a9
--- /dev/null
+++ b/src/group/context.rs
@@ -0,0 +1,98 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::vec;
+use alloc::vec::Vec;
+use core::fmt::{self, Debug};
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+
+use crate::{cipher_suite::CipherSuite, protocol_version::ProtocolVersion, ExtensionList};
+
+use super::ConfirmedTranscriptHash;
+
+#[derive(Clone, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(
+ all(feature = "ffi", not(test)),
+ safer_ffi_gen::ffi_type(clone, opaque)
+)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub struct GroupContext {
+ pub(crate) protocol_version: ProtocolVersion,
+ pub(crate) cipher_suite: CipherSuite,
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
+ pub(crate) group_id: Vec<u8>,
+ pub(crate) epoch: u64,
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ pub(crate) tree_hash: Vec<u8>,
+ pub(crate) confirmed_transcript_hash: ConfirmedTranscriptHash,
+ pub(crate) extensions: ExtensionList,
+}
+
+impl Debug for GroupContext {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("GroupContext")
+ .field("protocol_version", &self.protocol_version)
+ .field("cipher_suite", &self.cipher_suite)
+ .field(
+ "group_id",
+ &mls_rs_core::debug::pretty_group_id(&self.group_id),
+ )
+ .field("epoch", &self.epoch)
+ .field(
+ "tree_hash",
+ &mls_rs_core::debug::pretty_bytes(&self.tree_hash),
+ )
+ .field("confirmed_transcript_hash", &self.confirmed_transcript_hash)
+ .field("extensions", &self.extensions)
+ .finish()
+ }
+}
+
+#[cfg_attr(all(feature = "ffi", not(test)), ::safer_ffi_gen::safer_ffi_gen)]
+impl GroupContext {
+ pub(crate) fn new_group(
+ protocol_version: ProtocolVersion,
+ cipher_suite: CipherSuite,
+ group_id: Vec<u8>,
+ tree_hash: Vec<u8>,
+ extensions: ExtensionList,
+ ) -> Self {
+ GroupContext {
+ protocol_version,
+ cipher_suite,
+ group_id,
+ epoch: 0,
+ tree_hash,
+ confirmed_transcript_hash: ConfirmedTranscriptHash::from(vec![]),
+ extensions,
+ }
+ }
+
+ /// Get the current protocol version in use by the group.
+ pub fn version(&self) -> ProtocolVersion {
+ self.protocol_version
+ }
+
+ /// Get the current cipher suite in use by the group.
+ pub fn cipher_suite(&self) -> CipherSuite {
+ self.cipher_suite
+ }
+
+ /// Get the unique identifier of this group.
+ pub fn group_id(&self) -> &[u8] {
+ &self.group_id
+ }
+
+ /// Get the current epoch number of the group's state.
+ pub fn epoch(&self) -> u64 {
+ self.epoch
+ }
+
+ pub fn extensions(&self) -> &ExtensionList {
+ &self.extensions
+ }
+}
diff --git a/src/group/epoch.rs b/src/group/epoch.rs
new file mode 100644
index 0000000..58352d6
--- /dev/null
+++ b/src/group/epoch.rs
@@ -0,0 +1,165 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+#[cfg(feature = "psk")]
+use crate::psk::PreSharedKey;
+#[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+use crate::tree_kem::node::NodeIndex;
+#[cfg(feature = "prior_epoch")]
+use crate::{crypto::SignaturePublicKey, group::GroupContext, tree_kem::node::LeafIndex};
+use alloc::vec::Vec;
+use core::{
+ fmt::{self, Debug},
+ ops::Deref,
+};
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+use zeroize::Zeroizing;
+
+#[cfg(all(feature = "prior_epoch", feature = "private_message"))]
+use super::ciphertext_processor::GroupStateProvider;
+
+#[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+use crate::group::secret_tree::SecretTree;
+
+#[cfg(feature = "prior_epoch")]
+#[derive(Debug, Clone, MlsEncode, MlsDecode, MlsSize, PartialEq)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub(crate) struct PriorEpoch {
+ pub(crate) context: GroupContext,
+ pub(crate) self_index: LeafIndex,
+ pub(crate) secrets: EpochSecrets,
+ pub(crate) signature_public_keys: Vec<Option<SignaturePublicKey>>,
+}
+
+#[cfg(feature = "prior_epoch")]
+impl PriorEpoch {
+ #[inline(always)]
+ pub(crate) fn epoch_id(&self) -> u64 {
+ self.context.epoch
+ }
+
+ #[inline(always)]
+ pub(crate) fn group_id(&self) -> &[u8] {
+ &self.context.group_id
+ }
+}
+
+#[cfg(all(feature = "private_message", feature = "prior_epoch"))]
+impl GroupStateProvider for PriorEpoch {
+ fn group_context(&self) -> &GroupContext {
+ &self.context
+ }
+
+ fn self_index(&self) -> LeafIndex {
+ self.self_index
+ }
+
+ fn epoch_secrets_mut(&mut self) -> &mut EpochSecrets {
+ &mut self.secrets
+ }
+
+ fn epoch_secrets(&self) -> &EpochSecrets {
+ &self.secrets
+ }
+}
+
+#[derive(Debug, Clone, PartialEq, MlsEncode, MlsDecode, MlsSize)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub(crate) struct EpochSecrets {
+ #[cfg(feature = "psk")]
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ pub(crate) resumption_secret: PreSharedKey,
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ pub(crate) sender_data_secret: SenderDataSecret,
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ pub(crate) secret_tree: SecretTree<NodeIndex>,
+}
+
+#[derive(Clone, PartialEq, MlsEncode, MlsDecode, MlsSize)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub(crate) struct SenderDataSecret(
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))]
+ Zeroizing<Vec<u8>>,
+);
+
+impl Debug for SenderDataSecret {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ mls_rs_core::debug::pretty_bytes(&self.0)
+ .named("SenderDataSecret")
+ .fmt(f)
+ }
+}
+
+impl AsRef<[u8]> for SenderDataSecret {
+ fn as_ref(&self) -> &[u8] {
+ &self.0
+ }
+}
+
+impl Deref for SenderDataSecret {
+ type Target = Vec<u8>;
+
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+impl From<Vec<u8>> for SenderDataSecret {
+ fn from(bytes: Vec<u8>) -> Self {
+ Self(Zeroizing::new(bytes))
+ }
+}
+
+impl From<Zeroizing<Vec<u8>>> for SenderDataSecret {
+ fn from(bytes: Zeroizing<Vec<u8>>) -> Self {
+ Self(bytes)
+ }
+}
+
+#[cfg(test)]
+pub(crate) mod test_utils {
+ use mls_rs_core::crypto::CipherSuiteProvider;
+
+ use super::*;
+ use crate::cipher_suite::CipherSuite;
+ use crate::crypto::test_utils::test_cipher_suite_provider;
+
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ use crate::group::secret_tree::test_utils::get_test_tree;
+
+ #[cfg(feature = "prior_epoch")]
+ use crate::group::test_utils::get_test_group_context_with_id;
+
+ use crate::group::test_utils::random_bytes;
+
+ pub(crate) fn get_test_epoch_secrets(cipher_suite: CipherSuite) -> EpochSecrets {
+ let cs_provider = test_cipher_suite_provider(cipher_suite);
+
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ let secret_tree = get_test_tree(random_bytes(cs_provider.kdf_extract_size()), 2);
+
+ EpochSecrets {
+ #[cfg(feature = "psk")]
+ resumption_secret: random_bytes(cs_provider.kdf_extract_size()).into(),
+ sender_data_secret: random_bytes(cs_provider.kdf_extract_size()).into(),
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ secret_tree,
+ }
+ }
+
+ #[cfg(feature = "prior_epoch")]
+ pub(crate) fn get_test_epoch_with_id(
+ group_id: Vec<u8>,
+ cipher_suite: CipherSuite,
+ id: u64,
+ ) -> PriorEpoch {
+ PriorEpoch {
+ context: get_test_group_context_with_id(group_id, id, cipher_suite),
+ self_index: LeafIndex(0),
+ secrets: get_test_epoch_secrets(cipher_suite),
+ signature_public_keys: Default::default(),
+ }
+ }
+}
diff --git a/src/group/exported_tree.rs b/src/group/exported_tree.rs
new file mode 100644
index 0000000..acf507f
--- /dev/null
+++ b/src/group/exported_tree.rs
@@ -0,0 +1,51 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::{borrow::Cow, vec::Vec};
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+
+use crate::{client::MlsError, tree_kem::node::NodeVec};
+
+#[cfg_attr(
+ all(feature = "ffi", not(test)),
+ safer_ffi_gen::ffi_type(clone, opaque)
+)]
+#[derive(Debug, MlsSize, MlsEncode, MlsDecode, PartialEq, Clone)]
+pub struct ExportedTree<'a>(pub(crate) Cow<'a, NodeVec>);
+
+#[cfg_attr(all(feature = "ffi", not(test)), ::safer_ffi_gen::safer_ffi_gen)]
+impl<'a> ExportedTree<'a> {
+ pub(crate) fn new(node_data: NodeVec) -> Self {
+ Self(Cow::Owned(node_data))
+ }
+
+ pub(crate) fn new_borrowed(node_data: &'a NodeVec) -> Self {
+ Self(Cow::Borrowed(node_data))
+ }
+
+ pub fn to_bytes(&self) -> Result<Vec<u8>, MlsError> {
+ self.mls_encode_to_vec().map_err(Into::into)
+ }
+
+ pub fn byte_size(&self) -> usize {
+ self.mls_encoded_len()
+ }
+
+ pub fn into_owned(self) -> ExportedTree<'static> {
+ ExportedTree(Cow::Owned(self.0.into_owned()))
+ }
+}
+
+#[cfg_attr(all(feature = "ffi", not(test)), ::safer_ffi_gen::safer_ffi_gen)]
+impl ExportedTree<'static> {
+ pub fn from_bytes(bytes: &[u8]) -> Result<Self, MlsError> {
+ Self::mls_decode(&mut &*bytes).map_err(Into::into)
+ }
+}
+
+impl From<ExportedTree<'_>> for NodeVec {
+ fn from(value: ExportedTree) -> Self {
+ value.0.into_owned()
+ }
+}
diff --git a/src/group/external_commit.rs b/src/group/external_commit.rs
new file mode 100644
index 0000000..34b1042
--- /dev/null
+++ b/src/group/external_commit.rs
@@ -0,0 +1,266 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use mls_rs_core::{crypto::SignatureSecretKey, identity::SigningIdentity};
+
+use crate::{
+ client_config::ClientConfig,
+ group::{
+ cipher_suite_provider,
+ epoch::SenderDataSecret,
+ key_schedule::{InitSecret, KeySchedule},
+ proposal::{ExternalInit, Proposal, RemoveProposal},
+ EpochSecrets, ExternalPubExt, LeafIndex, LeafNode, MlsError, TreeKemPrivate,
+ },
+ Group, MlsMessage,
+};
+
+#[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+use crate::group::secret_tree::SecretTree;
+
+#[cfg(feature = "custom_proposal")]
+use crate::group::{
+ framing::MlsMessagePayload,
+ message_processor::{EventOrContent, MessageProcessor},
+ message_signature::AuthenticatedContent,
+ message_verifier::verify_plaintext_authentication,
+ CustomProposal,
+};
+
+use alloc::vec;
+use alloc::vec::Vec;
+
+#[cfg(feature = "psk")]
+use mls_rs_core::psk::{ExternalPskId, PreSharedKey};
+
+#[cfg(feature = "psk")]
+use crate::group::{
+ PreSharedKeyProposal, {JustPreSharedKeyID, PreSharedKeyID},
+};
+
+use super::{validate_group_info_joiner, ExportedTree};
+
+/// A builder that aids with the construction of an external commit.
+#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::ffi_type(opaque))]
+pub struct ExternalCommitBuilder<C: ClientConfig> {
+ signer: SignatureSecretKey,
+ signing_identity: SigningIdentity,
+ config: C,
+ tree_data: Option<ExportedTree<'static>>,
+ to_remove: Option<u32>,
+ #[cfg(feature = "psk")]
+ external_psks: Vec<ExternalPskId>,
+ authenticated_data: Vec<u8>,
+ #[cfg(feature = "custom_proposal")]
+ custom_proposals: Vec<Proposal>,
+ #[cfg(feature = "custom_proposal")]
+ received_custom_proposals: Vec<MlsMessage>,
+}
+
+impl<C: ClientConfig> ExternalCommitBuilder<C> {
+ pub(crate) fn new(
+ signer: SignatureSecretKey,
+ signing_identity: SigningIdentity,
+ config: C,
+ ) -> Self {
+ Self {
+ tree_data: None,
+ to_remove: None,
+ authenticated_data: Vec::new(),
+ signer,
+ signing_identity,
+ config,
+ #[cfg(feature = "psk")]
+ external_psks: Vec::new(),
+ #[cfg(feature = "custom_proposal")]
+ custom_proposals: Vec::new(),
+ #[cfg(feature = "custom_proposal")]
+ received_custom_proposals: Vec::new(),
+ }
+ }
+
+ #[must_use]
+ /// Use external tree data if the GroupInfo message does not contain a
+ /// [`RatchetTreeExt`](crate::extension::built_in::RatchetTreeExt)
+ pub fn with_tree_data(self, tree_data: ExportedTree<'static>) -> Self {
+ Self {
+ tree_data: Some(tree_data),
+ ..self
+ }
+ }
+
+ #[must_use]
+ /// Propose the removal of an old version of the client as part of the external commit.
+ /// Only one such proposal is allowed.
+ pub fn with_removal(self, to_remove: u32) -> Self {
+ Self {
+ to_remove: Some(to_remove),
+ ..self
+ }
+ }
+
+ #[must_use]
+ /// Add plaintext authenticated data to the resulting commit message.
+ pub fn with_authenticated_data(self, data: Vec<u8>) -> Self {
+ Self {
+ authenticated_data: data,
+ ..self
+ }
+ }
+
+ #[cfg(feature = "psk")]
+ #[must_use]
+ /// Add an external psk to the group as part of the external commit.
+ pub fn with_external_psk(mut self, psk: ExternalPskId) -> Self {
+ self.external_psks.push(psk);
+ self
+ }
+
+ #[cfg(feature = "custom_proposal")]
+ #[must_use]
+ /// Insert a [`CustomProposal`] into the current commit that is being built.
+ pub fn with_custom_proposal(mut self, proposal: CustomProposal) -> Self {
+ self.custom_proposals.push(Proposal::Custom(proposal));
+ self
+ }
+
+ #[cfg(all(feature = "custom_proposal", feature = "by_ref_proposal"))]
+ #[must_use]
+ /// Insert a [`CustomProposal`] received from a current group member into the current
+ /// commit that is being built.
+ ///
+ /// # Warning
+ ///
+ /// The authenticity of the proposal is NOT fully verified. It is only verified the
+ /// same way as by [`ExternalGroup`](`crate::external_client::ExternalGroup`).
+ /// The proposal MUST be an MlsPlaintext, else the [`Self::build`] function will fail.
+ pub fn with_received_custom_proposal(mut self, proposal: MlsMessage) -> Self {
+ self.received_custom_proposals.push(proposal);
+ self
+ }
+
+ /// Build the external commit using a GroupInfo message provided by an existing group member.
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn build(self, group_info: MlsMessage) -> Result<(Group<C>, MlsMessage), MlsError> {
+ let protocol_version = group_info.version;
+
+ if !self.config.version_supported(protocol_version) {
+ return Err(MlsError::UnsupportedProtocolVersion(protocol_version));
+ }
+
+ let group_info = group_info
+ .into_group_info()
+ .ok_or(MlsError::UnexpectedMessageType)?;
+
+ let cipher_suite = cipher_suite_provider(
+ self.config.crypto_provider(),
+ group_info.group_context.cipher_suite,
+ )?;
+
+ let external_pub_ext = group_info
+ .extensions
+ .get_as::<ExternalPubExt>()?
+ .ok_or(MlsError::MissingExternalPubExtension)?;
+
+ let public_tree = validate_group_info_joiner(
+ protocol_version,
+ &group_info,
+ self.tree_data,
+ &self.config.identity_provider(),
+ &cipher_suite,
+ )
+ .await?;
+
+ let (leaf_node, _) = LeafNode::generate(
+ &cipher_suite,
+ self.config.leaf_properties(),
+ self.signing_identity,
+ &self.signer,
+ self.config.lifetime(),
+ )
+ .await?;
+
+ let (init_secret, kem_output) =
+ InitSecret::encode_for_external(&cipher_suite, &external_pub_ext.external_pub).await?;
+
+ let epoch_secrets = EpochSecrets {
+ #[cfg(feature = "psk")]
+ resumption_secret: PreSharedKey::new(vec![]),
+ sender_data_secret: SenderDataSecret::from(vec![]),
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ secret_tree: SecretTree::empty(),
+ };
+
+ let (mut group, _) = Group::join_with(
+ self.config,
+ group_info,
+ public_tree,
+ KeySchedule::new(init_secret),
+ epoch_secrets,
+ TreeKemPrivate::new_for_external(),
+ None,
+ self.signer,
+ )
+ .await?;
+
+ #[cfg(feature = "psk")]
+ let psk_ids = self
+ .external_psks
+ .into_iter()
+ .map(|psk_id| PreSharedKeyID::new(JustPreSharedKeyID::External(psk_id), &cipher_suite))
+ .collect::<Result<Vec<_>, MlsError>>()?;
+
+ let mut proposals = vec![Proposal::ExternalInit(ExternalInit { kem_output })];
+
+ #[cfg(feature = "psk")]
+ proposals.extend(
+ psk_ids
+ .into_iter()
+ .map(|psk| Proposal::Psk(PreSharedKeyProposal { psk })),
+ );
+
+ #[cfg(feature = "custom_proposal")]
+ {
+ let mut custom_proposals = self.custom_proposals;
+ proposals.append(&mut custom_proposals);
+ }
+
+ #[cfg(all(feature = "custom_proposal", feature = "by_ref_proposal"))]
+ for message in self.received_custom_proposals {
+ let MlsMessagePayload::Plain(plaintext) = message.payload else {
+ return Err(MlsError::UnexpectedMessageType);
+ };
+
+ let auth_content = AuthenticatedContent::from(plaintext.clone());
+
+ verify_plaintext_authentication(&cipher_suite, plaintext, None, None, &group.state)
+ .await?;
+
+ group
+ .process_event_or_content(EventOrContent::Content(auth_content), true, None)
+ .await?;
+ }
+
+ if let Some(r) = self.to_remove {
+ proposals.push(Proposal::Remove(RemoveProposal {
+ to_remove: LeafIndex(r),
+ }));
+ }
+
+ let commit_output = group
+ .commit_internal(
+ proposals,
+ Some(&leaf_node),
+ self.authenticated_data,
+ Default::default(),
+ None,
+ None,
+ )
+ .await?;
+
+ group.apply_pending_commit().await?;
+
+ Ok((group, commit_output.commit_message))
+ }
+}
diff --git a/src/group/framing.rs b/src/group/framing.rs
new file mode 100644
index 0000000..8663b96
--- /dev/null
+++ b/src/group/framing.rs
@@ -0,0 +1,741 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use core::ops::Deref;
+
+use crate::{client::MlsError, tree_kem::node::LeafIndex, KeyPackage, KeyPackageRef};
+
+use super::{Commit, FramedContentAuthData, GroupInfo, MembershipTag, Welcome};
+
+#[cfg(feature = "by_ref_proposal")]
+use crate::{group::Proposal, mls_rules::ProposalRef};
+
+use alloc::vec::Vec;
+use core::fmt::{self, Debug};
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+use mls_rs_core::{
+ crypto::{CipherSuite, CipherSuiteProvider},
+ protocol_version::ProtocolVersion,
+};
+use zeroize::ZeroizeOnDrop;
+
+#[cfg(feature = "private_message")]
+use alloc::boxed::Box;
+
+#[cfg(feature = "custom_proposal")]
+use crate::group::proposal::{CustomProposal, ProposalOrRef};
+
+#[derive(Copy, Clone, Debug, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[repr(u8)]
+pub enum ContentType {
+ #[cfg(feature = "private_message")]
+ Application = 1u8,
+ #[cfg(feature = "by_ref_proposal")]
+ Proposal = 2u8,
+ Commit = 3u8,
+}
+
+impl From<&Content> for ContentType {
+ fn from(content: &Content) -> Self {
+ match content {
+ #[cfg(feature = "private_message")]
+ Content::Application(_) => ContentType::Application,
+ #[cfg(feature = "by_ref_proposal")]
+ Content::Proposal(_) => ContentType::Proposal,
+ Content::Commit(_) => ContentType::Commit,
+ }
+ }
+}
+
+#[cfg_attr(
+ all(feature = "ffi", not(test)),
+ safer_ffi_gen::ffi_type(clone, opaque)
+)]
+#[derive(Clone, Copy, Debug, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+#[repr(u8)]
+#[non_exhaustive]
+/// Description of a [`MlsMessage`] sender
+pub enum Sender {
+ /// Current group member index.
+ Member(u32) = 1u8,
+ /// An external entity sending a proposal proposal identified by an index
+ /// in the current
+ /// [`ExternalSendersExt`](crate::extension::ExternalSendersExt) stored in
+ /// group context extensions.
+ #[cfg(feature = "by_ref_proposal")]
+ External(u32) = 2u8,
+ /// A new member proposing their own addition to the group.
+ #[cfg(feature = "by_ref_proposal")]
+ NewMemberProposal = 3u8,
+ /// A member sending an external commit.
+ NewMemberCommit = 4u8,
+}
+
+impl From<LeafIndex> for Sender {
+ fn from(leaf_index: LeafIndex) -> Self {
+ Sender::Member(*leaf_index)
+ }
+}
+
+impl From<u32> for Sender {
+ fn from(leaf_index: u32) -> Self {
+ Sender::Member(leaf_index)
+ }
+}
+
+#[derive(Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode, ZeroizeOnDrop)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub struct ApplicationData(
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
+ Vec<u8>,
+);
+
+impl Debug for ApplicationData {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ mls_rs_core::debug::pretty_bytes(&self.0)
+ .named("ApplicationData")
+ .fmt(f)
+ }
+}
+
+impl From<Vec<u8>> for ApplicationData {
+ fn from(data: Vec<u8>) -> Self {
+ Self(data)
+ }
+}
+
+impl Deref for ApplicationData {
+ type Target = [u8];
+
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+impl ApplicationData {
+ /// Underlying message content.
+ pub fn as_bytes(&self) -> &[u8] {
+ &self.0
+ }
+}
+
+#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+#[repr(u8)]
+pub(crate) enum Content {
+ #[cfg(feature = "private_message")]
+ Application(ApplicationData) = 1u8,
+ #[cfg(feature = "by_ref_proposal")]
+ Proposal(alloc::boxed::Box<Proposal>) = 2u8,
+ Commit(alloc::boxed::Box<Commit>) = 3u8,
+}
+
+impl Content {
+ pub fn content_type(&self) -> ContentType {
+ self.into()
+ }
+}
+
+#[derive(Clone, Debug, PartialEq)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+pub(crate) struct PublicMessage {
+ pub content: FramedContent,
+ pub auth: FramedContentAuthData,
+ pub membership_tag: Option<MembershipTag>,
+}
+
+impl MlsSize for PublicMessage {
+ fn mls_encoded_len(&self) -> usize {
+ self.content.mls_encoded_len()
+ + self.auth.mls_encoded_len()
+ + self
+ .membership_tag
+ .as_ref()
+ .map_or(0, |tag| tag.mls_encoded_len())
+ }
+}
+
+impl MlsEncode for PublicMessage {
+ fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error> {
+ self.content.mls_encode(writer)?;
+ self.auth.mls_encode(writer)?;
+
+ self.membership_tag
+ .as_ref()
+ .map_or(Ok(()), |tag| tag.mls_encode(writer))
+ }
+}
+
+impl MlsDecode for PublicMessage {
+ fn mls_decode(reader: &mut &[u8]) -> Result<Self, mls_rs_codec::Error> {
+ let content = FramedContent::mls_decode(reader)?;
+ let auth = FramedContentAuthData::mls_decode(reader, content.content_type())?;
+
+ let membership_tag = match content.sender {
+ Sender::Member(_) => Some(MembershipTag::mls_decode(reader)?),
+ _ => None,
+ };
+
+ Ok(Self {
+ content,
+ auth,
+ membership_tag,
+ })
+ }
+}
+
+#[cfg(feature = "private_message")]
+#[derive(Clone, Debug, PartialEq)]
+pub(crate) struct PrivateMessageContent {
+ pub content: Content,
+ pub auth: FramedContentAuthData,
+}
+
+#[cfg(feature = "private_message")]
+impl MlsSize for PrivateMessageContent {
+ fn mls_encoded_len(&self) -> usize {
+ let content_len_without_type = match &self.content {
+ Content::Application(c) => c.mls_encoded_len(),
+ #[cfg(feature = "by_ref_proposal")]
+ Content::Proposal(c) => c.mls_encoded_len(),
+ Content::Commit(c) => c.mls_encoded_len(),
+ };
+
+ content_len_without_type + self.auth.mls_encoded_len()
+ }
+}
+
+#[cfg(feature = "private_message")]
+impl MlsEncode for PrivateMessageContent {
+ fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error> {
+ match &self.content {
+ Content::Application(c) => c.mls_encode(writer),
+ #[cfg(feature = "by_ref_proposal")]
+ Content::Proposal(c) => c.mls_encode(writer),
+ Content::Commit(c) => c.mls_encode(writer),
+ }?;
+
+ self.auth.mls_encode(writer)?;
+
+ Ok(())
+ }
+}
+
+#[cfg(feature = "private_message")]
+impl PrivateMessageContent {
+ pub(crate) fn mls_decode(
+ reader: &mut &[u8],
+ content_type: ContentType,
+ ) -> Result<Self, mls_rs_codec::Error> {
+ let content = match content_type {
+ ContentType::Application => Content::Application(ApplicationData::mls_decode(reader)?),
+ #[cfg(feature = "by_ref_proposal")]
+ ContentType::Proposal => Content::Proposal(Box::new(Proposal::mls_decode(reader)?)),
+ ContentType::Commit => {
+ Content::Commit(alloc::boxed::Box::new(Commit::mls_decode(reader)?))
+ }
+ };
+
+ let auth = FramedContentAuthData::mls_decode(reader, content.content_type())?;
+
+ if reader.iter().any(|&i| i != 0u8) {
+ // #[cfg(feature = "std")]
+ // return Err(mls_rs_codec::Error::Custom(
+ // "non-zero padding bytes discovered".to_string(),
+ // ));
+
+ // #[cfg(not(feature = "std"))]
+ return Err(mls_rs_codec::Error::Custom(5));
+ }
+
+ Ok(Self { content, auth })
+ }
+}
+
+#[cfg(feature = "private_message")]
+#[derive(Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
+pub struct PrivateContentAAD {
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ pub group_id: Vec<u8>,
+ pub epoch: u64,
+ pub content_type: ContentType,
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ pub authenticated_data: Vec<u8>,
+}
+
+#[cfg(feature = "private_message")]
+impl Debug for PrivateContentAAD {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("PrivateContentAAD")
+ .field(
+ "group_id",
+ &mls_rs_core::debug::pretty_group_id(&self.group_id),
+ )
+ .field("epoch", &self.epoch)
+ .field("content_type", &self.content_type)
+ .field(
+ "authenticated_data",
+ &mls_rs_core::debug::pretty_bytes(&self.authenticated_data),
+ )
+ .finish()
+ }
+}
+
+#[cfg(feature = "private_message")]
+#[derive(Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+pub struct PrivateMessage {
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ pub group_id: Vec<u8>,
+ pub epoch: u64,
+ pub content_type: ContentType,
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ pub authenticated_data: Vec<u8>,
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ pub encrypted_sender_data: Vec<u8>,
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ pub ciphertext: Vec<u8>,
+}
+
+#[cfg(feature = "private_message")]
+impl Debug for PrivateMessage {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("PrivateMessage")
+ .field(
+ "group_id",
+ &mls_rs_core::debug::pretty_group_id(&self.group_id),
+ )
+ .field("epoch", &self.epoch)
+ .field("content_type", &self.content_type)
+ .field(
+ "authenticated_data",
+ &mls_rs_core::debug::pretty_bytes(&self.authenticated_data),
+ )
+ .field(
+ "encrypted_sender_data",
+ &mls_rs_core::debug::pretty_bytes(&self.encrypted_sender_data),
+ )
+ .field(
+ "ciphertext",
+ &mls_rs_core::debug::pretty_bytes(&self.ciphertext),
+ )
+ .finish()
+ }
+}
+
+#[cfg(feature = "private_message")]
+impl From<&PrivateMessage> for PrivateContentAAD {
+ fn from(ciphertext: &PrivateMessage) -> Self {
+ Self {
+ group_id: ciphertext.group_id.clone(),
+ epoch: ciphertext.epoch,
+ content_type: ciphertext.content_type,
+ authenticated_data: ciphertext.authenticated_data.clone(),
+ }
+ }
+}
+
+#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(
+ all(feature = "ffi", not(test)),
+ ::safer_ffi_gen::ffi_type(clone, opaque)
+)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+/// A MLS protocol message for sending data over the wire.
+pub struct MlsMessage {
+ pub(crate) version: ProtocolVersion,
+ pub(crate) payload: MlsMessagePayload,
+}
+
+#[cfg_attr(all(feature = "ffi", not(test)), ::safer_ffi_gen::safer_ffi_gen)]
+#[allow(dead_code)]
+impl MlsMessage {
+ pub(crate) fn new(version: ProtocolVersion, payload: MlsMessagePayload) -> MlsMessage {
+ Self { version, payload }
+ }
+
+ #[inline(always)]
+ pub(crate) fn into_plaintext(self) -> Option<PublicMessage> {
+ match self.payload {
+ MlsMessagePayload::Plain(plaintext) => Some(plaintext),
+ _ => None,
+ }
+ }
+
+ #[cfg(feature = "private_message")]
+ #[inline(always)]
+ pub(crate) fn into_ciphertext(self) -> Option<PrivateMessage> {
+ match self.payload {
+ MlsMessagePayload::Cipher(ciphertext) => Some(ciphertext),
+ _ => None,
+ }
+ }
+
+ #[inline(always)]
+ pub(crate) fn into_welcome(self) -> Option<Welcome> {
+ match self.payload {
+ MlsMessagePayload::Welcome(welcome) => Some(welcome),
+ _ => None,
+ }
+ }
+
+ #[inline(always)]
+ pub fn into_group_info(self) -> Option<GroupInfo> {
+ match self.payload {
+ MlsMessagePayload::GroupInfo(info) => Some(info),
+ _ => None,
+ }
+ }
+
+ #[inline(always)]
+ pub fn as_group_info(&self) -> Option<&GroupInfo> {
+ match &self.payload {
+ MlsMessagePayload::GroupInfo(info) => Some(info),
+ _ => None,
+ }
+ }
+
+ #[inline(always)]
+ pub fn into_key_package(self) -> Option<KeyPackage> {
+ match self.payload {
+ MlsMessagePayload::KeyPackage(kp) => Some(kp),
+ _ => None,
+ }
+ }
+
+ /// The wire format value describing the contents of this message.
+ pub fn wire_format(&self) -> WireFormat {
+ match self.payload {
+ MlsMessagePayload::Plain(_) => WireFormat::PublicMessage,
+ #[cfg(feature = "private_message")]
+ MlsMessagePayload::Cipher(_) => WireFormat::PrivateMessage,
+ MlsMessagePayload::Welcome(_) => WireFormat::Welcome,
+ MlsMessagePayload::GroupInfo(_) => WireFormat::GroupInfo,
+ MlsMessagePayload::KeyPackage(_) => WireFormat::KeyPackage,
+ }
+ }
+
+ /// The epoch that this message belongs to.
+ ///
+ /// Returns `None` if the message is [`WireFormat::KeyPackage`]
+ /// or [`WireFormat::Welcome`]
+ #[cfg_attr(all(feature = "ffi", not(test)), ::safer_ffi_gen::safer_ffi_gen_ignore)]
+ pub fn epoch(&self) -> Option<u64> {
+ match &self.payload {
+ MlsMessagePayload::Plain(p) => Some(p.content.epoch),
+ #[cfg(feature = "private_message")]
+ MlsMessagePayload::Cipher(c) => Some(c.epoch),
+ MlsMessagePayload::GroupInfo(gi) => Some(gi.group_context.epoch),
+ _ => None,
+ }
+ }
+
+ #[cfg_attr(all(feature = "ffi", not(test)), ::safer_ffi_gen::safer_ffi_gen_ignore)]
+ pub fn cipher_suite(&self) -> Option<CipherSuite> {
+ match &self.payload {
+ MlsMessagePayload::GroupInfo(i) => Some(i.group_context.cipher_suite),
+ MlsMessagePayload::Welcome(w) => Some(w.cipher_suite),
+ MlsMessagePayload::KeyPackage(k) => Some(k.cipher_suite),
+ _ => None,
+ }
+ }
+
+ pub fn group_id(&self) -> Option<&[u8]> {
+ match &self.payload {
+ MlsMessagePayload::Plain(p) => Some(&p.content.group_id),
+ #[cfg(feature = "private_message")]
+ MlsMessagePayload::Cipher(p) => Some(&p.group_id),
+ MlsMessagePayload::GroupInfo(p) => Some(&p.group_context.group_id),
+ MlsMessagePayload::KeyPackage(_) | MlsMessagePayload::Welcome(_) => None,
+ }
+ }
+
+ /// Deserialize a message from transport.
+ #[inline(never)]
+ pub fn from_bytes(bytes: &[u8]) -> Result<Self, MlsError> {
+ Self::mls_decode(&mut &*bytes).map_err(Into::into)
+ }
+
+ /// Serialize a message for transport.
+ pub fn to_bytes(&self) -> Result<Vec<u8>, MlsError> {
+ self.mls_encode_to_vec().map_err(Into::into)
+ }
+
+ /// If this is a plaintext commit message, return all custom proposals committed by value.
+ /// If this is not a plaintext or not a commit, this returns an empty list.
+ #[cfg(feature = "custom_proposal")]
+ pub fn custom_proposals_by_value(&self) -> Vec<&CustomProposal> {
+ match &self.payload {
+ MlsMessagePayload::Plain(plaintext) => match &plaintext.content.content {
+ Content::Commit(commit) => Self::find_custom_proposals(commit),
+ _ => Vec::new(),
+ },
+ _ => Vec::new(),
+ }
+ }
+
+ /// If this is a welcome message, return key package references of all members who can
+ /// join using this message.
+ pub fn welcome_key_package_references(&self) -> Vec<&KeyPackageRef> {
+ let MlsMessagePayload::Welcome(welcome) = &self.payload else {
+ return Vec::new();
+ };
+
+ welcome.secrets.iter().map(|s| &s.new_member).collect()
+ }
+
+ /// If this is a key package, return its key package reference.
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn key_package_reference<C: CipherSuiteProvider>(
+ &self,
+ cipher_suite: &C,
+ ) -> Result<Option<KeyPackageRef>, MlsError> {
+ let MlsMessagePayload::KeyPackage(kp) = &self.payload else {
+ return Ok(None);
+ };
+
+ kp.to_reference(cipher_suite).await.map(Some)
+ }
+
+ /// If this is a plaintext proposal, return the proposal reference that can be matched e.g. with
+ /// [`StateUpdate::unused_proposals`](super::StateUpdate::unused_proposals).
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn into_proposal_reference<C: CipherSuiteProvider>(
+ self,
+ cipher_suite: &C,
+ ) -> Result<Option<Vec<u8>>, MlsError> {
+ let MlsMessagePayload::Plain(public_message) = self.payload else {
+ return Ok(None);
+ };
+
+ ProposalRef::from_content(cipher_suite, &public_message.into())
+ .await
+ .map(|r| Some(r.to_vec()))
+ }
+}
+
+#[cfg(feature = "custom_proposal")]
+impl MlsMessage {
+ fn find_custom_proposals(commit: &Commit) -> Vec<&CustomProposal> {
+ commit
+ .proposals
+ .iter()
+ .filter_map(|p| match p {
+ ProposalOrRef::Proposal(p) => match p.as_ref() {
+ crate::group::Proposal::Custom(p) => Some(p),
+ _ => None,
+ },
+ _ => None,
+ })
+ .collect()
+ }
+}
+
+#[allow(clippy::large_enum_variant)]
+#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[repr(u16)]
+pub(crate) enum MlsMessagePayload {
+ Plain(PublicMessage) = 1u16,
+ #[cfg(feature = "private_message")]
+ Cipher(PrivateMessage) = 2u16,
+ Welcome(Welcome) = 3u16,
+ GroupInfo(GroupInfo) = 4u16,
+ KeyPackage(KeyPackage) = 5u16,
+}
+
+impl From<PublicMessage> for MlsMessagePayload {
+ fn from(m: PublicMessage) -> Self {
+ Self::Plain(m)
+ }
+}
+
+#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::ffi_type)]
+#[derive(
+ Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd, MlsSize, MlsEncode, MlsDecode,
+)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+#[repr(u16)]
+#[non_exhaustive]
+/// Content description of an [`MlsMessage`]
+pub enum WireFormat {
+ PublicMessage = 1u16,
+ PrivateMessage = 2u16,
+ Welcome = 3u16,
+ GroupInfo = 4u16,
+ KeyPackage = 5u16,
+}
+
+#[derive(Clone, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub(crate) struct FramedContent {
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
+ pub group_id: Vec<u8>,
+ pub epoch: u64,
+ pub sender: Sender,
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
+ pub authenticated_data: Vec<u8>,
+ pub content: Content,
+}
+
+impl Debug for FramedContent {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("FramedContent")
+ .field(
+ "group_id",
+ &mls_rs_core::debug::pretty_group_id(&self.group_id),
+ )
+ .field("epoch", &self.epoch)
+ .field("sender", &self.sender)
+ .field(
+ "authenticated_data",
+ &mls_rs_core::debug::pretty_bytes(&self.authenticated_data),
+ )
+ .field("content", &self.content)
+ .finish()
+ }
+}
+
+impl FramedContent {
+ pub fn content_type(&self) -> ContentType {
+ self.content.content_type()
+ }
+}
+
+#[cfg(test)]
+pub(crate) mod test_utils {
+ #[cfg(feature = "private_message")]
+ use crate::group::test_utils::random_bytes;
+
+ use crate::group::{AuthenticatedContent, MessageSignature};
+
+ use super::*;
+
+ use alloc::boxed::Box;
+
+ pub(crate) fn get_test_auth_content() -> AuthenticatedContent {
+ // This is not a valid commit and should not be validated
+ let commit = Commit {
+ proposals: Default::default(),
+ path: None,
+ };
+
+ AuthenticatedContent {
+ wire_format: WireFormat::PublicMessage,
+ content: FramedContent {
+ group_id: Vec::new(),
+ epoch: 0,
+ sender: Sender::Member(1),
+ authenticated_data: Vec::new(),
+ content: Content::Commit(Box::new(commit)),
+ },
+ auth: FramedContentAuthData {
+ signature: MessageSignature::empty(),
+ confirmation_tag: None,
+ },
+ }
+ }
+
+ #[cfg(feature = "private_message")]
+ pub(crate) fn get_test_ciphertext_content() -> PrivateMessageContent {
+ PrivateMessageContent {
+ content: Content::Application(random_bytes(1024).into()),
+ auth: FramedContentAuthData {
+ signature: MessageSignature::from(random_bytes(128)),
+ confirmation_tag: None,
+ },
+ }
+ }
+
+ impl AsRef<[u8]> for ApplicationData {
+ fn as_ref(&self) -> &[u8] {
+ &self.0
+ }
+ }
+}
+
+#[cfg(feature = "private_message")]
+#[cfg(test)]
+mod tests {
+ use assert_matches::assert_matches;
+
+ use crate::{
+ client::test_utils::{TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
+ crypto::test_utils::test_cipher_suite_provider,
+ group::{
+ framing::test_utils::get_test_ciphertext_content,
+ proposal_ref::test_utils::auth_content_from_proposal, RemoveProposal,
+ },
+ };
+
+ use super::*;
+
+ #[test]
+ fn test_mls_ciphertext_content_mls_encoding() {
+ let ciphertext_content = get_test_ciphertext_content();
+
+ let mut encoded = ciphertext_content.mls_encode_to_vec().unwrap();
+ encoded.extend_from_slice(&[0u8; 128]);
+
+ let decoded =
+ PrivateMessageContent::mls_decode(&mut &*encoded, (&ciphertext_content.content).into())
+ .unwrap();
+
+ assert_eq!(ciphertext_content, decoded);
+ }
+
+ #[test]
+ fn test_mls_ciphertext_content_non_zero_padding_error() {
+ let ciphertext_content = get_test_ciphertext_content();
+
+ let mut encoded = ciphertext_content.mls_encode_to_vec().unwrap();
+ encoded.extend_from_slice(&[1u8; 128]);
+
+ let decoded =
+ PrivateMessageContent::mls_decode(&mut &*encoded, (&ciphertext_content.content).into());
+
+ assert_matches!(decoded, Err(mls_rs_codec::Error::Custom(_)));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn proposal_ref() {
+ let cs = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let test_auth = auth_content_from_proposal(
+ Proposal::Remove(RemoveProposal {
+ to_remove: LeafIndex(0),
+ }),
+ Sender::External(0),
+ );
+
+ let expected_ref = ProposalRef::from_content(&cs, &test_auth).await.unwrap();
+
+ let test_message = MlsMessage {
+ version: TEST_PROTOCOL_VERSION,
+ payload: MlsMessagePayload::Plain(PublicMessage {
+ content: test_auth.content,
+ auth: test_auth.auth,
+ membership_tag: Some(cs.mac(&[1, 2, 3], &[1, 2, 3]).await.unwrap().into()),
+ }),
+ };
+
+ let computed_ref = test_message
+ .into_proposal_reference(&cs)
+ .await
+ .unwrap()
+ .unwrap();
+
+ assert_eq!(computed_ref, expected_ref.to_vec());
+ }
+}
diff --git a/src/group/group_info.rs b/src/group/group_info.rs
new file mode 100644
index 0000000..a5e7268
--- /dev/null
+++ b/src/group/group_info.rs
@@ -0,0 +1,95 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::vec::Vec;
+use core::fmt::{self, Debug};
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+use mls_rs_core::extension::ExtensionList;
+
+use crate::{signer::Signable, tree_kem::node::LeafIndex};
+
+use super::{ConfirmationTag, GroupContext};
+
+#[derive(Clone, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(
+ all(feature = "ffi", not(test)),
+ safer_ffi_gen::ffi_type(clone, opaque)
+)]
+pub struct GroupInfo {
+ pub(crate) group_context: GroupContext,
+ pub(crate) extensions: ExtensionList,
+ pub(crate) confirmation_tag: ConfirmationTag,
+ pub(crate) signer: LeafIndex,
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ pub(crate) signature: Vec<u8>,
+}
+
+impl Debug for GroupInfo {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("GroupInfo")
+ .field("group_context", &self.group_context)
+ .field("extensions", &self.extensions)
+ .field("confirmation_tag", &self.confirmation_tag)
+ .field("signer", &self.signer)
+ .field(
+ "signature",
+ &mls_rs_core::debug::pretty_bytes(&self.signature),
+ )
+ .finish()
+ }
+}
+
+#[cfg_attr(all(feature = "ffi", not(test)), ::safer_ffi_gen::safer_ffi_gen)]
+impl GroupInfo {
+ /// Group context.
+ pub fn group_context(&self) -> &GroupContext {
+ &self.group_context
+ }
+
+ /// Group info extensions (not to be confused with group context extensions),
+ /// e.g. the ratchet tree.
+ pub fn extensions(&self) -> &ExtensionList {
+ &self.extensions
+ }
+
+ /// Leaf index of the sender who generated and signed this group info.
+ pub fn sender(&self) -> u32 {
+ *self.signer
+ }
+}
+
+#[derive(MlsEncode, MlsSize)]
+struct SignableGroupInfo<'a> {
+ group_context: &'a GroupContext,
+ extensions: &'a ExtensionList,
+ confirmation_tag: &'a ConfirmationTag,
+ signer: LeafIndex,
+}
+
+impl<'a> Signable<'a> for GroupInfo {
+ const SIGN_LABEL: &'static str = "GroupInfoTBS";
+ type SigningContext = ();
+
+ fn signature(&self) -> &[u8] {
+ &self.signature
+ }
+
+ fn signable_content(
+ &self,
+ _context: &Self::SigningContext,
+ ) -> Result<Vec<u8>, mls_rs_codec::Error> {
+ SignableGroupInfo {
+ group_context: &self.group_context,
+ extensions: &self.extensions,
+ confirmation_tag: &self.confirmation_tag,
+ signer: self.signer,
+ }
+ .mls_encode_to_vec()
+ }
+
+ fn write_signature(&mut self, signature: Vec<u8>) {
+ self.signature = signature
+ }
+}
diff --git a/src/group/interop_test_vectors.rs b/src/group/interop_test_vectors.rs
new file mode 100644
index 0000000..abd82fc
--- /dev/null
+++ b/src/group/interop_test_vectors.rs
@@ -0,0 +1,9 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+mod framing;
+mod passive_client;
+mod serialization;
+mod tree_kem;
+mod tree_modifications;
diff --git a/src/group/interop_test_vectors/framing.rs b/src/group/interop_test_vectors/framing.rs
new file mode 100644
index 0000000..30e4225
--- /dev/null
+++ b/src/group/interop_test_vectors/framing.rs
@@ -0,0 +1,461 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::vec;
+use alloc::vec::Vec;
+use mls_rs_codec::{MlsDecode, MlsEncode};
+use mls_rs_core::crypto::{CipherSuite, CipherSuiteProvider, SignaturePublicKey};
+
+use crate::{
+ client::test_utils::{TestClientConfig, TEST_PROTOCOL_VERSION},
+ crypto::test_utils::{test_cipher_suite_provider, try_test_cipher_suite_provider},
+ group::{
+ confirmation_tag::ConfirmationTag,
+ epoch::EpochSecrets,
+ framing::{Content, WireFormat},
+ message_processor::{EventOrContent, MessageProcessor},
+ mls_rules::EncryptionOptions,
+ padding::PaddingMode,
+ proposal::{Proposal, RemoveProposal},
+ secret_tree::test_utils::get_test_tree,
+ test_utils::{random_bytes, test_group_custom_config},
+ AuthenticatedContent, Commit, Group, GroupContext, MlsMessage, Sender,
+ },
+ mls_rules::DefaultMlsRules,
+ test_utils::is_edwards,
+ tree_kem::{leaf_node::test_utils::get_basic_test_node, node::LeafIndex},
+};
+
+const FRAMING_N_LEAVES: u32 = 2;
+
+#[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
+struct FramingTestCase {
+ #[serde(flatten)]
+ pub context: InteropGroupContext,
+
+ #[serde(with = "hex::serde")]
+ pub signature_priv: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub signature_pub: Vec<u8>,
+
+ #[serde(with = "hex::serde")]
+ pub encryption_secret: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub sender_data_secret: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub membership_key: Vec<u8>,
+
+ #[serde(with = "hex::serde")]
+ pub proposal: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub proposal_priv: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub proposal_pub: Vec<u8>,
+
+ #[serde(with = "hex::serde")]
+ pub commit: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub commit_priv: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub commit_pub: Vec<u8>,
+
+ #[serde(with = "hex::serde")]
+ pub application: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub application_priv: Vec<u8>,
+}
+
+impl FramingTestCase {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ async fn random<P: CipherSuiteProvider>(cs: &P) -> Self {
+ let mut context = InteropGroupContext::random(cs);
+ context.cipher_suite = cs.cipher_suite().into();
+
+ let (mut signature_priv, signature_pub) = cs.signature_key_generate().await.unwrap();
+
+ if is_edwards(*cs.cipher_suite()) {
+ signature_priv = signature_priv[0..signature_priv.len() / 2].to_vec().into();
+ }
+
+ Self {
+ context,
+ signature_priv: signature_priv.to_vec(),
+ signature_pub: signature_pub.to_vec(),
+ encryption_secret: random_bytes(cs.kdf_extract_size()),
+ sender_data_secret: random_bytes(cs.kdf_extract_size()),
+ membership_key: random_bytes(cs.kdf_extract_size()),
+ ..Default::default()
+ }
+ }
+}
+
+#[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
+pub struct InteropGroupContext {
+ pub cipher_suite: u16,
+ #[serde(with = "hex::serde")]
+ pub group_id: Vec<u8>,
+ pub epoch: u64,
+ #[serde(with = "hex::serde")]
+ pub tree_hash: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub confirmed_transcript_hash: Vec<u8>,
+}
+
+impl InteropGroupContext {
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn random<P: CipherSuiteProvider>(cs: &P) -> Self {
+ Self {
+ cipher_suite: cs.cipher_suite().into(),
+ group_id: random_bytes(cs.kdf_extract_size()),
+ epoch: 0x121212,
+ tree_hash: random_bytes(cs.kdf_extract_size()),
+ confirmed_transcript_hash: random_bytes(cs.kdf_extract_size()),
+ }
+ }
+}
+
+impl From<InteropGroupContext> for GroupContext {
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn from(ctx: InteropGroupContext) -> Self {
+ Self {
+ cipher_suite: ctx.cipher_suite.into(),
+ protocol_version: TEST_PROTOCOL_VERSION,
+ group_id: ctx.group_id,
+ epoch: ctx.epoch,
+ tree_hash: ctx.tree_hash,
+ confirmed_transcript_hash: ctx.confirmed_transcript_hash.into(),
+ extensions: vec![].into(),
+ }
+ }
+}
+
+// The test vector can be found here:
+// https://github.com/mlswg/mls-implementations/blob/main/test-vectors/message-protection.json
+#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+async fn framing_proposal() {
+ #[cfg(not(mls_build_async))]
+ let test_cases: Vec<FramingTestCase> =
+ load_test_case_json!(framing, generate_framing_test_vector());
+
+ #[cfg(mls_build_async)]
+ let test_cases: Vec<FramingTestCase> =
+ load_test_case_json!(framing, generate_framing_test_vector().await);
+
+ for test_case in test_cases.into_iter() {
+ let Some(cs) = try_test_cipher_suite_provider(test_case.context.cipher_suite) else {
+ continue;
+ };
+
+ let to_check = vec![
+ test_case.proposal_priv.clone(),
+ test_case.proposal_pub.clone(),
+ ];
+
+ // Wasm uses incompatible signature secret key format
+ #[cfg(not(target_arch = "wasm32"))]
+ let mut to_check = to_check;
+
+ #[cfg(not(target_arch = "wasm32"))]
+ for enable_encryption in [true, false] {
+ let proposal = Proposal::mls_decode(&mut &*test_case.proposal).unwrap();
+
+ let built = make_group(&test_case, true, enable_encryption, &cs)
+ .await
+ .proposal_message(proposal, vec![])
+ .await
+ .unwrap()
+ .mls_encode_to_vec()
+ .unwrap();
+
+ to_check.push(built);
+ }
+
+ let proposal = Proposal::mls_decode(&mut &*test_case.proposal).unwrap();
+
+ for message in to_check {
+ match process_message(&test_case, &message, &cs).await {
+ Content::Proposal(p) => assert_eq!(p.as_ref(), &proposal),
+ _ => panic!("received value not proposal"),
+ };
+ }
+ }
+}
+
+// The test vector can be found here:
+// https://github.com/mlswg/mls-implementations/blob/main/test-vectors/message-protection.json
+// Wasm uses incompatible signature secret key format
+#[cfg(not(target_arch = "wasm32"))]
+#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+async fn framing_application() {
+ #[cfg(not(mls_build_async))]
+ let test_cases: Vec<FramingTestCase> =
+ load_test_case_json!(framing, generate_framing_test_vector());
+
+ #[cfg(mls_build_async)]
+ let test_cases: Vec<FramingTestCase> =
+ load_test_case_json!(framing, generate_framing_test_vector().await);
+
+ for test_case in test_cases.into_iter() {
+ let Some(cs) = try_test_cipher_suite_provider(test_case.context.cipher_suite) else {
+ continue;
+ };
+
+ let built_priv = make_group(&test_case, true, true, &cs)
+ .await
+ .encrypt_application_message(&test_case.application, vec![])
+ .await
+ .unwrap()
+ .mls_encode_to_vec()
+ .unwrap();
+
+ for message in [&test_case.application_priv, &built_priv] {
+ match process_message(&test_case, message, &cs).await {
+ Content::Application(data) => assert_eq!(data.as_ref(), &test_case.application),
+ _ => panic!("decrypted value not application data"),
+ };
+ }
+ }
+}
+
+// The test vector can be found here:
+// https://github.com/mlswg/mls-implementations/blob/main/test-vectors/message-protection.json
+#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+async fn framing_commit() {
+ #[cfg(not(mls_build_async))]
+ let test_cases: Vec<FramingTestCase> =
+ load_test_case_json!(framing, generate_framing_test_vector());
+
+ #[cfg(mls_build_async)]
+ let test_cases: Vec<FramingTestCase> =
+ load_test_case_json!(framing, generate_framing_test_vector().await);
+
+ for test_case in test_cases.into_iter() {
+ let Some(cs) = try_test_cipher_suite_provider(test_case.context.cipher_suite) else {
+ continue;
+ };
+
+ let commit = Commit::mls_decode(&mut &*test_case.commit).unwrap();
+
+ let to_check = vec![test_case.commit_priv.clone(), test_case.commit_pub.clone()];
+
+ // Wasm uses incompatible signature secret key format
+ #[cfg(not(target_arch = "wasm32"))]
+ let to_check = {
+ let mut to_check = to_check;
+
+ let mut signature_priv = test_case.signature_priv.clone();
+
+ if is_edwards(test_case.context.cipher_suite) {
+ signature_priv.extend(test_case.signature_pub.iter());
+ }
+
+ let mut auth_content = AuthenticatedContent::new_signed(
+ &cs,
+ &test_case.context.clone().into(),
+ Sender::Member(1),
+ Content::Commit(alloc::boxed::Box::new(commit.clone())),
+ &signature_priv.into(),
+ WireFormat::PublicMessage,
+ vec![],
+ )
+ .await
+ .unwrap();
+
+ auth_content.auth.confirmation_tag = Some(ConfirmationTag::empty(&cs).await);
+
+ for enable_encryption in [true, false] {
+ let built = make_group(&test_case, true, enable_encryption, &cs)
+ .await
+ .format_for_wire(auth_content.clone())
+ .await
+ .unwrap()
+ .mls_encode_to_vec()
+ .unwrap();
+
+ to_check.push(built);
+ }
+
+ to_check
+ };
+
+ for message in to_check {
+ match process_message(&test_case, &message, &cs).await {
+ Content::Commit(c) => assert_eq!(&*c, &commit),
+ _ => panic!("received value not commit"),
+ };
+ }
+ let commit = Commit::mls_decode(&mut &*test_case.commit).unwrap();
+
+ match process_message(&test_case, &test_case.commit_priv.clone(), &cs).await {
+ Content::Commit(c) => assert_eq!(&*c, &commit),
+ _ => panic!("received value not commit"),
+ };
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(coverage_nightly, coverage(off))]
+async fn generate_framing_test_vector() -> Vec<FramingTestCase> {
+ let mut test_vector = vec![];
+
+ for cs in CipherSuite::all() {
+ let cs = test_cipher_suite_provider(cs);
+
+ let mut test_case = FramingTestCase::random(&cs).await;
+
+ // Generate private application message
+ test_case.application = cs.random_bytes_vec(42).unwrap();
+
+ let application_priv = make_group(&test_case, true, true, &cs)
+ .await
+ .encrypt_application_message(&test_case.application, vec![])
+ .await
+ .unwrap();
+
+ test_case.application_priv = application_priv.mls_encode_to_vec().unwrap();
+
+ // Generate private and public proposal message
+ let proposal = Proposal::Remove(RemoveProposal {
+ to_remove: LeafIndex(2),
+ });
+
+ test_case.proposal = proposal.mls_encode_to_vec().unwrap();
+
+ let mut group = make_group(&test_case, true, false, &cs).await;
+ let proposal_pub = group.proposal_message(proposal.clone(), vec![]).await;
+ test_case.proposal_pub = proposal_pub.unwrap().mls_encode_to_vec().unwrap();
+
+ let mut group = make_group(&test_case, true, true, &cs).await;
+ let proposal_priv = group.proposal_message(proposal, vec![]).await.unwrap();
+ test_case.proposal_priv = proposal_priv.mls_encode_to_vec().unwrap();
+
+ // Generate private and public commit message
+ let commit = Commit {
+ proposals: vec![],
+ path: None,
+ };
+
+ test_case.commit = commit.mls_encode_to_vec().unwrap();
+
+ let mut auth_content = AuthenticatedContent::new_signed(
+ &cs,
+ group.context(),
+ Sender::Member(1),
+ Content::Commit(alloc::boxed::Box::new(commit.clone())),
+ &group.signer,
+ WireFormat::PublicMessage,
+ vec![],
+ )
+ .await
+ .unwrap();
+
+ auth_content.auth.confirmation_tag = Some(ConfirmationTag::empty(&cs).await);
+
+ let mut group = make_group(&test_case, true, false, &cs).await;
+ let commit_pub = group.format_for_wire(auth_content.clone()).await.unwrap();
+ test_case.commit_pub = commit_pub.mls_encode_to_vec().unwrap();
+
+ let mut auth_content = AuthenticatedContent::new_signed(
+ &cs,
+ group.context(),
+ Sender::Member(1),
+ Content::Commit(alloc::boxed::Box::new(commit)),
+ &group.signer,
+ WireFormat::PrivateMessage,
+ vec![],
+ )
+ .await
+ .unwrap();
+
+ auth_content.auth.confirmation_tag = Some(ConfirmationTag::empty(&cs).await);
+
+ let mut group = make_group(&test_case, true, true, &cs).await;
+ let commit_priv = group.format_for_wire(auth_content.clone()).await.unwrap();
+ test_case.commit_priv = commit_priv.mls_encode_to_vec().unwrap();
+
+ test_vector.push(test_case);
+ }
+
+ test_vector
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+async fn make_group<P: CipherSuiteProvider>(
+ test_case: &FramingTestCase,
+ for_send: bool,
+ control_encryption_enabled: bool,
+ cs: &P,
+) -> Group<TestClientConfig> {
+ let mut group =
+ test_group_custom_config(
+ TEST_PROTOCOL_VERSION,
+ test_case.context.cipher_suite.into(),
+ |b| {
+ b.mls_rules(DefaultMlsRules::default().with_encryption_options(
+ EncryptionOptions::new(control_encryption_enabled, PaddingMode::None),
+ ))
+ },
+ )
+ .await
+ .group;
+
+ // Add a leaf for the sender. It will get index 1.
+ let mut leaf = get_basic_test_node(cs.cipher_suite(), "leaf").await;
+
+ leaf.signing_identity.signature_key = SignaturePublicKey::from(test_case.signature_pub.clone());
+
+ group
+ .state
+ .public_tree
+ .add_leaves(vec![leaf], &group.config.0.identity_provider, cs)
+ .await
+ .unwrap();
+
+ // Convince the group that their index is 1 if they send or 0 if they receive.
+ group.private_tree.self_index = LeafIndex(if for_send { 1 } else { 0 });
+
+ // Convince the group that their signing key is the one from the test case
+ let mut signature_priv = test_case.signature_priv.clone();
+
+ if is_edwards(test_case.context.cipher_suite) {
+ signature_priv.extend(test_case.signature_pub.iter());
+ }
+
+ group.signer = signature_priv.into();
+
+ // Set the group context and secrets
+ let context = GroupContext::from(test_case.context.clone());
+ let secret_tree = get_test_tree(test_case.encryption_secret.clone(), FRAMING_N_LEAVES);
+
+ let secrets = EpochSecrets {
+ secret_tree,
+ resumption_secret: vec![0_u8; cs.kdf_extract_size()].into(),
+ sender_data_secret: test_case.sender_data_secret.clone().into(),
+ };
+
+ group.epoch_secrets = secrets;
+ group.state.context = context;
+ let membership_key = test_case.membership_key.clone();
+ group.key_schedule.set_membership_key(membership_key);
+
+ group
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+async fn process_message<P: CipherSuiteProvider>(
+ test_case: &FramingTestCase,
+ message: &[u8],
+ cs: &P,
+) -> Content {
+ // Enabling encryption doesn't matter for processing
+ let mut group = make_group(test_case, false, true, cs).await;
+ let message = MlsMessage::mls_decode(&mut &*message).unwrap();
+ let evt_or_cont = group.get_event_from_incoming_message(message);
+
+ match evt_or_cont.await.unwrap() {
+ EventOrContent::Content(content) => content.content.content,
+ EventOrContent::Event(_) => panic!("expected content, got event"),
+ }
+}
diff --git a/src/group/interop_test_vectors/passive_client.rs b/src/group/interop_test_vectors/passive_client.rs
new file mode 100644
index 0000000..29588ed
--- /dev/null
+++ b/src/group/interop_test_vectors/passive_client.rs
@@ -0,0 +1,732 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::vec;
+use alloc::vec::Vec;
+
+use itertools::Itertools;
+use mls_rs_core::{
+ crypto::{CipherSuite, CipherSuiteProvider, CryptoProvider},
+ identity::SigningIdentity,
+ protocol_version::ProtocolVersion,
+ psk::ExternalPskId,
+ time::MlsTime,
+};
+use rand::{seq::IteratorRandom, Rng, SeedableRng};
+
+use crate::{
+ client_builder::{ClientBuilder, MlsConfig},
+ crypto::test_utils::TestCryptoProvider,
+ group::{ClientConfig, CommitBuilder, ExportedTree},
+ identity::basic::BasicIdentityProvider,
+ key_package::KeyPackageGeneration,
+ mls_rules::CommitOptions,
+ storage_provider::in_memory::InMemoryKeyPackageStorage,
+ test_utils::{
+ all_process_message, generate_basic_client, get_test_basic_credential, get_test_groups,
+ make_test_ext_psk, TEST_EXT_PSK_ID,
+ },
+ tree_kem::Lifetime,
+ Client, Group, MlsMessage,
+};
+
+const VERSION: ProtocolVersion = ProtocolVersion::MLS_10;
+
+const ETERNAL_LIFETIME: Lifetime = Lifetime {
+ not_before: 0,
+ not_after: u64::MAX,
+};
+
+#[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
+pub struct TestCase {
+ pub cipher_suite: u16,
+
+ pub external_psks: Vec<TestExternalPsk>,
+ #[serde(with = "hex::serde")]
+ pub key_package: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub signature_priv: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub encryption_priv: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub init_priv: Vec<u8>,
+
+ #[serde(with = "hex::serde")]
+ pub welcome: Vec<u8>,
+ pub ratchet_tree: Option<TestRatchetTree>,
+ #[serde(with = "hex::serde")]
+ pub initial_epoch_authenticator: Vec<u8>,
+
+ pub epochs: Vec<TestEpoch>,
+}
+
+#[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
+pub struct TestExternalPsk {
+ #[serde(with = "hex::serde")]
+ pub psk_id: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub psk: Vec<u8>,
+}
+
+#[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
+pub struct TestEpoch {
+ pub proposals: Vec<TestMlsMessage>,
+ #[serde(with = "hex::serde")]
+ pub commit: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub epoch_authenticator: Vec<u8>,
+}
+
+#[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
+pub struct TestMlsMessage(#[serde(with = "hex::serde")] pub Vec<u8>);
+
+#[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
+pub struct TestRatchetTree(#[serde(with = "hex::serde")] pub Vec<u8>);
+
+impl TestEpoch {
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ pub fn new(
+ proposals: Vec<MlsMessage>,
+ commit: &MlsMessage,
+ epoch_authenticator: Vec<u8>,
+ ) -> Self {
+ let proposals = proposals
+ .into_iter()
+ .map(|p| TestMlsMessage(p.to_bytes().unwrap()))
+ .collect();
+
+ Self {
+ proposals,
+ commit: commit.to_bytes().unwrap(),
+ epoch_authenticator,
+ }
+ }
+}
+
+#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+async fn interop_passive_client() {
+ // Test vectors can be found here:
+ // * https://github.com/mlswg/mls-implementations/blob/main/test-vectors/passive-client-welcome.json
+ // * https://github.com/mlswg/mls-implementations/blob/main/test-vectors/passive-client-handle-commit.json
+ // * https://github.com/mlswg/mls-implementations/blob/main/test-vectors/passive-client-random.json
+
+ #[cfg(mls_build_async)]
+ let (test_cases_wel, test_cases_com, test_cases_rand) = {
+ let test_cases_wel: Vec<TestCase> = load_test_case_json!(
+ interop_passive_client_welcome,
+ generate_passive_client_welcome_tests().await
+ );
+
+ let test_cases_com: Vec<TestCase> = load_test_case_json!(
+ interop_passive_client_handle_commit,
+ generate_passive_client_proposal_tests().await
+ );
+
+ let test_cases_rand: Vec<TestCase> = load_test_case_json!(
+ interop_passive_client_random,
+ generate_passive_client_random_tests().await
+ );
+
+ (test_cases_wel, test_cases_com, test_cases_rand)
+ };
+
+ #[cfg(not(mls_build_async))]
+ let (test_cases_wel, test_cases_com, test_cases_rand) = {
+ let test_cases_wel: Vec<TestCase> = load_test_case_json!(
+ interop_passive_client_welcome,
+ generate_passive_client_welcome_tests()
+ );
+
+ let test_cases_com: Vec<TestCase> = load_test_case_json!(
+ interop_passive_client_handle_commit,
+ generate_passive_client_proposal_tests()
+ );
+
+ let test_cases_rand: Vec<TestCase> = load_test_case_json!(
+ interop_passive_client_random,
+ generate_passive_client_random_tests()
+ );
+
+ (test_cases_wel, test_cases_com, test_cases_rand)
+ };
+
+ for test_case in vec![]
+ .into_iter()
+ .chain(test_cases_com)
+ .chain(test_cases_wel)
+ .chain(test_cases_rand)
+ {
+ let crypto_provider = TestCryptoProvider::new();
+ let Some(cs) = crypto_provider.cipher_suite_provider(test_case.cipher_suite.into()) else {
+ continue;
+ };
+
+ let message = MlsMessage::from_bytes(&test_case.key_package).unwrap();
+ let key_package = message.into_key_package().unwrap();
+ let id = key_package.leaf_node.signing_identity.clone();
+ let key = test_case.signature_priv.clone().into();
+
+ let mut client_builder = ClientBuilder::new()
+ .crypto_provider(crypto_provider)
+ .identity_provider(BasicIdentityProvider::new());
+
+ for psk in test_case.external_psks {
+ client_builder = client_builder.psk(ExternalPskId::new(psk.psk_id), psk.psk.into());
+ }
+
+ let client = client_builder
+ .signing_identity(id, key, cs.cipher_suite())
+ .build();
+
+ let key_pckg_gen = KeyPackageGeneration {
+ reference: key_package.to_reference(&cs).await.unwrap(),
+ key_package,
+ init_secret_key: test_case.init_priv.into(),
+ leaf_node_secret_key: test_case.encryption_priv.into(),
+ };
+
+ let (id, pkg) = key_pckg_gen.to_storage().unwrap();
+ client.config.key_package_repo().insert(id, pkg);
+
+ let welcome = MlsMessage::from_bytes(&test_case.welcome).unwrap();
+
+ let tree = test_case
+ .ratchet_tree
+ .map(|t| ExportedTree::from_bytes(&t.0).unwrap());
+
+ let (mut group, _info) = client.join_group(tree, &welcome).await.unwrap();
+
+ assert_eq!(
+ group.epoch_authenticator().unwrap().to_vec(),
+ test_case.initial_epoch_authenticator
+ );
+
+ for epoch in test_case.epochs {
+ for proposal in epoch.proposals.iter() {
+ let message = MlsMessage::from_bytes(&proposal.0).unwrap();
+
+ group
+ .process_incoming_message_with_time(message, MlsTime::now())
+ .await
+ .unwrap();
+ }
+
+ let message = MlsMessage::from_bytes(&epoch.commit).unwrap();
+
+ group
+ .process_incoming_message_with_time(message, MlsTime::now())
+ .await
+ .unwrap();
+
+ assert_eq!(
+ epoch.epoch_authenticator,
+ group.epoch_authenticator().unwrap().to_vec()
+ );
+ }
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(coverage_nightly, coverage(off))]
+async fn invite_passive_client<P: CipherSuiteProvider>(
+ groups: &mut [Group<impl MlsConfig>],
+ with_psk: bool,
+ cs: &P,
+) -> TestCase {
+ let crypto_provider = TestCryptoProvider::new();
+
+ let (secret_key, public_key) = cs.signature_key_generate().await.unwrap();
+ let credential = get_test_basic_credential(b"Arnold".to_vec());
+ let identity = SigningIdentity::new(credential, public_key);
+ let key_package_repo = InMemoryKeyPackageStorage::new();
+
+ let client = ClientBuilder::new()
+ .crypto_provider(crypto_provider)
+ .identity_provider(BasicIdentityProvider::new())
+ .key_package_repo(key_package_repo.clone())
+ .key_package_lifetime(ETERNAL_LIFETIME.not_after - ETERNAL_LIFETIME.not_before)
+ .key_package_not_before(ETERNAL_LIFETIME.not_before)
+ .signing_identity(identity.clone(), secret_key.clone(), cs.cipher_suite())
+ .build();
+
+ let key_pckg = client.generate_key_package_message().await.unwrap();
+
+ let (_, key_pckg_secrets) = key_package_repo.key_packages()[0].clone();
+
+ let mut commit_builder = groups[0]
+ .commit_builder()
+ .add_member(key_pckg.clone())
+ .unwrap();
+
+ if with_psk {
+ commit_builder = commit_builder
+ .add_external_psk(ExternalPskId::new(TEST_EXT_PSK_ID.to_vec()))
+ .unwrap();
+ }
+
+ let commit = commit_builder.build().await.unwrap();
+
+ all_process_message(groups, &commit.commit_message, 0, true).await;
+
+ let external_psk = TestExternalPsk {
+ psk_id: TEST_EXT_PSK_ID.to_vec(),
+ psk: make_test_ext_psk(),
+ };
+
+ TestCase {
+ cipher_suite: cs.cipher_suite().into(),
+ key_package: key_pckg.to_bytes().unwrap(),
+ encryption_priv: key_pckg_secrets.leaf_node_key.to_vec(),
+ init_priv: key_pckg_secrets.init_key.to_vec(),
+ welcome: commit.welcome_messages[0].to_bytes().unwrap(),
+ initial_epoch_authenticator: groups[0].epoch_authenticator().unwrap().to_vec(),
+ epochs: vec![],
+ signature_priv: secret_key.to_vec(),
+ external_psks: if with_psk { vec![external_psk] } else { vec![] },
+ ratchet_tree: None,
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(coverage_nightly, coverage(off))]
+pub async fn generate_passive_client_proposal_tests() -> Vec<TestCase> {
+ let mut test_cases: Vec<TestCase> = vec![];
+
+ for cs in CipherSuite::all() {
+ let crypto_provider = TestCryptoProvider::new();
+ let Some(cs) = crypto_provider.cipher_suite_provider(cs) else {
+ continue;
+ };
+
+ let mut groups =
+ get_test_groups(VERSION, cs.cipher_suite(), 7, None, false, &crypto_provider).await;
+
+ let mut partial_test_case = invite_passive_client(&mut groups, true, &cs).await;
+
+ // Create a new epoch s.t. the passive member can process resumption PSK from the current one
+ let commit = groups[0].commit(vec![]).await.unwrap();
+ all_process_message(&mut groups, &commit.commit_message, 0, true).await;
+
+ partial_test_case.epochs.push(TestEpoch::new(
+ vec![],
+ &commit.commit_message,
+ groups[0].epoch_authenticator().unwrap().to_vec(),
+ ));
+
+ let psk = ExternalPskId::new(TEST_EXT_PSK_ID.to_vec());
+ let key_pckg = create_key_package(cs.cipher_suite()).await;
+
+ // Create by value proposals
+ let test_case = commit_by_value(
+ &mut groups[3].clone(),
+ |b| b.add_member(key_pckg.clone()).unwrap(),
+ partial_test_case.clone(),
+ )
+ .await;
+
+ test_cases.push(test_case);
+
+ let test_case = commit_by_value(
+ &mut groups[3].clone(),
+ |b| b.remove_member(5).unwrap(),
+ partial_test_case.clone(),
+ )
+ .await;
+
+ test_cases.push(test_case);
+
+ let test_case = commit_by_value(
+ &mut groups[1].clone(),
+ |b| b.add_external_psk(psk.clone()).unwrap(),
+ partial_test_case.clone(),
+ )
+ .await;
+
+ test_cases.push(test_case);
+
+ let test_case = commit_by_value(
+ &mut groups[5].clone(),
+ |b| b.add_resumption_psk(groups[1].current_epoch() - 1).unwrap(),
+ partial_test_case.clone(),
+ )
+ .await;
+
+ test_cases.push(test_case);
+
+ let test_case = commit_by_value(
+ &mut groups[2].clone(),
+ |b| b.set_group_context_ext(Default::default()).unwrap(),
+ partial_test_case.clone(),
+ )
+ .await;
+
+ test_cases.push(test_case);
+
+ let test_case = commit_by_value(
+ &mut groups[3].clone(),
+ |b| {
+ b.add_member(key_pckg)
+ .unwrap()
+ .remove_member(5)
+ .unwrap()
+ .add_external_psk(psk.clone())
+ .unwrap()
+ .add_resumption_psk(groups[4].current_epoch() - 1)
+ .unwrap()
+ .set_group_context_ext(Default::default())
+ .unwrap()
+ },
+ partial_test_case.clone(),
+ )
+ .await;
+
+ test_cases.push(test_case);
+
+ // Create by reference proposals
+ let add = groups[0]
+ .propose_add(create_key_package(cs.cipher_suite()).await, vec![])
+ .await
+ .unwrap();
+
+ let add = (add, 0);
+
+ let update = (groups[1].propose_update(vec![]).await.unwrap(), 1);
+ let remove = (groups[2].propose_remove(2, vec![]).await.unwrap(), 2);
+
+ let ext_psk = groups[3]
+ .propose_external_psk(psk.clone(), vec![])
+ .await
+ .unwrap();
+
+ let ext_psk = (ext_psk, 3);
+
+ let last_ep = groups[3].current_epoch() - 1;
+
+ let res_psk = groups[3]
+ .propose_resumption_psk(last_ep, vec![])
+ .await
+ .unwrap();
+
+ let res_psk = (res_psk, 3);
+
+ let grp_ext = groups[4]
+ .propose_group_context_extensions(Default::default(), vec![])
+ .await
+ .unwrap();
+
+ let grp_ext = (grp_ext, 4);
+
+ let proposals = [add, update, remove, ext_psk, res_psk, grp_ext];
+
+ for (p, sender) in &proposals {
+ let mut groups = groups.clone();
+
+ all_process_message(&mut groups, p, *sender, false).await;
+
+ let commit = groups[5].commit(vec![]).await.unwrap().commit_message;
+
+ groups[5].apply_pending_commit().await.unwrap();
+ let auth = groups[5].epoch_authenticator().unwrap().to_vec();
+
+ let mut test_case = partial_test_case.clone();
+ let epoch = TestEpoch::new(vec![p.clone()], &commit, auth);
+ test_case.epochs.push(epoch);
+
+ test_cases.push(test_case);
+ }
+
+ let mut group = groups[4].clone();
+
+ for (p, _) in proposals.iter().filter(|(_, i)| *i != 4) {
+ group.process_incoming_message(p.clone()).await.unwrap();
+ }
+
+ let commit = group.commit(vec![]).await.unwrap().commit_message;
+ group.apply_pending_commit().await.unwrap();
+ let auth = group.epoch_authenticator().unwrap().to_vec();
+ let mut test_case = partial_test_case.clone();
+ let proposals = proposals.into_iter().map(|(p, _)| p).collect();
+ let epoch = TestEpoch::new(proposals, &commit, auth);
+ test_case.epochs.push(epoch);
+ test_cases.push(test_case);
+ }
+
+ test_cases
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(coverage_nightly, coverage(off))]
+async fn commit_by_value<F, C: MlsConfig>(
+ group: &mut Group<C>,
+ proposal_adder: F,
+ partial_test_case: TestCase,
+) -> TestCase
+where
+ F: FnOnce(CommitBuilder<C>) -> CommitBuilder<C>,
+{
+ let builder = proposal_adder(group.commit_builder());
+ let commit = builder.build().await.unwrap().commit_message;
+ group.apply_pending_commit().await.unwrap();
+ let auth = group.epoch_authenticator().unwrap().to_vec();
+ let epoch = TestEpoch::new(vec![], &commit, auth);
+ let mut test_case = partial_test_case;
+ test_case.epochs.push(epoch);
+ test_case
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(coverage_nightly, coverage(off))]
+async fn create_key_package(cs: CipherSuite) -> MlsMessage {
+ let client = generate_basic_client(
+ cs,
+ VERSION,
+ 0xbeef,
+ None,
+ false,
+ &TestCryptoProvider::new(),
+ Some(ETERNAL_LIFETIME),
+ )
+ .await;
+
+ client.generate_key_package_message().await.unwrap()
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(coverage_nightly, coverage(off))]
+pub async fn generate_passive_client_welcome_tests() -> Vec<TestCase> {
+ let mut test_cases: Vec<TestCase> = vec![];
+
+ for cs in CipherSuite::all() {
+ let crypto_provider = TestCryptoProvider::new();
+ let Some(cs) = crypto_provider.cipher_suite_provider(cs) else {
+ continue;
+ };
+
+ for with_tree_in_extension in [true, false] {
+ for (with_psk, with_path) in [false, true].into_iter().cartesian_product([true, false])
+ {
+ let options = CommitOptions::new()
+ .with_path_required(with_path)
+ .with_ratchet_tree_extension(with_tree_in_extension);
+
+ let mut groups = get_test_groups(
+ VERSION,
+ cs.cipher_suite(),
+ 16,
+ Some(options),
+ false,
+ &crypto_provider,
+ )
+ .await;
+
+ // Remove a member s.t. the passive member joins in their place
+ let proposal = groups[0].propose_remove(7, vec![]).await.unwrap();
+ all_process_message(&mut groups, &proposal, 0, false).await;
+
+ let mut test_case = invite_passive_client(&mut groups, with_psk, &cs).await;
+
+ if !with_tree_in_extension {
+ let tree = groups[0].export_tree().to_bytes().unwrap();
+ test_case.ratchet_tree = Some(TestRatchetTree(tree));
+ }
+
+ test_cases.push(test_case);
+ }
+ }
+ }
+
+ test_cases
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(coverage_nightly, coverage(off))]
+pub async fn generate_passive_client_random_tests() -> Vec<TestCase> {
+ let mut test_cases: Vec<TestCase> = vec![];
+
+ for cs in CipherSuite::all() {
+ let crypto = TestCryptoProvider::new();
+ let Some(csp) = crypto.cipher_suite_provider(cs) else {
+ continue;
+ };
+
+ let creator =
+ generate_basic_client(cs, VERSION, 0, None, false, &crypto, Some(ETERNAL_LIFETIME))
+ .await;
+
+ let creator_group = creator.create_group(Default::default()).await.unwrap();
+
+ let mut groups = vec![creator_group];
+
+ let mut new_clients = Vec::new();
+
+ for i in 0..10 {
+ new_clients.push(
+ generate_basic_client(
+ cs,
+ VERSION,
+ i + 1,
+ None,
+ false,
+ &crypto,
+ Some(ETERNAL_LIFETIME),
+ )
+ .await,
+ )
+ }
+
+ add_random_members(0, &mut groups, new_clients, None).await;
+
+ let mut test_case = invite_passive_client(&mut groups, false, &csp).await;
+
+ let passive_client_index = 11;
+
+ let seed: <rand::rngs::StdRng as SeedableRng>::Seed = rand::random();
+ let mut rng = rand::rngs::StdRng::from_seed(seed);
+ #[cfg(feature = "std")]
+ println!("generating random commits for seed {}", hex::encode(seed));
+
+ let mut next_free_idx = 11;
+ for _ in 0..100 {
+ // We keep the passive client and another member to send
+ let num_removed = rng.gen_range(0..groups.len() - 2);
+ let num_added = rng.gen_range(1..30);
+
+ let mut members = (0..groups.len())
+ .filter(|i| groups[*i].current_member_index() != passive_client_index)
+ .choose_multiple(&mut rng, num_removed + 1);
+
+ let sender = members.pop().unwrap();
+
+ remove_members(members, sender, &mut groups, Some(&mut test_case)).await;
+
+ let sender = (0..groups.len())
+ .filter(|i| groups[*i].current_member_index() != passive_client_index)
+ .choose(&mut rng)
+ .unwrap();
+
+ let mut new_clients = Vec::new();
+
+ for i in 0..num_added {
+ new_clients.push(
+ generate_basic_client(
+ cs,
+ VERSION,
+ next_free_idx + i,
+ None,
+ false,
+ &crypto,
+ Some(ETERNAL_LIFETIME),
+ )
+ .await,
+ );
+ }
+
+ add_random_members(sender, &mut groups, new_clients, Some(&mut test_case)).await;
+
+ next_free_idx += num_added;
+ }
+
+ test_cases.push(test_case);
+ }
+
+ test_cases
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(coverage_nightly, coverage(off))]
+pub async fn add_random_members<C: MlsConfig>(
+ committer: usize,
+ groups: &mut Vec<Group<C>>,
+ clients: Vec<Client<C>>,
+ test_case: Option<&mut TestCase>,
+) {
+ let committer_index = groups[committer].current_member_index() as usize;
+
+ let mut key_packages = Vec::new();
+
+ for client in &clients {
+ let key_package = client.generate_key_package_message().await.unwrap();
+ key_packages.push(key_package);
+ }
+
+ let mut add_proposals = Vec::new();
+
+ let committer_group = &mut groups[committer];
+
+ for key_package in key_packages {
+ add_proposals.push(
+ committer_group
+ .propose_add(key_package, vec![])
+ .await
+ .unwrap(),
+ );
+ }
+
+ for p in &add_proposals {
+ all_process_message(groups, p, committer_index, false).await;
+ }
+
+ let commit_output = groups[committer].commit(vec![]).await.unwrap();
+
+ all_process_message(groups, &commit_output.commit_message, committer_index, true).await;
+
+ let auth = groups[committer].epoch_authenticator().unwrap().to_vec();
+ let epoch = TestEpoch::new(add_proposals, &commit_output.commit_message, auth);
+
+ if let Some(tc) = test_case {
+ tc.epochs.push(epoch)
+ };
+
+ let tree_data = groups[committer].export_tree().into_owned();
+
+ for client in &clients {
+ let commit = commit_output.welcome_messages[0].clone();
+
+ let group = client
+ .join_group(Some(tree_data.clone()), &commit)
+ .await
+ .unwrap()
+ .0;
+
+ groups.push(group);
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(coverage_nightly, coverage(off))]
+pub async fn remove_members<C: MlsConfig>(
+ removed_members: Vec<usize>,
+ committer: usize,
+ groups: &mut Vec<Group<C>>,
+ test_case: Option<&mut TestCase>,
+) {
+ let remove_indexes = removed_members
+ .iter()
+ .map(|removed| groups[*removed].current_member_index())
+ .collect::<Vec<u32>>();
+
+ let mut commit_builder = groups[committer].commit_builder();
+
+ for index in remove_indexes {
+ commit_builder = commit_builder.remove_member(index).unwrap();
+ }
+
+ let commit = commit_builder.build().await.unwrap().commit_message;
+ let committer_index = groups[committer].current_member_index() as usize;
+ all_process_message(groups, &commit, committer_index, true).await;
+
+ let auth = groups[committer].epoch_authenticator().unwrap().to_vec();
+ let epoch = TestEpoch::new(vec![], &commit, auth);
+
+ if let Some(tc) = test_case {
+ tc.epochs.push(epoch)
+ };
+
+ let mut index = 0;
+
+ groups.retain(|_| {
+ index += 1;
+ !(removed_members.contains(&(index - 1)))
+ });
+}
diff --git a/src/group/interop_test_vectors/serialization.rs b/src/group/interop_test_vectors/serialization.rs
new file mode 100644
index 0000000..cbaf6fa
--- /dev/null
+++ b/src/group/interop_test_vectors/serialization.rs
@@ -0,0 +1,169 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::vec::Vec;
+use mls_rs_codec::{MlsDecode, MlsEncode};
+
+use mls_rs_core::extension::ExtensionList;
+
+use crate::{
+ group::{
+ framing::ContentType,
+ proposal::{
+ AddProposal, ExternalInit, PreSharedKeyProposal, ReInitProposal, RemoveProposal,
+ UpdateProposal,
+ },
+ Commit, GroupSecrets, MlsMessage,
+ },
+ tree_kem::node::NodeVec,
+};
+
+#[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
+struct TestCase {
+ #[serde(with = "hex::serde")]
+ mls_welcome: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ mls_group_info: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ mls_key_package: Vec<u8>,
+
+ #[serde(with = "hex::serde")]
+ ratchet_tree: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ group_secrets: Vec<u8>,
+
+ #[serde(with = "hex::serde")]
+ add_proposal: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ update_proposal: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ remove_proposal: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pre_shared_key_proposal: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ re_init_proposal: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ external_init_proposal: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ group_context_extensions_proposal: Vec<u8>,
+
+ #[serde(with = "hex::serde")]
+ commit: Vec<u8>,
+
+ #[serde(with = "hex::serde")]
+ public_message_application: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ public_message_proposal: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ public_message_commit: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ private_message: Vec<u8>,
+}
+
+// The test vector can be found here:
+// https://github.com/mlswg/mls-implementations/blob/main/test-vectors/messages.json
+#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+async fn serialization() {
+ let test_cases: Vec<TestCase> = load_test_case_json!(serialization, Vec::<TestCase>::new());
+
+ for test_case in test_cases.into_iter() {
+ let message = MlsMessage::from_bytes(&test_case.mls_welcome).unwrap();
+ message.clone().into_welcome().unwrap();
+ assert_eq!(&message.to_bytes().unwrap(), &test_case.mls_welcome);
+
+ let message = MlsMessage::from_bytes(&test_case.mls_group_info).unwrap();
+ message.clone().into_group_info().unwrap();
+ assert_eq!(&message.to_bytes().unwrap(), &test_case.mls_group_info);
+
+ let message = MlsMessage::from_bytes(&test_case.mls_key_package).unwrap();
+ message.clone().into_key_package().unwrap();
+ assert_eq!(&message.to_bytes().unwrap(), &test_case.mls_key_package);
+
+ let tree = NodeVec::mls_decode(&mut &*test_case.ratchet_tree).unwrap();
+
+ assert_eq!(&tree.mls_encode_to_vec().unwrap(), &test_case.ratchet_tree);
+
+ let secs = GroupSecrets::mls_decode(&mut &*test_case.group_secrets).unwrap();
+
+ assert_eq!(&secs.mls_encode_to_vec().unwrap(), &test_case.group_secrets);
+
+ let proposal = AddProposal::mls_decode(&mut &*test_case.add_proposal).unwrap();
+
+ assert_eq!(
+ &proposal.mls_encode_to_vec().unwrap(),
+ &test_case.add_proposal
+ );
+
+ let proposal = UpdateProposal::mls_decode(&mut &*test_case.update_proposal).unwrap();
+
+ assert_eq!(
+ &proposal.mls_encode_to_vec().unwrap(),
+ &test_case.update_proposal
+ );
+
+ let proposal = RemoveProposal::mls_decode(&mut &*test_case.remove_proposal).unwrap();
+
+ assert_eq!(
+ &proposal.mls_encode_to_vec().unwrap(),
+ &test_case.remove_proposal
+ );
+
+ let proposal = ReInitProposal::mls_decode(&mut &*test_case.re_init_proposal).unwrap();
+
+ assert_eq!(
+ &proposal.mls_encode_to_vec().unwrap(),
+ &test_case.re_init_proposal
+ );
+
+ let proposal =
+ PreSharedKeyProposal::mls_decode(&mut &*test_case.pre_shared_key_proposal).unwrap();
+
+ assert_eq!(
+ &proposal.mls_encode_to_vec().unwrap(),
+ &test_case.pre_shared_key_proposal
+ );
+
+ let proposal = ExternalInit::mls_decode(&mut &*test_case.external_init_proposal).unwrap();
+
+ assert_eq!(
+ &proposal.mls_encode_to_vec().unwrap(),
+ &test_case.external_init_proposal
+ );
+
+ let proposal =
+ ExtensionList::mls_decode(&mut &*test_case.group_context_extensions_proposal).unwrap();
+
+ assert_eq!(
+ &proposal.mls_encode_to_vec().unwrap(),
+ &test_case.group_context_extensions_proposal
+ );
+
+ let commit = Commit::mls_decode(&mut &*test_case.commit).unwrap();
+
+ assert_eq!(&commit.mls_encode_to_vec().unwrap(), &test_case.commit);
+
+ let message = MlsMessage::from_bytes(&test_case.public_message_application).unwrap();
+ let serialized = message.mls_encode_to_vec().unwrap();
+ assert_eq!(&serialized, &test_case.public_message_application);
+ let content_type = message.into_plaintext().unwrap().content.content_type();
+ assert_eq!(content_type, ContentType::Application);
+
+ let message = MlsMessage::from_bytes(&test_case.public_message_proposal).unwrap();
+ let serialized = message.mls_encode_to_vec().unwrap();
+ assert_eq!(&serialized, &test_case.public_message_proposal);
+ let content_type = message.into_plaintext().unwrap().content.content_type();
+ assert_eq!(content_type, ContentType::Proposal);
+
+ let message = MlsMessage::from_bytes(&test_case.public_message_commit).unwrap();
+ let serialized = message.mls_encode_to_vec().unwrap();
+ assert_eq!(&serialized, &test_case.public_message_commit);
+ let content_type = message.into_plaintext().unwrap().content.content_type();
+ assert_eq!(content_type, ContentType::Commit);
+
+ let message = MlsMessage::from_bytes(&test_case.private_message).unwrap();
+ let serialized = message.mls_encode_to_vec().unwrap();
+ assert_eq!(&serialized, &test_case.private_message);
+ message.into_ciphertext().unwrap();
+ }
+}
diff --git a/src/group/interop_test_vectors/tree_kem.rs b/src/group/interop_test_vectors/tree_kem.rs
new file mode 100644
index 0000000..0a04312
--- /dev/null
+++ b/src/group/interop_test_vectors/tree_kem.rs
@@ -0,0 +1,185 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use crate::{
+ client::test_utils::TEST_PROTOCOL_VERSION,
+ crypto::test_utils::try_test_cipher_suite_provider,
+ group::{
+ confirmation_tag::ConfirmationTag, framing::Content, message_processor::MessageProcessor,
+ message_signature::AuthenticatedContent, test_utils::GroupWithoutKeySchedule, Commit,
+ GroupContext, PathSecret, Sender,
+ },
+ identity::basic::BasicIdentityProvider,
+ tree_kem::{
+ node::{LeafIndex, NodeVec},
+ TreeKemPrivate, TreeKemPublic, UpdatePath,
+ },
+ WireFormat,
+};
+use alloc::vec;
+use alloc::vec::Vec;
+use mls_rs_codec::MlsDecode;
+use mls_rs_core::{crypto::CipherSuiteProvider, extension::ExtensionList};
+
+#[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
+struct TreeKemTestCase {
+ pub cipher_suite: u16,
+
+ #[serde(with = "hex::serde")]
+ pub group_id: Vec<u8>,
+ epoch: u64,
+ #[serde(with = "hex::serde")]
+ confirmed_transcript_hash: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ ratchet_tree: Vec<u8>,
+
+ leaves_private: Vec<TestLeafPrivate>,
+ update_paths: Vec<TestUpdatePath>,
+}
+
+#[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
+struct TestLeafPrivate {
+ index: u32,
+ #[serde(with = "hex::serde")]
+ encryption_priv: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ signature_priv: Vec<u8>,
+ path_secrets: Vec<TestPathSecretPrivate>,
+}
+
+#[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
+struct TestPathSecretPrivate {
+ node: u32,
+ #[serde(with = "hex::serde")]
+ path_secret: Vec<u8>,
+}
+
+#[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
+struct TestUpdatePath {
+ sender: u32,
+ #[serde(with = "hex::serde")]
+ update_path: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ tree_hash_after: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ commit_secret: Vec<u8>,
+}
+
+#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+async fn tree_kem() {
+ // The test vector can be found here https://github.com/mlswg/mls-implementations/blob/main/test-vectors/treekem.json
+
+ let test_cases: Vec<TreeKemTestCase> =
+ load_test_case_json!(interop_tree_kem, Vec::<TreeKemTestCase>::new());
+
+ for test_case in test_cases {
+ let Some(cs) = try_test_cipher_suite_provider(test_case.cipher_suite) else {
+ continue;
+ };
+
+ // Import the public ratchet tree
+ let nodes = NodeVec::mls_decode(&mut &*test_case.ratchet_tree).unwrap();
+
+ let mut tree =
+ TreeKemPublic::import_node_data(nodes, &BasicIdentityProvider, &Default::default())
+ .await
+ .unwrap();
+
+ // Construct GroupContext
+ let group_context = GroupContext {
+ protocol_version: TEST_PROTOCOL_VERSION,
+ cipher_suite: cs.cipher_suite(),
+ group_id: test_case.group_id,
+ epoch: test_case.epoch,
+ tree_hash: tree.tree_hash(&cs).await.unwrap(),
+ confirmed_transcript_hash: test_case.confirmed_transcript_hash.into(),
+ extensions: ExtensionList::new(),
+ };
+
+ for leaf in test_case.leaves_private.iter() {
+ // Construct the private ratchet tree
+ let mut tree_private = TreeKemPrivate::new(LeafIndex(leaf.index));
+
+ // Set and validate HPKE keys on direct path
+ let path = tree.nodes.direct_copath(tree_private.self_index);
+
+ tree_private.secret_keys = Vec::new();
+
+ for dp in path {
+ let dp = dp.path;
+
+ let secret = leaf
+ .path_secrets
+ .iter()
+ .find_map(|s| (s.node == dp).then_some(s.path_secret.clone()));
+
+ let private_key = if let Some(secret) = secret {
+ let (secret_key, public_key) = PathSecret::from(secret)
+ .to_hpke_key_pair(&cs)
+ .await
+ .unwrap();
+
+ let tree_public = &tree.nodes.borrow_as_parent(dp).unwrap().public_key;
+ assert_eq!(&public_key, tree_public);
+
+ Some(secret_key)
+ } else {
+ None
+ };
+
+ tree_private.secret_keys.push(private_key);
+ }
+
+ // Set HPKE key for leaf
+ tree_private
+ .secret_keys
+ .insert(0, Some(leaf.encryption_priv.clone().into()));
+
+ let paths = test_case
+ .update_paths
+ .iter()
+ .filter(|path| path.sender != leaf.index);
+
+ for update_path in paths {
+ let mut group = GroupWithoutKeySchedule::new(cs.cipher_suite()).await;
+ group.state.context = group_context.clone();
+ group.state.public_tree = tree.clone();
+ group.private_tree = tree_private.clone();
+
+ let path = UpdatePath::mls_decode(&mut &*update_path.update_path).unwrap();
+
+ let commit = Commit {
+ proposals: vec![],
+ path: Some(path),
+ };
+
+ let mut auth_content = AuthenticatedContent::new(
+ &group_context,
+ Sender::Member(update_path.sender),
+ Content::Commit(alloc::boxed::Box::new(commit)),
+ vec![],
+ WireFormat::PublicMessage,
+ );
+
+ auth_content.auth.confirmation_tag = Some(ConfirmationTag::empty(&cs).await);
+
+ // Hack not to increment epoch
+ group.state.context.epoch -= 1;
+
+ group.process_commit(auth_content, None).await.unwrap();
+
+ // Check that we got the expected commit secret and correctly merged the update path.
+ // This implies that we computed the path secrets correctly.
+ let commit_secret = group.secrets.unwrap().1;
+
+ assert_eq!(&*commit_secret, &update_path.commit_secret);
+
+ let new_tree = &mut group.provisional_public_state.unwrap().public_tree;
+ let new_tree_hash = new_tree.tree_hash(&cs).await.unwrap();
+
+ assert_eq!(&new_tree_hash, &update_path.tree_hash_after);
+ }
+ }
+ }
+}
diff --git a/src/group/interop_test_vectors/tree_modifications.rs b/src/group/interop_test_vectors/tree_modifications.rs
new file mode 100644
index 0000000..a172e0c
--- /dev/null
+++ b/src/group/interop_test_vectors/tree_modifications.rs
@@ -0,0 +1,177 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::boxed::Box;
+use alloc::vec;
+use alloc::vec::Vec;
+use mls_rs_codec::{MlsDecode, MlsEncode};
+use mls_rs_core::crypto::CipherSuite;
+
+use crate::{
+ client::test_utils::{TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
+ crypto::test_utils::{test_cipher_suite_provider, try_test_cipher_suite_provider},
+ group::{
+ proposal::{AddProposal, Proposal, ProposalOrRef, RemoveProposal, UpdateProposal},
+ proposal_cache::test_utils::CommitReceiver,
+ proposal_ref::ProposalRef,
+ test_utils::TEST_GROUP,
+ LeafIndex, Sender, TreeKemPublic,
+ },
+ identity::basic::BasicIdentityProvider,
+ key_package::test_utils::test_key_package,
+ tree_kem::{
+ leaf_node::test_utils::default_properties, node::NodeVec, test_utils::TreeWithSigners,
+ },
+};
+
+#[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
+struct TreeModsTestCase {
+ #[serde(with = "hex::serde")]
+ pub tree_before: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub proposal: Vec<u8>,
+ pub proposal_sender: u32,
+ #[serde(with = "hex::serde")]
+ pub tree_after: Vec<u8>,
+}
+
+impl TreeModsTestCase {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ async fn new(tree_before: TreeKemPublic, proposal: Proposal, proposal_sender: u32) -> Self {
+ let tree_after = apply_proposal(proposal.clone(), proposal_sender, &tree_before).await;
+
+ Self {
+ tree_before: tree_before.nodes.mls_encode_to_vec().unwrap(),
+ proposal: proposal.mls_encode_to_vec().unwrap(),
+ tree_after: tree_after.nodes.mls_encode_to_vec().unwrap(),
+ proposal_sender,
+ }
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(coverage_nightly, coverage(off))]
+async fn generate_tree_mods_tests() -> Vec<TreeModsTestCase> {
+ let mut test_vector = vec![];
+ let cs = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ // Update
+ let tree_before = TreeWithSigners::make_full_tree(8, &cs).await;
+ let update = generate_update(6, &tree_before).await;
+ test_vector.push(TreeModsTestCase::new(tree_before.tree, update, 6).await);
+
+ // Add in the middle
+ let mut tree_before = TreeWithSigners::make_full_tree(6, &cs).await;
+ tree_before.remove_member(3);
+ test_vector.push(TreeModsTestCase::new(tree_before.tree, generate_add().await, 2).await);
+
+ // Add at the end
+ let tree_before = TreeWithSigners::make_full_tree(6, &cs).await;
+ test_vector.push(TreeModsTestCase::new(tree_before.tree, generate_add().await, 2).await);
+
+ // Add at the end, tree grows
+ let tree_before = TreeWithSigners::make_full_tree(8, &cs).await;
+ test_vector.push(TreeModsTestCase::new(tree_before.tree, generate_add().await, 2).await);
+
+ // Remove in the middle
+ let tree_before = TreeWithSigners::make_full_tree(8, &cs).await;
+ test_vector.push(TreeModsTestCase::new(tree_before.tree, generate_remove(2), 2).await);
+
+ // Remove at the end
+ let tree_before = TreeWithSigners::make_full_tree(8, &cs).await;
+ test_vector.push(TreeModsTestCase::new(tree_before.tree, generate_remove(7), 2).await);
+
+ // Remove at the end, tree shrinks
+ let tree_before = TreeWithSigners::make_full_tree(9, &cs).await;
+ test_vector.push(TreeModsTestCase::new(tree_before.tree, generate_remove(8), 2).await);
+
+ test_vector
+}
+
+#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+async fn tree_modifications_interop() {
+ // The test vector can be found here https://github.com/mlswg/mls-implementations/blob/main/test-vectors/tree-operations.json
+
+ // All test vectors use cipher suite 1
+ if try_test_cipher_suite_provider(*CipherSuite::CURVE25519_AES128).is_none() {
+ return;
+ }
+
+ #[cfg(not(mls_build_async))]
+ let test_cases: Vec<TreeModsTestCase> =
+ load_test_case_json!(tree_modifications_interop, generate_tree_mods_tests());
+
+ #[cfg(mls_build_async)]
+ let test_cases: Vec<TreeModsTestCase> =
+ load_test_case_json!(tree_modifications_interop, generate_tree_mods_tests().await);
+
+ for test_case in test_cases.into_iter() {
+ let nodes = NodeVec::mls_decode(&mut &*test_case.tree_before).unwrap();
+
+ let tree_before =
+ TreeKemPublic::import_node_data(nodes, &BasicIdentityProvider, &Default::default())
+ .await
+ .unwrap();
+
+ let proposal = Proposal::mls_decode(&mut &*test_case.proposal).unwrap();
+
+ let tree_after = apply_proposal(proposal, test_case.proposal_sender, &tree_before).await;
+
+ let tree_after = tree_after.nodes.mls_encode_to_vec().unwrap();
+
+ assert_eq!(tree_after, test_case.tree_after);
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+async fn apply_proposal(
+ proposal: Proposal,
+ sender: u32,
+ tree_before: &TreeKemPublic,
+) -> TreeKemPublic {
+ let cs = test_cipher_suite_provider(CipherSuite::CURVE25519_AES128);
+ let p_ref = ProposalRef::new_fake(b"fake ref".to_vec());
+
+ CommitReceiver::new(tree_before, Sender::Member(0), LeafIndex(1), cs)
+ .cache(p_ref.clone(), proposal, Sender::Member(sender))
+ .receive(vec![ProposalOrRef::Reference(p_ref)])
+ .await
+ .unwrap()
+ .public_tree
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(coverage_nightly, coverage(off))]
+async fn generate_add() -> Proposal {
+ let key_package = test_key_package(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "Roger").await;
+ Proposal::Add(Box::new(AddProposal { key_package }))
+}
+
+#[cfg_attr(coverage_nightly, coverage(off))]
+fn generate_remove(i: u32) -> Proposal {
+ let to_remove = LeafIndex(i);
+ Proposal::Remove(RemoveProposal { to_remove })
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(coverage_nightly, coverage(off))]
+async fn generate_update(i: u32, tree: &TreeWithSigners) -> Proposal {
+ let signer = tree.signers[i as usize].as_ref().unwrap();
+ let mut leaf_node = tree.tree.get_leaf_node(LeafIndex(i)).unwrap().clone();
+
+ leaf_node
+ .update(
+ &test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ TEST_GROUP,
+ i,
+ default_properties(),
+ None,
+ signer,
+ )
+ .await
+ .unwrap();
+
+ Proposal::Update(UpdateProposal { leaf_node })
+}
diff --git a/src/group/key_schedule.rs b/src/group/key_schedule.rs
new file mode 100644
index 0000000..77c1d65
--- /dev/null
+++ b/src/group/key_schedule.rs
@@ -0,0 +1,988 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use crate::client::MlsError;
+use crate::extension::ExternalPubExt;
+use crate::group::{GroupContext, MembershipTag};
+use crate::psk::secret::PskSecret;
+#[cfg(feature = "psk")]
+use crate::psk::PreSharedKey;
+use crate::tree_kem::path_secret::PathSecret;
+use crate::CipherSuiteProvider;
+
+#[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+use crate::group::SecretTree;
+
+use alloc::vec;
+use alloc::vec::Vec;
+use core::fmt::{self, Debug};
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+use mls_rs_core::error::IntoAnyError;
+use zeroize::Zeroizing;
+
+use crate::crypto::{HpkeContextR, HpkeContextS, HpkePublicKey, HpkeSecretKey};
+
+use super::epoch::{EpochSecrets, SenderDataSecret};
+use super::message_signature::AuthenticatedContent;
+
+#[derive(Clone, PartialEq, Eq, Default, MlsEncode, MlsDecode, MlsSize)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub struct KeySchedule {
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))]
+ exporter_secret: Zeroizing<Vec<u8>>,
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))]
+ pub authentication_secret: Zeroizing<Vec<u8>>,
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))]
+ external_secret: Zeroizing<Vec<u8>>,
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))]
+ membership_key: Zeroizing<Vec<u8>>,
+ init_secret: InitSecret,
+}
+
+impl Debug for KeySchedule {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("KeySchedule")
+ .field(
+ "exporter_secret",
+ &mls_rs_core::debug::pretty_bytes(&self.exporter_secret),
+ )
+ .field(
+ "authentication_secret",
+ &mls_rs_core::debug::pretty_bytes(&self.authentication_secret),
+ )
+ .field(
+ "external_secret",
+ &mls_rs_core::debug::pretty_bytes(&self.external_secret),
+ )
+ .field(
+ "membership_key",
+ &mls_rs_core::debug::pretty_bytes(&self.membership_key),
+ )
+ .field("init_secret", &self.init_secret)
+ .finish()
+ }
+}
+
+pub(crate) struct KeyScheduleDerivationResult {
+ pub(crate) key_schedule: KeySchedule,
+ pub(crate) confirmation_key: Zeroizing<Vec<u8>>,
+ pub(crate) joiner_secret: JoinerSecret,
+ pub(crate) epoch_secrets: EpochSecrets,
+}
+
+impl KeySchedule {
+ pub fn new(init_secret: InitSecret) -> Self {
+ KeySchedule {
+ init_secret,
+ ..Default::default()
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn derive_for_external<P: CipherSuiteProvider>(
+ &self,
+ kem_output: &[u8],
+ cipher_suite: &P,
+ ) -> Result<KeySchedule, MlsError> {
+ let (secret, public) = self.get_external_key_pair(cipher_suite).await?;
+
+ let init_secret =
+ InitSecret::decode_for_external(cipher_suite, kem_output, &secret, &public).await?;
+
+ Ok(KeySchedule::new(init_secret))
+ }
+
+ /// Returns the derived epoch as well as the joiner secret required for building welcome
+ /// messages
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn from_key_schedule<P: CipherSuiteProvider>(
+ last_key_schedule: &KeySchedule,
+ commit_secret: &PathSecret,
+ context: &GroupContext,
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ secret_tree_size: u32,
+ psk_secret: &PskSecret,
+ cipher_suite_provider: &P,
+ ) -> Result<KeyScheduleDerivationResult, MlsError> {
+ let joiner_seed = cipher_suite_provider
+ .kdf_extract(&last_key_schedule.init_secret.0, commit_secret)
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
+
+ let joiner_secret = kdf_expand_with_label(
+ cipher_suite_provider,
+ &joiner_seed,
+ b"joiner",
+ &context.mls_encode_to_vec()?,
+ None,
+ )
+ .await?
+ .into();
+
+ let key_schedule_result = Self::from_joiner(
+ cipher_suite_provider,
+ &joiner_secret,
+ context,
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ secret_tree_size,
+ psk_secret,
+ )
+ .await?;
+
+ Ok(KeyScheduleDerivationResult {
+ key_schedule: key_schedule_result.key_schedule,
+ confirmation_key: key_schedule_result.confirmation_key,
+ joiner_secret,
+ epoch_secrets: key_schedule_result.epoch_secrets,
+ })
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn from_joiner<P: CipherSuiteProvider>(
+ cipher_suite_provider: &P,
+ joiner_secret: &JoinerSecret,
+ context: &GroupContext,
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ secret_tree_size: u32,
+ psk_secret: &PskSecret,
+ ) -> Result<KeyScheduleDerivationResult, MlsError> {
+ let epoch_seed =
+ get_pre_epoch_secret(cipher_suite_provider, psk_secret, joiner_secret).await?;
+ let context = context.mls_encode_to_vec()?;
+
+ let epoch_secret =
+ kdf_expand_with_label(cipher_suite_provider, &epoch_seed, b"epoch", &context, None)
+ .await?;
+
+ Self::from_epoch_secret(
+ cipher_suite_provider,
+ &epoch_secret,
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ secret_tree_size,
+ )
+ .await
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn from_random_epoch_secret<P: CipherSuiteProvider>(
+ cipher_suite_provider: &P,
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ secret_tree_size: u32,
+ ) -> Result<KeyScheduleDerivationResult, MlsError> {
+ let epoch_secret = cipher_suite_provider
+ .random_bytes_vec(cipher_suite_provider.kdf_extract_size())
+ .map(Zeroizing::new)
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
+
+ Self::from_epoch_secret(
+ cipher_suite_provider,
+ &epoch_secret,
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ secret_tree_size,
+ )
+ .await
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn from_epoch_secret<P: CipherSuiteProvider>(
+ cipher_suite_provider: &P,
+ epoch_secret: &[u8],
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ secret_tree_size: u32,
+ ) -> Result<KeyScheduleDerivationResult, MlsError> {
+ let secrets_producer = SecretsProducer::new(cipher_suite_provider, epoch_secret);
+
+ let epoch_secrets = EpochSecrets {
+ #[cfg(feature = "psk")]
+ resumption_secret: PreSharedKey::from(secrets_producer.derive(b"resumption").await?),
+ sender_data_secret: SenderDataSecret::from(
+ secrets_producer.derive(b"sender data").await?,
+ ),
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ secret_tree: SecretTree::new(
+ secret_tree_size,
+ secrets_producer.derive(b"encryption").await?,
+ ),
+ };
+
+ let key_schedule = Self {
+ exporter_secret: secrets_producer.derive(b"exporter").await?,
+ authentication_secret: secrets_producer.derive(b"authentication").await?,
+ external_secret: secrets_producer.derive(b"external").await?,
+ membership_key: secrets_producer.derive(b"membership").await?,
+ init_secret: InitSecret(secrets_producer.derive(b"init").await?),
+ };
+
+ Ok(KeyScheduleDerivationResult {
+ key_schedule,
+ confirmation_key: secrets_producer.derive(b"confirm").await?,
+ joiner_secret: Zeroizing::new(vec![]).into(),
+ epoch_secrets,
+ })
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn export_secret<P: CipherSuiteProvider>(
+ &self,
+ label: &[u8],
+ context: &[u8],
+ len: usize,
+ cipher_suite: &P,
+ ) -> Result<Zeroizing<Vec<u8>>, MlsError> {
+ let secret = kdf_derive_secret(cipher_suite, &self.exporter_secret, label).await?;
+
+ let context_hash = cipher_suite
+ .hash(context)
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
+
+ kdf_expand_with_label(cipher_suite, &secret, b"exported", &context_hash, Some(len)).await
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn get_membership_tag<P: CipherSuiteProvider>(
+ &self,
+ content: &AuthenticatedContent,
+ context: &GroupContext,
+ cipher_suite_provider: &P,
+ ) -> Result<MembershipTag, MlsError> {
+ MembershipTag::create(
+ content,
+ context,
+ &self.membership_key,
+ cipher_suite_provider,
+ )
+ .await
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn get_external_key_pair<P: CipherSuiteProvider>(
+ &self,
+ cipher_suite: &P,
+ ) -> Result<(HpkeSecretKey, HpkePublicKey), MlsError> {
+ cipher_suite
+ .kem_derive(&self.external_secret)
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn get_external_key_pair_ext<P: CipherSuiteProvider>(
+ &self,
+ cipher_suite: &P,
+ ) -> Result<ExternalPubExt, MlsError> {
+ let (_external_secret, external_pub) = self.get_external_key_pair(cipher_suite).await?;
+
+ Ok(ExternalPubExt { external_pub })
+ }
+}
+
+#[derive(MlsEncode, MlsSize)]
+struct Label<'a> {
+ length: u16,
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ label: Vec<u8>,
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ context: &'a [u8],
+}
+
+impl<'a> Label<'a> {
+ fn new(length: u16, label: &'a [u8], context: &'a [u8]) -> Self {
+ Self {
+ length,
+ label: [b"MLS 1.0 ", label].concat(),
+ context,
+ }
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(crate) async fn kdf_expand_with_label<P: CipherSuiteProvider>(
+ cipher_suite_provider: &P,
+ secret: &[u8],
+ label: &[u8],
+ context: &[u8],
+ len: Option<usize>,
+) -> Result<Zeroizing<Vec<u8>>, MlsError> {
+ let extract_size = cipher_suite_provider.kdf_extract_size();
+ let len = len.unwrap_or(extract_size);
+ let label = Label::new(len as u16, label, context);
+
+ cipher_suite_provider
+ .kdf_expand(secret, &label.mls_encode_to_vec()?, len)
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(crate) async fn kdf_derive_secret<P: CipherSuiteProvider>(
+ cipher_suite_provider: &P,
+ secret: &[u8],
+ label: &[u8],
+) -> Result<Zeroizing<Vec<u8>>, MlsError> {
+ kdf_expand_with_label(cipher_suite_provider, secret, label, &[], None).await
+}
+
+#[derive(Clone, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+pub(crate) struct JoinerSecret(#[mls_codec(with = "mls_rs_codec::byte_vec")] Zeroizing<Vec<u8>>);
+
+impl Debug for JoinerSecret {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ mls_rs_core::debug::pretty_bytes(&self.0)
+ .named("JoinerSecret")
+ .fmt(f)
+ }
+}
+
+impl From<Zeroizing<Vec<u8>>> for JoinerSecret {
+ fn from(bytes: Zeroizing<Vec<u8>>) -> Self {
+ Self(bytes)
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(crate) async fn get_pre_epoch_secret<P: CipherSuiteProvider>(
+ cipher_suite_provider: &P,
+ psk_secret: &PskSecret,
+ joiner_secret: &JoinerSecret,
+) -> Result<Zeroizing<Vec<u8>>, MlsError> {
+ cipher_suite_provider
+ .kdf_extract(&joiner_secret.0, psk_secret)
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
+}
+
+struct SecretsProducer<'a, P: CipherSuiteProvider> {
+ cipher_suite_provider: &'a P,
+ epoch_secret: &'a [u8],
+}
+
+impl<'a, P: CipherSuiteProvider> SecretsProducer<'a, P> {
+ fn new(cipher_suite_provider: &'a P, epoch_secret: &'a [u8]) -> Self {
+ Self {
+ cipher_suite_provider,
+ epoch_secret,
+ }
+ }
+
+ // TODO document somewhere in the crypto provider that the RFC defines the length of all secrets as
+ // KDF extract size but then inputs secrets as MAC keys etc, therefore, we require that these
+ // lengths match in the crypto provider
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn derive(&self, label: &[u8]) -> Result<Zeroizing<Vec<u8>>, MlsError> {
+ kdf_derive_secret(self.cipher_suite_provider, self.epoch_secret, label).await
+ }
+}
+
+const EXPORTER_CONTEXT: &[u8] = b"MLS 1.0 external init secret";
+
+#[derive(Clone, Eq, PartialEq, MlsEncode, MlsDecode, MlsSize, Default)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub struct InitSecret(
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))]
+ Zeroizing<Vec<u8>>,
+);
+
+impl Debug for InitSecret {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ mls_rs_core::debug::pretty_bytes(&self.0)
+ .named("InitSecret")
+ .fmt(f)
+ }
+}
+
+impl InitSecret {
+ /// Returns init secret and KEM output to be used when creating an external commit.
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn encode_for_external<P: CipherSuiteProvider>(
+ cipher_suite: &P,
+ external_pub: &HpkePublicKey,
+ ) -> Result<(Self, Vec<u8>), MlsError> {
+ let (kem_output, context) = cipher_suite
+ .hpke_setup_s(external_pub, &[])
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
+
+ let init_secret = context
+ .export(EXPORTER_CONTEXT, cipher_suite.kdf_extract_size())
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
+
+ Ok((InitSecret(Zeroizing::new(init_secret)), kem_output))
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn decode_for_external<P: CipherSuiteProvider>(
+ cipher_suite: &P,
+ kem_output: &[u8],
+ external_secret: &HpkeSecretKey,
+ external_pub: &HpkePublicKey,
+ ) -> Result<Self, MlsError> {
+ let context = cipher_suite
+ .hpke_setup_r(kem_output, external_secret, external_pub, &[])
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
+
+ context
+ .export(EXPORTER_CONTEXT, cipher_suite.kdf_extract_size())
+ .await
+ .map(Zeroizing::new)
+ .map(InitSecret)
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
+ }
+}
+
+pub(crate) struct WelcomeSecret<'a, P: CipherSuiteProvider> {
+ cipher_suite: &'a P,
+ key: Zeroizing<Vec<u8>>,
+ nonce: Zeroizing<Vec<u8>>,
+}
+
+impl<'a, P: CipherSuiteProvider> WelcomeSecret<'a, P> {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn from_joiner_secret(
+ cipher_suite: &'a P,
+ joiner_secret: &JoinerSecret,
+ psk_secret: &PskSecret,
+ ) -> Result<WelcomeSecret<'a, P>, MlsError> {
+ let welcome_secret = get_welcome_secret(cipher_suite, joiner_secret, psk_secret).await?;
+
+ let key_len = cipher_suite.aead_key_size();
+ let key = kdf_expand_with_label(cipher_suite, &welcome_secret, b"key", &[], Some(key_len))
+ .await?;
+
+ let nonce_len = cipher_suite.aead_nonce_size();
+
+ let nonce = kdf_expand_with_label(
+ cipher_suite,
+ &welcome_secret,
+ b"nonce",
+ &[],
+ Some(nonce_len),
+ )
+ .await?;
+
+ Ok(Self {
+ cipher_suite,
+ key,
+ nonce,
+ })
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn encrypt(&self, plaintext: &[u8]) -> Result<Vec<u8>, MlsError> {
+ self.cipher_suite
+ .aead_seal(&self.key, plaintext, None, &self.nonce)
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn decrypt(&self, ciphertext: &[u8]) -> Result<Zeroizing<Vec<u8>>, MlsError> {
+ self.cipher_suite
+ .aead_open(&self.key, ciphertext, None, &self.nonce)
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+async fn get_welcome_secret<P: CipherSuiteProvider>(
+ cipher_suite: &P,
+ joiner_secret: &JoinerSecret,
+ psk_secret: &PskSecret,
+) -> Result<Zeroizing<Vec<u8>>, MlsError> {
+ let epoch_seed = get_pre_epoch_secret(cipher_suite, psk_secret, joiner_secret).await?;
+ kdf_derive_secret(cipher_suite, &epoch_seed, b"welcome").await
+}
+
+#[cfg(test)]
+pub(crate) mod test_utils {
+ use alloc::vec;
+ use alloc::vec::Vec;
+ use mls_rs_core::crypto::CipherSuiteProvider;
+ use zeroize::Zeroizing;
+
+ use crate::{cipher_suite::CipherSuite, crypto::test_utils::test_cipher_suite_provider};
+
+ use super::{InitSecret, JoinerSecret, KeySchedule};
+
+ #[cfg(all(feature = "rfc_compliant", not(mls_build_async)))]
+ use mls_rs_core::error::IntoAnyError;
+
+ #[cfg(all(feature = "rfc_compliant", not(mls_build_async)))]
+ use super::MlsError;
+
+ impl From<JoinerSecret> for Vec<u8> {
+ fn from(mut value: JoinerSecret) -> Self {
+ core::mem::take(&mut value.0)
+ }
+ }
+
+ pub(crate) fn get_test_key_schedule(cipher_suite: CipherSuite) -> KeySchedule {
+ let key_size = test_cipher_suite_provider(cipher_suite).kdf_extract_size();
+ let fake_secret = Zeroizing::new(vec![1u8; key_size]);
+
+ KeySchedule {
+ exporter_secret: fake_secret.clone(),
+ authentication_secret: fake_secret.clone(),
+ external_secret: fake_secret.clone(),
+ membership_key: fake_secret,
+ init_secret: InitSecret::new(vec![0u8; key_size]),
+ }
+ }
+
+ impl InitSecret {
+ pub fn new(init_secret: Vec<u8>) -> Self {
+ InitSecret(Zeroizing::new(init_secret))
+ }
+
+ #[cfg(all(feature = "rfc_compliant", test, not(mls_build_async)))]
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ pub fn random<P: CipherSuiteProvider>(cipher_suite: &P) -> Result<Self, MlsError> {
+ cipher_suite
+ .random_bytes_vec(cipher_suite.kdf_extract_size())
+ .map(Zeroizing::new)
+ .map(InitSecret)
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
+ }
+ }
+
+ #[cfg(feature = "rfc_compliant")]
+ impl KeySchedule {
+ pub fn set_membership_key(&mut self, key: Vec<u8>) {
+ self.membership_key = Zeroizing::new(key)
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use crate::client::test_utils::TEST_PROTOCOL_VERSION;
+ use crate::crypto::test_utils::try_test_cipher_suite_provider;
+ use crate::group::key_schedule::{
+ get_welcome_secret, kdf_derive_secret, kdf_expand_with_label,
+ };
+ use crate::group::GroupContext;
+ use alloc::string::String;
+ use alloc::vec::Vec;
+ use mls_rs_codec::MlsEncode;
+ use mls_rs_core::crypto::CipherSuiteProvider;
+ use mls_rs_core::extension::ExtensionList;
+
+ #[cfg(all(not(mls_build_async), feature = "rfc_compliant"))]
+ use crate::{
+ crypto::test_utils::{test_cipher_suite_provider, TestCryptoProvider},
+ group::{
+ key_schedule::KeyScheduleDerivationResult, test_utils::random_bytes, InitSecret,
+ PskSecret,
+ },
+ };
+
+ #[cfg(all(not(mls_build_async), feature = "rfc_compliant"))]
+ use alloc::{string::ToString, vec};
+
+ #[cfg(target_arch = "wasm32")]
+ use wasm_bindgen_test::wasm_bindgen_test as test;
+ use zeroize::Zeroizing;
+
+ use super::test_utils::get_test_key_schedule;
+ use super::KeySchedule;
+
+ #[derive(serde::Deserialize, serde::Serialize)]
+ struct TestCase {
+ cipher_suite: u16,
+ #[serde(with = "hex::serde")]
+ group_id: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ initial_init_secret: Vec<u8>,
+ epochs: Vec<KeyScheduleEpoch>,
+ }
+
+ #[derive(serde::Deserialize, serde::Serialize)]
+ struct KeyScheduleEpoch {
+ #[serde(with = "hex::serde")]
+ commit_secret: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ psk_secret: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ confirmed_transcript_hash: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ tree_hash: Vec<u8>,
+
+ #[serde(with = "hex::serde")]
+ group_context: Vec<u8>,
+
+ #[serde(with = "hex::serde")]
+ joiner_secret: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ welcome_secret: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ init_secret: Vec<u8>,
+
+ #[serde(with = "hex::serde")]
+ sender_data_secret: Vec<u8>,
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ #[serde(with = "hex::serde")]
+ encryption_secret: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ exporter_secret: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ epoch_authenticator: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ external_secret: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ confirmation_key: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ membership_key: Vec<u8>,
+ #[cfg(feature = "psk")]
+ #[serde(with = "hex::serde")]
+ resumption_psk: Vec<u8>,
+
+ #[serde(with = "hex::serde")]
+ external_pub: Vec<u8>,
+
+ exporter: KeyScheduleExporter,
+ }
+
+ #[derive(serde::Deserialize, serde::Serialize)]
+ struct KeyScheduleExporter {
+ label: String,
+ #[serde(with = "hex::serde")]
+ context: Vec<u8>,
+ length: usize,
+ #[serde(with = "hex::serde")]
+ secret: Vec<u8>,
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_key_schedule() {
+ let test_cases: Vec<TestCase> =
+ load_test_case_json!(key_schedule_test_vector, generate_test_vector());
+
+ for test_case in test_cases {
+ let Some(cs_provider) = try_test_cipher_suite_provider(test_case.cipher_suite) else {
+ continue;
+ };
+
+ let mut key_schedule = get_test_key_schedule(cs_provider.cipher_suite());
+ key_schedule.init_secret.0 = Zeroizing::new(test_case.initial_init_secret);
+
+ for (i, epoch) in test_case.epochs.into_iter().enumerate() {
+ let context = GroupContext {
+ protocol_version: TEST_PROTOCOL_VERSION,
+ cipher_suite: cs_provider.cipher_suite(),
+ group_id: test_case.group_id.clone(),
+ epoch: i as u64,
+ tree_hash: epoch.tree_hash,
+ confirmed_transcript_hash: epoch.confirmed_transcript_hash.into(),
+ extensions: ExtensionList::new(),
+ };
+
+ assert_eq!(context.mls_encode_to_vec().unwrap(), epoch.group_context);
+
+ let psk = epoch.psk_secret.into();
+ let commit = epoch.commit_secret.into();
+
+ let key_schedule_res = KeySchedule::from_key_schedule(
+ &key_schedule,
+ &commit,
+ &context,
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ 32,
+ &psk,
+ &cs_provider,
+ )
+ .await
+ .unwrap();
+
+ key_schedule = key_schedule_res.key_schedule;
+
+ let welcome =
+ get_welcome_secret(&cs_provider, &key_schedule_res.joiner_secret, &psk)
+ .await
+ .unwrap();
+
+ assert_eq!(*welcome, epoch.welcome_secret);
+
+ let expected: Vec<u8> = key_schedule_res.joiner_secret.into();
+ assert_eq!(epoch.joiner_secret, expected);
+
+ assert_eq!(&key_schedule.init_secret.0.to_vec(), &epoch.init_secret);
+
+ assert_eq!(
+ epoch.sender_data_secret,
+ *key_schedule_res.epoch_secrets.sender_data_secret.to_vec()
+ );
+
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ assert_eq!(
+ epoch.encryption_secret,
+ *key_schedule_res.epoch_secrets.secret_tree.get_root_secret()
+ );
+
+ assert_eq!(epoch.exporter_secret, key_schedule.exporter_secret.to_vec());
+
+ assert_eq!(
+ epoch.epoch_authenticator,
+ key_schedule.authentication_secret.to_vec()
+ );
+
+ assert_eq!(epoch.external_secret, key_schedule.external_secret.to_vec());
+
+ assert_eq!(
+ epoch.confirmation_key,
+ key_schedule_res.confirmation_key.to_vec()
+ );
+
+ assert_eq!(epoch.membership_key, key_schedule.membership_key.to_vec());
+
+ #[cfg(feature = "psk")]
+ {
+ let expected: Vec<u8> =
+ key_schedule_res.epoch_secrets.resumption_secret.to_vec();
+
+ assert_eq!(epoch.resumption_psk, expected);
+ }
+
+ let (_external_sec, external_pub) = key_schedule
+ .get_external_key_pair(&cs_provider)
+ .await
+ .unwrap();
+
+ assert_eq!(epoch.external_pub, *external_pub);
+
+ let exp = epoch.exporter;
+
+ let exported = key_schedule
+ .export_secret(exp.label.as_bytes(), &exp.context, exp.length, &cs_provider)
+ .await
+ .unwrap();
+
+ assert_eq!(exported.to_vec(), exp.secret);
+ }
+ }
+ }
+
+ #[cfg(all(not(mls_build_async), feature = "rfc_compliant"))]
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn generate_test_vector() -> Vec<TestCase> {
+ let mut test_cases = vec![];
+
+ for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
+ let cs_provider = test_cipher_suite_provider(cipher_suite);
+ let key_size = cs_provider.kdf_extract_size();
+
+ let mut group_context = GroupContext {
+ protocol_version: TEST_PROTOCOL_VERSION,
+ cipher_suite: cs_provider.cipher_suite(),
+ group_id: b"my group 5".to_vec(),
+ epoch: 0,
+ tree_hash: random_bytes(key_size),
+ confirmed_transcript_hash: random_bytes(key_size).into(),
+ extensions: Default::default(),
+ };
+
+ let initial_init_secret = InitSecret::random(&cs_provider).unwrap();
+ let mut key_schedule = get_test_key_schedule(cs_provider.cipher_suite());
+ key_schedule.init_secret = initial_init_secret.clone();
+
+ let commit_secret = random_bytes(key_size).into();
+ let psk_secret = PskSecret::new(&cs_provider);
+
+ let key_schedule_res = KeySchedule::from_key_schedule(
+ &key_schedule,
+ &commit_secret,
+ &group_context,
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ 32,
+ &psk_secret,
+ &cs_provider,
+ )
+ .unwrap();
+
+ key_schedule = key_schedule_res.key_schedule.clone();
+
+ let epoch1 = KeyScheduleEpoch::new(
+ key_schedule_res,
+ psk_secret,
+ commit_secret.to_vec(),
+ &group_context,
+ &cs_provider,
+ );
+
+ group_context.epoch += 1;
+ group_context.confirmed_transcript_hash = random_bytes(key_size).into();
+ group_context.tree_hash = random_bytes(key_size);
+
+ let commit_secret = random_bytes(key_size).into();
+ let psk_secret = PskSecret::new(&cs_provider);
+
+ let key_schedule_res = KeySchedule::from_key_schedule(
+ &key_schedule,
+ &commit_secret,
+ &group_context,
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ 32,
+ &psk_secret,
+ &cs_provider,
+ )
+ .unwrap();
+
+ let epoch2 = KeyScheduleEpoch::new(
+ key_schedule_res,
+ psk_secret,
+ commit_secret.to_vec(),
+ &group_context,
+ &cs_provider,
+ );
+
+ let test_case = TestCase {
+ cipher_suite: cs_provider.cipher_suite().into(),
+ group_id: group_context.group_id.clone(),
+ initial_init_secret: initial_init_secret.0.to_vec(),
+ epochs: vec![epoch1, epoch2],
+ };
+
+ test_cases.push(test_case);
+ }
+
+ test_cases
+ }
+
+ #[cfg(not(all(not(mls_build_async), feature = "rfc_compliant")))]
+ fn generate_test_vector() -> Vec<TestCase> {
+ panic!("Tests cannot be generated in async mode");
+ }
+
+ #[cfg(all(not(mls_build_async), feature = "rfc_compliant"))]
+ impl KeyScheduleEpoch {
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn new<P: CipherSuiteProvider>(
+ key_schedule_res: KeyScheduleDerivationResult,
+ psk_secret: PskSecret,
+ commit_secret: Vec<u8>,
+ group_context: &GroupContext,
+ cs: &P,
+ ) -> Self {
+ let (_external_sec, external_pub) = key_schedule_res
+ .key_schedule
+ .get_external_key_pair(cs)
+ .unwrap();
+
+ let mut exporter = KeyScheduleExporter {
+ label: "exporter label 15".to_string(),
+ context: b"exporter context".to_vec(),
+ length: 64,
+ secret: vec![],
+ };
+
+ exporter.secret = key_schedule_res
+ .key_schedule
+ .export_secret(
+ exporter.label.as_bytes(),
+ &exporter.context,
+ exporter.length,
+ cs,
+ )
+ .unwrap()
+ .to_vec();
+
+ let welcome_secret =
+ get_welcome_secret(cs, &key_schedule_res.joiner_secret, &psk_secret)
+ .unwrap()
+ .to_vec();
+
+ KeyScheduleEpoch {
+ commit_secret,
+ welcome_secret,
+ psk_secret: psk_secret.to_vec(),
+ group_context: group_context.mls_encode_to_vec().unwrap(),
+ joiner_secret: key_schedule_res.joiner_secret.into(),
+ init_secret: key_schedule_res.key_schedule.init_secret.0.to_vec(),
+ sender_data_secret: key_schedule_res.epoch_secrets.sender_data_secret.to_vec(),
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ encryption_secret: key_schedule_res.epoch_secrets.secret_tree.get_root_secret(),
+ exporter_secret: key_schedule_res.key_schedule.exporter_secret.to_vec(),
+ epoch_authenticator: key_schedule_res.key_schedule.authentication_secret.to_vec(),
+ external_secret: key_schedule_res.key_schedule.external_secret.to_vec(),
+ confirmation_key: key_schedule_res.confirmation_key.to_vec(),
+ membership_key: key_schedule_res.key_schedule.membership_key.to_vec(),
+ #[cfg(feature = "psk")]
+ resumption_psk: key_schedule_res.epoch_secrets.resumption_secret.to_vec(),
+ external_pub: external_pub.to_vec(),
+ exporter,
+ confirmed_transcript_hash: group_context.confirmed_transcript_hash.to_vec(),
+ tree_hash: group_context.tree_hash.clone(),
+ }
+ }
+ }
+
+ #[derive(Debug, serde::Serialize, serde::Deserialize)]
+ struct ExpandWithLabelTestCase {
+ #[serde(with = "hex::serde")]
+ secret: Vec<u8>,
+ label: String,
+ #[serde(with = "hex::serde")]
+ context: Vec<u8>,
+ length: usize,
+ #[serde(with = "hex::serde")]
+ out: Vec<u8>,
+ }
+
+ #[derive(Debug, serde::Serialize, serde::Deserialize)]
+ struct DeriveSecretTestCase {
+ #[serde(with = "hex::serde")]
+ secret: Vec<u8>,
+ label: String,
+ #[serde(with = "hex::serde")]
+ out: Vec<u8>,
+ }
+
+ #[derive(Debug, serde::Serialize, serde::Deserialize)]
+ pub struct InteropTestCase {
+ cipher_suite: u16,
+ expand_with_label: ExpandWithLabelTestCase,
+ derive_secret: DeriveSecretTestCase,
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_basic_crypto_test_vectors() {
+ // The test vector can be found here https://github.com/mlswg/mls-implementations/blob/main/test-vectors/crypto-basics.json
+ let test_cases: Vec<InteropTestCase> =
+ load_test_case_json!(basic_crypto, Vec::<InteropTestCase>::new());
+
+ for test_case in test_cases {
+ if let Some(cs) = try_test_cipher_suite_provider(test_case.cipher_suite) {
+ let test_exp = &test_case.expand_with_label;
+
+ let computed = kdf_expand_with_label(
+ &cs,
+ &test_exp.secret,
+ test_exp.label.as_bytes(),
+ &test_exp.context,
+ Some(test_exp.length),
+ )
+ .await
+ .unwrap();
+
+ assert_eq!(&computed.to_vec(), &test_exp.out);
+
+ let test_derive = &test_case.derive_secret;
+
+ let computed =
+ kdf_derive_secret(&cs, &test_derive.secret, test_derive.label.as_bytes())
+ .await
+ .unwrap();
+
+ assert_eq!(&computed.to_vec(), &test_derive.out);
+ }
+ }
+ }
+}
diff --git a/src/group/membership_tag.rs b/src/group/membership_tag.rs
new file mode 100644
index 0000000..b28edea
--- /dev/null
+++ b/src/group/membership_tag.rs
@@ -0,0 +1,163 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use crate::client::MlsError;
+use crate::crypto::CipherSuiteProvider;
+use crate::group::message_signature::{AuthenticatedContentTBS, FramedContentAuthData};
+use crate::group::GroupContext;
+use alloc::vec::Vec;
+use core::{
+ fmt::{self, Debug},
+ ops::Deref,
+};
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+use mls_rs_core::error::IntoAnyError;
+
+use super::message_signature::AuthenticatedContent;
+
+#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode)]
+struct AuthenticatedContentTBM<'a> {
+ content_tbs: AuthenticatedContentTBS<'a>,
+ auth: &'a FramedContentAuthData,
+}
+
+impl<'a> AuthenticatedContentTBM<'a> {
+ pub fn from_authenticated_content(
+ auth_content: &'a AuthenticatedContent,
+ group_context: &'a GroupContext,
+ ) -> AuthenticatedContentTBM<'a> {
+ AuthenticatedContentTBM {
+ content_tbs: AuthenticatedContentTBS::from_authenticated_content(
+ auth_content,
+ Some(group_context),
+ group_context.protocol_version,
+ ),
+ auth: &auth_content.auth,
+ }
+ }
+}
+
+#[derive(Clone, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+pub struct MembershipTag(#[mls_codec(with = "mls_rs_codec::byte_vec")] Vec<u8>);
+
+impl Debug for MembershipTag {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ mls_rs_core::debug::pretty_bytes(&self.0)
+ .named("MembershipTag")
+ .fmt(f)
+ }
+}
+
+impl Deref for MembershipTag {
+ type Target = Vec<u8>;
+
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+impl From<Vec<u8>> for MembershipTag {
+ fn from(m: Vec<u8>) -> Self {
+ Self(m)
+ }
+}
+
+impl MembershipTag {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn create<P: CipherSuiteProvider>(
+ authenticated_content: &AuthenticatedContent,
+ group_context: &GroupContext,
+ membership_key: &[u8],
+ cipher_suite_provider: &P,
+ ) -> Result<Self, MlsError> {
+ let plaintext_tbm = AuthenticatedContentTBM::from_authenticated_content(
+ authenticated_content,
+ group_context,
+ );
+
+ let serialized_tbm = plaintext_tbm.mls_encode_to_vec()?;
+
+ let tag = cipher_suite_provider
+ .mac(membership_key, &serialized_tbm)
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
+
+ Ok(MembershipTag(tag))
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::crypto::test_utils::{test_cipher_suite_provider, try_test_cipher_suite_provider};
+ use crate::group::{
+ framing::test_utils::get_test_auth_content, test_utils::get_test_group_context,
+ };
+
+ #[cfg(not(mls_build_async))]
+ use crate::crypto::test_utils::TestCryptoProvider;
+
+ #[cfg(target_arch = "wasm32")]
+ use wasm_bindgen_test::wasm_bindgen_test as test;
+
+ #[derive(Debug, serde::Serialize, serde::Deserialize)]
+ struct TestCase {
+ cipher_suite: u16,
+ #[serde(with = "hex::serde")]
+ tag: Vec<u8>,
+ }
+
+ #[cfg(not(mls_build_async))]
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn generate_test_cases() -> Vec<TestCase> {
+ let mut test_cases = Vec::new();
+
+ for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
+ let tag = MembershipTag::create(
+ &get_test_auth_content(),
+ &get_test_group_context(1, cipher_suite),
+ b"membership_key".as_ref(),
+ &test_cipher_suite_provider(cipher_suite),
+ )
+ .unwrap();
+
+ test_cases.push(TestCase {
+ cipher_suite: cipher_suite.into(),
+ tag: tag.to_vec(),
+ });
+ }
+
+ test_cases
+ }
+
+ #[cfg(mls_build_async)]
+ fn generate_test_cases() -> Vec<TestCase> {
+ panic!("Tests cannot be generated in async mode");
+ }
+
+ fn load_test_cases() -> Vec<TestCase> {
+ load_test_case_json!(membership_tag, generate_test_cases())
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_membership_tag() {
+ for case in load_test_cases() {
+ let Some(cs_provider) = try_test_cipher_suite_provider(case.cipher_suite) else {
+ continue;
+ };
+
+ let tag = MembershipTag::create(
+ &get_test_auth_content(),
+ &get_test_group_context(1, cs_provider.cipher_suite()).await,
+ b"membership_key".as_ref(),
+ &test_cipher_suite_provider(cs_provider.cipher_suite()),
+ )
+ .await
+ .unwrap();
+
+ assert_eq!(**tag, case.tag);
+ }
+ }
+}
diff --git a/src/group/message_processor.rs b/src/group/message_processor.rs
new file mode 100644
index 0000000..8084a58
--- /dev/null
+++ b/src/group/message_processor.rs
@@ -0,0 +1,1039 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use super::{
+ commit_sender,
+ confirmation_tag::ConfirmationTag,
+ framing::{
+ ApplicationData, Content, ContentType, MlsMessage, MlsMessagePayload, PublicMessage, Sender,
+ },
+ message_signature::AuthenticatedContent,
+ mls_rules::{CommitDirection, MlsRules},
+ proposal_filter::ProposalBundle,
+ state::GroupState,
+ transcript_hash::InterimTranscriptHash,
+ transcript_hashes, validate_group_info_member, GroupContext, GroupInfo, Welcome,
+};
+use crate::{
+ client::MlsError,
+ key_package::validate_key_package_properties,
+ time::MlsTime,
+ tree_kem::{
+ leaf_node_validator::{LeafNodeValidator, ValidationContext},
+ node::LeafIndex,
+ path_secret::PathSecret,
+ validate_update_path, TreeKemPrivate, TreeKemPublic, ValidatedUpdatePath,
+ },
+ CipherSuiteProvider, KeyPackage,
+};
+#[cfg(mls_build_async)]
+use alloc::boxed::Box;
+use alloc::vec::Vec;
+use core::fmt::{self, Debug};
+use mls_rs_core::{
+ identity::IdentityProvider, protocol_version::ProtocolVersion, psk::PreSharedKeyStorage,
+};
+
+#[cfg(feature = "by_ref_proposal")]
+use super::proposal_ref::ProposalRef;
+
+#[cfg(not(feature = "by_ref_proposal"))]
+use crate::group::proposal_cache::resolve_for_commit;
+
+#[cfg(feature = "by_ref_proposal")]
+use super::proposal::Proposal;
+
+#[cfg(feature = "custom_proposal")]
+use super::proposal_filter::ProposalInfo;
+
+#[cfg(feature = "state_update")]
+use mls_rs_core::{
+ crypto::CipherSuite,
+ group::{MemberUpdate, RosterUpdate},
+};
+
+#[cfg(all(feature = "state_update", feature = "psk"))]
+use mls_rs_core::psk::ExternalPskId;
+
+#[cfg(feature = "state_update")]
+use crate::tree_kem::UpdatePath;
+
+#[cfg(feature = "state_update")]
+use super::{member_from_key_package, member_from_leaf_node};
+
+#[cfg(all(feature = "state_update", feature = "custom_proposal"))]
+use super::proposal::CustomProposal;
+
+#[cfg(feature = "private_message")]
+use crate::group::framing::PrivateMessage;
+
+#[cfg(feature = "by_ref_proposal")]
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+
+#[derive(Debug)]
+pub(crate) struct ProvisionalState {
+ pub(crate) public_tree: TreeKemPublic,
+ pub(crate) applied_proposals: ProposalBundle,
+ pub(crate) group_context: GroupContext,
+ pub(crate) external_init_index: Option<LeafIndex>,
+ pub(crate) indexes_of_added_kpkgs: Vec<LeafIndex>,
+ #[cfg(feature = "by_ref_proposal")]
+ pub(crate) unused_proposals: Vec<crate::mls_rules::ProposalInfo<Proposal>>,
+}
+
+//By default, the path field of a Commit MUST be populated. The path field MAY be omitted if
+//(a) it covers at least one proposal and (b) none of the proposals covered by the Commit are
+//of "path required" types. A proposal type requires a path if it cannot change the group
+//membership in a way that requires the forward secrecy and post-compromise security guarantees
+//that an UpdatePath provides. The only proposal types defined in this document that do not
+//require a path are:
+
+// add
+// psk
+// reinit
+pub(crate) fn path_update_required(proposals: &ProposalBundle) -> bool {
+ let res = proposals.external_init_proposals().first().is_some();
+
+ #[cfg(feature = "by_ref_proposal")]
+ let res = res || !proposals.update_proposals().is_empty();
+
+ res || proposals.length() == 0
+ || proposals.group_context_extensions_proposal().is_some()
+ || !proposals.remove_proposals().is_empty()
+}
+
+/// Representation of changes made by a [commit](crate::Group::commit).
+#[cfg(feature = "state_update")]
+#[derive(Clone, Debug, PartialEq)]
+pub struct StateUpdate {
+ pub(crate) roster_update: RosterUpdate,
+ #[cfg(feature = "psk")]
+ pub(crate) added_psks: Vec<ExternalPskId>,
+ pub(crate) pending_reinit: Option<CipherSuite>,
+ pub(crate) active: bool,
+ pub(crate) epoch: u64,
+ #[cfg(feature = "custom_proposal")]
+ pub(crate) custom_proposals: Vec<ProposalInfo<CustomProposal>>,
+ #[cfg(feature = "by_ref_proposal")]
+ pub(crate) unused_proposals: Vec<crate::mls_rules::ProposalInfo<Proposal>>,
+}
+
+#[cfg(not(feature = "state_update"))]
+#[non_exhaustive]
+#[derive(Clone, Debug, PartialEq)]
+pub struct StateUpdate {}
+
+#[cfg(feature = "state_update")]
+impl StateUpdate {
+ /// Changes to the roster as a result of proposals.
+ pub fn roster_update(&self) -> &RosterUpdate {
+ &self.roster_update
+ }
+
+ #[cfg(feature = "psk")]
+ /// Pre-shared keys that have been added to the group.
+ pub fn added_psks(&self) -> &[ExternalPskId] {
+ &self.added_psks
+ }
+
+ /// Flag to indicate if the group is now pending reinitialization due to
+ /// receiving a [`ReInit`](crate::group::proposal::Proposal::ReInit)
+ /// proposal.
+ pub fn is_pending_reinit(&self) -> bool {
+ self.pending_reinit.is_some()
+ }
+
+ /// Flag to indicate the group is still active. This will be false if the
+ /// member processing the commit has been removed from the group.
+ pub fn is_active(&self) -> bool {
+ self.active
+ }
+
+ /// The new epoch of the group state.
+ pub fn new_epoch(&self) -> u64 {
+ self.epoch
+ }
+
+ /// Custom proposals that were committed to.
+ #[cfg(feature = "custom_proposal")]
+ pub fn custom_proposals(&self) -> &[ProposalInfo<CustomProposal>] {
+ &self.custom_proposals
+ }
+
+ /// Proposals that were received in the prior epoch but not committed to.
+ #[cfg(feature = "by_ref_proposal")]
+ pub fn unused_proposals(&self) -> &[crate::mls_rules::ProposalInfo<Proposal>] {
+ &self.unused_proposals
+ }
+
+ pub fn pending_reinit_ciphersuite(&self) -> Option<CipherSuite> {
+ self.pending_reinit
+ }
+}
+
+#[cfg_attr(
+ all(feature = "ffi", not(test)),
+ safer_ffi_gen::ffi_type(clone, opaque)
+)]
+#[derive(Debug, Clone)]
+#[allow(clippy::large_enum_variant)]
+/// An event generated as a result of processing a message for a group with
+/// [`Group::process_incoming_message`](crate::group::Group::process_incoming_message).
+pub enum ReceivedMessage {
+ /// An application message was decrypted.
+ ApplicationMessage(ApplicationMessageDescription),
+ /// A new commit was processed creating a new group state.
+ Commit(CommitMessageDescription),
+ /// A proposal was received.
+ Proposal(ProposalMessageDescription),
+ /// Validated GroupInfo object
+ GroupInfo(GroupInfo),
+ /// Validated welcome message
+ Welcome,
+ /// Validated key package
+ KeyPackage(KeyPackage),
+}
+
+impl TryFrom<ApplicationMessageDescription> for ReceivedMessage {
+ type Error = MlsError;
+
+ fn try_from(value: ApplicationMessageDescription) -> Result<Self, Self::Error> {
+ Ok(ReceivedMessage::ApplicationMessage(value))
+ }
+}
+
+impl From<CommitMessageDescription> for ReceivedMessage {
+ fn from(value: CommitMessageDescription) -> Self {
+ ReceivedMessage::Commit(value)
+ }
+}
+
+impl From<ProposalMessageDescription> for ReceivedMessage {
+ fn from(value: ProposalMessageDescription) -> Self {
+ ReceivedMessage::Proposal(value)
+ }
+}
+
+impl From<GroupInfo> for ReceivedMessage {
+ fn from(value: GroupInfo) -> Self {
+ ReceivedMessage::GroupInfo(value)
+ }
+}
+
+impl From<Welcome> for ReceivedMessage {
+ fn from(_: Welcome) -> Self {
+ ReceivedMessage::Welcome
+ }
+}
+
+impl From<KeyPackage> for ReceivedMessage {
+ fn from(value: KeyPackage) -> Self {
+ ReceivedMessage::KeyPackage(value)
+ }
+}
+
+#[cfg_attr(
+ all(feature = "ffi", not(test)),
+ safer_ffi_gen::ffi_type(clone, opaque)
+)]
+#[derive(Clone, PartialEq, Eq)]
+/// Description of a MLS application message.
+pub struct ApplicationMessageDescription {
+ /// Index of this user in the group state.
+ pub sender_index: u32,
+ /// Received application data.
+ data: ApplicationData,
+ /// Plaintext authenticated data in the received MLS packet.
+ pub authenticated_data: Vec<u8>,
+}
+
+impl Debug for ApplicationMessageDescription {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("ApplicationMessageDescription")
+ .field("sender_index", &self.sender_index)
+ .field("data", &self.data)
+ .field(
+ "authenticated_data",
+ &mls_rs_core::debug::pretty_bytes(&self.authenticated_data),
+ )
+ .finish()
+ }
+}
+
+#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
+impl ApplicationMessageDescription {
+ pub fn data(&self) -> &[u8] {
+ self.data.as_bytes()
+ }
+}
+
+#[cfg_attr(
+ all(feature = "ffi", not(test)),
+ safer_ffi_gen::ffi_type(clone, opaque)
+)]
+#[derive(Clone, PartialEq)]
+#[non_exhaustive]
+/// Description of a processed MLS commit message.
+pub struct CommitMessageDescription {
+ /// True if this is the result of an external commit.
+ pub is_external: bool,
+ /// The index in the group state of the member who performed this commit.
+ pub committer: u32,
+ /// A full description of group state changes as a result of this commit.
+ pub state_update: StateUpdate,
+ /// Plaintext authenticated data in the received MLS packet.
+ pub authenticated_data: Vec<u8>,
+}
+
+impl Debug for CommitMessageDescription {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("CommitMessageDescription")
+ .field("is_external", &self.is_external)
+ .field("committer", &self.committer)
+ .field("state_update", &self.state_update)
+ .field(
+ "authenticated_data",
+ &mls_rs_core::debug::pretty_bytes(&self.authenticated_data),
+ )
+ .finish()
+ }
+}
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+/// Proposal sender type.
+pub enum ProposalSender {
+ /// A current member of the group by index in the group state.
+ Member(u32),
+ /// An external entity by index within an
+ /// [`ExternalSendersExt`](crate::extension::built_in::ExternalSendersExt).
+ External(u32),
+ /// A new member proposing their addition to the group.
+ NewMember,
+}
+
+impl TryFrom<Sender> for ProposalSender {
+ type Error = MlsError;
+
+ fn try_from(value: Sender) -> Result<Self, Self::Error> {
+ match value {
+ Sender::Member(index) => Ok(Self::Member(index)),
+ #[cfg(feature = "by_ref_proposal")]
+ Sender::External(index) => Ok(Self::External(index)),
+ #[cfg(feature = "by_ref_proposal")]
+ Sender::NewMemberProposal => Ok(Self::NewMember),
+ Sender::NewMemberCommit => Err(MlsError::InvalidSender),
+ }
+ }
+}
+
+#[cfg(feature = "by_ref_proposal")]
+#[cfg_attr(
+ all(feature = "ffi", not(test)),
+ safer_ffi_gen::ffi_type(clone, opaque)
+)]
+#[derive(Clone)]
+#[non_exhaustive]
+/// Description of a processed MLS proposal message.
+pub struct ProposalMessageDescription {
+ /// Sender of the proposal.
+ pub sender: ProposalSender,
+ /// Proposal content.
+ pub proposal: Proposal,
+ /// Plaintext authenticated data in the received MLS packet.
+ pub authenticated_data: Vec<u8>,
+ /// Proposal reference.
+ pub proposal_ref: ProposalRef,
+}
+
+#[cfg(feature = "by_ref_proposal")]
+impl Debug for ProposalMessageDescription {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("ProposalMessageDescription")
+ .field("sender", &self.sender)
+ .field("proposal", &self.proposal)
+ .field(
+ "authenticated_data",
+ &mls_rs_core::debug::pretty_bytes(&self.authenticated_data),
+ )
+ .field("proposal_ref", &self.proposal_ref)
+ .finish()
+ }
+}
+
+#[cfg(feature = "by_ref_proposal")]
+#[derive(MlsSize, MlsEncode, MlsDecode)]
+pub struct CachedProposal {
+ pub(crate) proposal: Proposal,
+ pub(crate) proposal_ref: ProposalRef,
+ pub(crate) sender: Sender,
+}
+
+#[cfg(feature = "by_ref_proposal")]
+impl CachedProposal {
+ /// Deserialize the proposal
+ pub fn from_bytes(bytes: &[u8]) -> Result<Self, MlsError> {
+ Ok(Self::mls_decode(&mut &*bytes)?)
+ }
+
+ /// Serialize the proposal
+ pub fn to_bytes(&self) -> Result<Vec<u8>, MlsError> {
+ Ok(self.mls_encode_to_vec()?)
+ }
+}
+
+#[cfg(feature = "by_ref_proposal")]
+impl ProposalMessageDescription {
+ pub fn cached_proposal(self) -> CachedProposal {
+ let sender = match self.sender {
+ ProposalSender::Member(i) => Sender::Member(i),
+ ProposalSender::External(i) => Sender::External(i),
+ ProposalSender::NewMember => Sender::NewMemberProposal,
+ };
+
+ CachedProposal {
+ proposal: self.proposal,
+ proposal_ref: self.proposal_ref,
+ sender,
+ }
+ }
+
+ pub fn proposal_ref(&self) -> Vec<u8> {
+ self.proposal_ref.to_vec()
+ }
+}
+
+#[cfg(not(feature = "by_ref_proposal"))]
+#[cfg_attr(
+ all(feature = "ffi", not(test)),
+ safer_ffi_gen::ffi_type(clone, opaque)
+)]
+#[derive(Debug, Clone)]
+/// Description of a processed MLS proposal message.
+pub struct ProposalMessageDescription {}
+
+#[allow(clippy::large_enum_variant)]
+pub(crate) enum EventOrContent<E> {
+ #[cfg_attr(
+ not(all(feature = "private_message", feature = "external_client")),
+ allow(dead_code)
+ )]
+ Event(E),
+ Content(AuthenticatedContent),
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(all(target_arch = "wasm32", mls_build_async), maybe_async::must_be_async(?Send))]
+#[cfg_attr(
+ all(not(target_arch = "wasm32"), mls_build_async),
+ maybe_async::must_be_async
+)]
+pub(crate) trait MessageProcessor: Send + Sync {
+ type OutputType: TryFrom<ApplicationMessageDescription, Error = MlsError>
+ + From<CommitMessageDescription>
+ + From<ProposalMessageDescription>
+ + From<GroupInfo>
+ + From<Welcome>
+ + From<KeyPackage>
+ + Send;
+
+ type MlsRules: MlsRules;
+ type IdentityProvider: IdentityProvider;
+ type CipherSuiteProvider: CipherSuiteProvider;
+ type PreSharedKeyStorage: PreSharedKeyStorage;
+
+ async fn process_incoming_message(
+ &mut self,
+ message: MlsMessage,
+ #[cfg(feature = "by_ref_proposal")] cache_proposal: bool,
+ ) -> Result<Self::OutputType, MlsError> {
+ self.process_incoming_message_with_time(
+ message,
+ #[cfg(feature = "by_ref_proposal")]
+ cache_proposal,
+ None,
+ )
+ .await
+ }
+
+ async fn process_incoming_message_with_time(
+ &mut self,
+ message: MlsMessage,
+ #[cfg(feature = "by_ref_proposal")] cache_proposal: bool,
+ time_sent: Option<MlsTime>,
+ ) -> Result<Self::OutputType, MlsError> {
+ let event_or_content = self.get_event_from_incoming_message(message).await?;
+
+ self.process_event_or_content(
+ event_or_content,
+ #[cfg(feature = "by_ref_proposal")]
+ cache_proposal,
+ time_sent,
+ )
+ .await
+ }
+
+ async fn get_event_from_incoming_message(
+ &mut self,
+ message: MlsMessage,
+ ) -> Result<EventOrContent<Self::OutputType>, MlsError> {
+ self.check_metadata(&message)?;
+
+ match message.payload {
+ MlsMessagePayload::Plain(plaintext) => {
+ self.verify_plaintext_authentication(plaintext).await
+ }
+ #[cfg(feature = "private_message")]
+ MlsMessagePayload::Cipher(cipher_text) => self.process_ciphertext(&cipher_text).await,
+ MlsMessagePayload::GroupInfo(group_info) => {
+ validate_group_info_member(
+ self.group_state(),
+ message.version,
+ &group_info,
+ self.cipher_suite_provider(),
+ )
+ .await?;
+
+ Ok(EventOrContent::Event(group_info.into()))
+ }
+ MlsMessagePayload::Welcome(welcome) => {
+ self.validate_welcome(&welcome, message.version)?;
+
+ Ok(EventOrContent::Event(welcome.into()))
+ }
+ MlsMessagePayload::KeyPackage(key_package) => {
+ self.validate_key_package(&key_package, message.version)
+ .await?;
+
+ Ok(EventOrContent::Event(key_package.into()))
+ }
+ }
+ }
+
+ async fn process_event_or_content(
+ &mut self,
+ event_or_content: EventOrContent<Self::OutputType>,
+ #[cfg(feature = "by_ref_proposal")] cache_proposal: bool,
+ time_sent: Option<MlsTime>,
+ ) -> Result<Self::OutputType, MlsError> {
+ let msg = match event_or_content {
+ EventOrContent::Event(event) => event,
+ EventOrContent::Content(content) => {
+ self.process_auth_content(
+ content,
+ #[cfg(feature = "by_ref_proposal")]
+ cache_proposal,
+ time_sent,
+ )
+ .await?
+ }
+ };
+
+ Ok(msg)
+ }
+
+ async fn process_auth_content(
+ &mut self,
+ auth_content: AuthenticatedContent,
+ #[cfg(feature = "by_ref_proposal")] cache_proposal: bool,
+ time_sent: Option<MlsTime>,
+ ) -> Result<Self::OutputType, MlsError> {
+ let event = match auth_content.content.content {
+ #[cfg(feature = "private_message")]
+ Content::Application(data) => {
+ let authenticated_data = auth_content.content.authenticated_data;
+ let sender = auth_content.content.sender;
+
+ self.process_application_message(data, sender, authenticated_data)
+ .and_then(Self::OutputType::try_from)
+ }
+ Content::Commit(_) => self
+ .process_commit(auth_content, time_sent)
+ .await
+ .map(Self::OutputType::from),
+ #[cfg(feature = "by_ref_proposal")]
+ Content::Proposal(ref proposal) => self
+ .process_proposal(&auth_content, proposal, cache_proposal)
+ .await
+ .map(Self::OutputType::from),
+ }?;
+
+ Ok(event)
+ }
+
+ #[cfg(feature = "private_message")]
+ fn process_application_message(
+ &self,
+ data: ApplicationData,
+ sender: Sender,
+ authenticated_data: Vec<u8>,
+ ) -> Result<ApplicationMessageDescription, MlsError> {
+ let Sender::Member(sender_index) = sender else {
+ return Err(MlsError::InvalidSender);
+ };
+
+ Ok(ApplicationMessageDescription {
+ authenticated_data,
+ sender_index,
+ data,
+ })
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn process_proposal(
+ &mut self,
+ auth_content: &AuthenticatedContent,
+ proposal: &Proposal,
+ cache_proposal: bool,
+ ) -> Result<ProposalMessageDescription, MlsError> {
+ let proposal_ref =
+ ProposalRef::from_content(self.cipher_suite_provider(), auth_content).await?;
+
+ let group_state = self.group_state_mut();
+
+ if cache_proposal {
+ let proposal_ref = proposal_ref.clone();
+
+ group_state.proposals.insert(
+ proposal_ref.clone(),
+ proposal.clone(),
+ auth_content.content.sender,
+ );
+ }
+
+ Ok(ProposalMessageDescription {
+ authenticated_data: auth_content.content.authenticated_data.clone(),
+ proposal: proposal.clone(),
+ sender: auth_content.content.sender.try_into()?,
+ proposal_ref,
+ })
+ }
+
+ #[cfg(feature = "state_update")]
+ async fn make_state_update(
+ &self,
+ provisional: &ProvisionalState,
+ path: Option<&UpdatePath>,
+ sender: LeafIndex,
+ ) -> Result<StateUpdate, MlsError> {
+ let added = provisional
+ .applied_proposals
+ .additions
+ .iter()
+ .zip(provisional.indexes_of_added_kpkgs.iter())
+ .map(|(p, index)| member_from_key_package(&p.proposal.key_package, *index))
+ .collect::<Vec<_>>();
+
+ let mut added = added;
+
+ let old_tree = &self.group_state().public_tree;
+
+ let removed = provisional
+ .applied_proposals
+ .removals
+ .iter()
+ .map(|p| {
+ let index = p.proposal.to_remove;
+ let node = old_tree.nodes.borrow_as_leaf(index)?;
+ Ok(member_from_leaf_node(node, index))
+ })
+ .collect::<Result<_, MlsError>>()?;
+
+ #[cfg(feature = "by_ref_proposal")]
+ let mut updated = provisional
+ .applied_proposals
+ .update_senders
+ .iter()
+ .map(|index| {
+ let prior = old_tree
+ .get_leaf_node(*index)
+ .map(|n| member_from_leaf_node(n, *index))?;
+
+ let new = provisional
+ .public_tree
+ .get_leaf_node(*index)
+ .map(|n| member_from_leaf_node(n, *index))?;
+
+ Ok::<_, MlsError>(MemberUpdate::new(prior, new))
+ })
+ .collect::<Result<Vec<_>, _>>()?;
+
+ #[cfg(not(feature = "by_ref_proposal"))]
+ let mut updated = Vec::new();
+
+ if let Some(path) = path {
+ if !provisional
+ .applied_proposals
+ .external_initializations
+ .is_empty()
+ {
+ added.push(member_from_leaf_node(&path.leaf_node, sender))
+ } else {
+ let prior = old_tree
+ .get_leaf_node(sender)
+ .map(|n| member_from_leaf_node(n, sender))?;
+
+ let new = member_from_leaf_node(&path.leaf_node, sender);
+
+ updated.push(MemberUpdate::new(prior, new))
+ }
+ }
+
+ #[cfg(feature = "psk")]
+ let psks = provisional
+ .applied_proposals
+ .psks
+ .iter()
+ .filter_map(|psk| psk.proposal.external_psk_id().cloned())
+ .collect::<Vec<_>>();
+
+ let roster_update = RosterUpdate::new(added, removed, updated);
+
+ let update = StateUpdate {
+ roster_update,
+ #[cfg(feature = "psk")]
+ added_psks: psks,
+ pending_reinit: provisional
+ .applied_proposals
+ .reinitializations
+ .first()
+ .map(|ri| ri.proposal.new_cipher_suite()),
+ active: true,
+ epoch: provisional.group_context.epoch,
+ #[cfg(feature = "custom_proposal")]
+ custom_proposals: provisional.applied_proposals.custom_proposals.clone(),
+ #[cfg(feature = "by_ref_proposal")]
+ unused_proposals: provisional.unused_proposals.clone(),
+ };
+
+ Ok(update)
+ }
+
+ async fn process_commit(
+ &mut self,
+ auth_content: AuthenticatedContent,
+ time_sent: Option<MlsTime>,
+ ) -> Result<CommitMessageDescription, MlsError> {
+ if self.group_state().pending_reinit.is_some() {
+ return Err(MlsError::GroupUsedAfterReInit);
+ }
+
+ // Update the new GroupContext's confirmed and interim transcript hashes using the new Commit.
+ let (interim_transcript_hash, confirmed_transcript_hash) = transcript_hashes(
+ self.cipher_suite_provider(),
+ &self.group_state().interim_transcript_hash,
+ &auth_content,
+ )
+ .await?;
+
+ #[cfg(any(feature = "private_message", feature = "by_ref_proposal"))]
+ let commit = match auth_content.content.content {
+ Content::Commit(commit) => Ok(commit),
+ _ => Err(MlsError::UnexpectedMessageType),
+ }?;
+
+ #[cfg(not(any(feature = "private_message", feature = "by_ref_proposal")))]
+ let Content::Commit(commit) = auth_content.content.content;
+
+ let group_state = self.group_state();
+ let id_provider = self.identity_provider();
+
+ #[cfg(feature = "by_ref_proposal")]
+ let proposals = group_state
+ .proposals
+ .resolve_for_commit(auth_content.content.sender, commit.proposals)?;
+
+ #[cfg(not(feature = "by_ref_proposal"))]
+ let proposals = resolve_for_commit(auth_content.content.sender, commit.proposals)?;
+
+ let mut provisional_state = group_state
+ .apply_resolved(
+ auth_content.content.sender,
+ proposals,
+ commit.path.as_ref().map(|path| &path.leaf_node),
+ &id_provider,
+ self.cipher_suite_provider(),
+ &self.psk_storage(),
+ &self.mls_rules(),
+ time_sent,
+ CommitDirection::Receive,
+ )
+ .await?;
+
+ let sender = commit_sender(&auth_content.content.sender, &provisional_state)?;
+
+ #[cfg(feature = "state_update")]
+ let mut state_update = self
+ .make_state_update(&provisional_state, commit.path.as_ref(), sender)
+ .await?;
+
+ #[cfg(not(feature = "state_update"))]
+ let state_update = StateUpdate {};
+
+ //Verify that the path value is populated if the proposals vector contains any Update
+ // or Remove proposals, or if it's empty. Otherwise, the path value MAY be omitted.
+ if path_update_required(&provisional_state.applied_proposals) && commit.path.is_none() {
+ return Err(MlsError::CommitMissingPath);
+ }
+
+ if !self.can_continue_processing(&provisional_state) {
+ #[cfg(feature = "state_update")]
+ {
+ state_update.active = false;
+ }
+
+ return Ok(CommitMessageDescription {
+ is_external: matches!(auth_content.content.sender, Sender::NewMemberCommit),
+ authenticated_data: auth_content.content.authenticated_data,
+ committer: *sender,
+ state_update,
+ });
+ }
+
+ let update_path = match commit.path {
+ Some(update_path) => Some(
+ validate_update_path(
+ &self.identity_provider(),
+ self.cipher_suite_provider(),
+ update_path,
+ &provisional_state,
+ sender,
+ time_sent,
+ )
+ .await?,
+ ),
+ None => None,
+ };
+
+ let new_secrets = match update_path {
+ Some(update_path) => {
+ self.apply_update_path(sender, &update_path, &mut provisional_state)
+ .await
+ }
+ None => Ok(None),
+ }?;
+
+ // Update the transcript hash to get the new context.
+ provisional_state.group_context.confirmed_transcript_hash = confirmed_transcript_hash;
+
+ // Update the parent hashes in the new context
+ provisional_state
+ .public_tree
+ .update_hashes(&[sender], self.cipher_suite_provider())
+ .await?;
+
+ // Update the tree hash in the new context
+ provisional_state.group_context.tree_hash = provisional_state
+ .public_tree
+ .tree_hash(self.cipher_suite_provider())
+ .await?;
+
+ if let Some(reinit) = provisional_state.applied_proposals.reinitializations.pop() {
+ self.group_state_mut().pending_reinit = Some(reinit.proposal);
+
+ #[cfg(feature = "state_update")]
+ {
+ state_update.active = false;
+ }
+ }
+
+ if let Some(confirmation_tag) = &auth_content.auth.confirmation_tag {
+ // Update the key schedule to calculate new private keys
+ self.update_key_schedule(
+ new_secrets,
+ interim_transcript_hash,
+ confirmation_tag,
+ provisional_state,
+ )
+ .await?;
+
+ Ok(CommitMessageDescription {
+ is_external: matches!(auth_content.content.sender, Sender::NewMemberCommit),
+ authenticated_data: auth_content.content.authenticated_data,
+ committer: *sender,
+ state_update,
+ })
+ } else {
+ Err(MlsError::InvalidConfirmationTag)
+ }
+ }
+
+ fn group_state(&self) -> &GroupState;
+ fn group_state_mut(&mut self) -> &mut GroupState;
+ #[cfg(feature = "private_message")]
+ fn self_index(&self) -> Option<LeafIndex>;
+ fn mls_rules(&self) -> Self::MlsRules;
+ fn identity_provider(&self) -> Self::IdentityProvider;
+ fn cipher_suite_provider(&self) -> &Self::CipherSuiteProvider;
+ fn psk_storage(&self) -> Self::PreSharedKeyStorage;
+ fn can_continue_processing(&self, provisional_state: &ProvisionalState) -> bool;
+
+ #[cfg(feature = "private_message")]
+ fn min_epoch_available(&self) -> Option<u64>;
+
+ fn check_metadata(&self, message: &MlsMessage) -> Result<(), MlsError> {
+ let context = &self.group_state().context;
+
+ if message.version != context.protocol_version {
+ return Err(MlsError::ProtocolVersionMismatch);
+ }
+
+ if let Some((group_id, epoch, content_type)) = match &message.payload {
+ MlsMessagePayload::Plain(plaintext) => Some((
+ &plaintext.content.group_id,
+ plaintext.content.epoch,
+ plaintext.content.content_type(),
+ )),
+ #[cfg(feature = "private_message")]
+ MlsMessagePayload::Cipher(ciphertext) => Some((
+ &ciphertext.group_id,
+ ciphertext.epoch,
+ ciphertext.content_type,
+ )),
+ _ => None,
+ } {
+ if group_id != &context.group_id {
+ return Err(MlsError::GroupIdMismatch);
+ }
+
+ match content_type {
+ ContentType::Commit => {
+ if context.epoch != epoch {
+ Err(MlsError::InvalidEpoch)
+ } else {
+ Ok(())
+ }
+ }
+ #[cfg(feature = "by_ref_proposal")]
+ ContentType::Proposal => {
+ if context.epoch != epoch {
+ Err(MlsError::InvalidEpoch)
+ } else {
+ Ok(())
+ }
+ }
+ #[cfg(feature = "private_message")]
+ ContentType::Application => {
+ if let Some(min) = self.min_epoch_available() {
+ if epoch < min {
+ Err(MlsError::InvalidEpoch)
+ } else {
+ Ok(())
+ }
+ } else {
+ Ok(())
+ }
+ }
+ }?;
+
+ // Proposal and commit messages must be sent in the current epoch
+ let check_epoch = content_type == ContentType::Commit;
+
+ #[cfg(feature = "by_ref_proposal")]
+ let check_epoch = check_epoch || content_type == ContentType::Proposal;
+
+ if check_epoch && epoch != context.epoch {
+ return Err(MlsError::InvalidEpoch);
+ }
+
+ // Unencrypted application messages are not allowed
+ #[cfg(feature = "private_message")]
+ if !matches!(&message.payload, MlsMessagePayload::Cipher(_))
+ && content_type == ContentType::Application
+ {
+ return Err(MlsError::UnencryptedApplicationMessage);
+ }
+ }
+
+ Ok(())
+ }
+
+ fn validate_welcome(
+ &self,
+ welcome: &Welcome,
+ version: ProtocolVersion,
+ ) -> Result<(), MlsError> {
+ let state = self.group_state();
+
+ (welcome.cipher_suite == state.context.cipher_suite
+ && version == state.context.protocol_version)
+ .then_some(())
+ .ok_or(MlsError::InvalidWelcomeMessage)
+ }
+
+ async fn validate_key_package(
+ &self,
+ key_package: &KeyPackage,
+ version: ProtocolVersion,
+ ) -> Result<(), MlsError> {
+ let cs = self.cipher_suite_provider();
+ let id = self.identity_provider();
+
+ validate_key_package(key_package, version, cs, &id).await
+ }
+
+ #[cfg(feature = "private_message")]
+ async fn process_ciphertext(
+ &mut self,
+ cipher_text: &PrivateMessage,
+ ) -> Result<EventOrContent<Self::OutputType>, MlsError>;
+
+ async fn verify_plaintext_authentication(
+ &self,
+ message: PublicMessage,
+ ) -> Result<EventOrContent<Self::OutputType>, MlsError>;
+
+ async fn apply_update_path(
+ &mut self,
+ sender: LeafIndex,
+ update_path: &ValidatedUpdatePath,
+ provisional_state: &mut ProvisionalState,
+ ) -> Result<Option<(TreeKemPrivate, PathSecret)>, MlsError> {
+ provisional_state
+ .public_tree
+ .apply_update_path(
+ sender,
+ update_path,
+ &provisional_state.group_context.extensions,
+ self.identity_provider(),
+ self.cipher_suite_provider(),
+ )
+ .await
+ .map(|_| None)
+ }
+
+ async fn update_key_schedule(
+ &mut self,
+ secrets: Option<(TreeKemPrivate, PathSecret)>,
+ interim_transcript_hash: InterimTranscriptHash,
+ confirmation_tag: &ConfirmationTag,
+ provisional_public_state: ProvisionalState,
+ ) -> Result<(), MlsError>;
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(crate) async fn validate_key_package<C: CipherSuiteProvider, I: IdentityProvider>(
+ key_package: &KeyPackage,
+ version: ProtocolVersion,
+ cs: &C,
+ id: &I,
+) -> Result<(), MlsError> {
+ let validator = LeafNodeValidator::new(cs, id, None);
+
+ #[cfg(feature = "std")]
+ let context = Some(MlsTime::now());
+
+ #[cfg(not(feature = "std"))]
+ let context = None;
+
+ let context = ValidationContext::Add(context);
+
+ validator
+ .check_if_valid(&key_package.leaf_node, context)
+ .await?;
+
+ validate_key_package_properties(key_package, version, cs).await?;
+
+ Ok(())
+}
diff --git a/src/group/message_signature.rs b/src/group/message_signature.rs
new file mode 100644
index 0000000..3c08935
--- /dev/null
+++ b/src/group/message_signature.rs
@@ -0,0 +1,274 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use super::framing::Content;
+use crate::client::MlsError;
+use crate::crypto::SignatureSecretKey;
+use crate::group::framing::{ContentType, FramedContent, PublicMessage, Sender, WireFormat};
+use crate::group::{ConfirmationTag, GroupContext};
+use crate::signer::Signable;
+use crate::CipherSuiteProvider;
+use alloc::vec;
+use alloc::vec::Vec;
+use core::{
+ fmt::{self, Debug},
+ ops::Deref,
+};
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+use mls_rs_core::protocol_version::ProtocolVersion;
+
+#[derive(Clone, Debug, PartialEq)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub struct FramedContentAuthData {
+ pub signature: MessageSignature,
+ pub confirmation_tag: Option<ConfirmationTag>,
+}
+
+impl MlsSize for FramedContentAuthData {
+ fn mls_encoded_len(&self) -> usize {
+ self.signature.mls_encoded_len()
+ + self
+ .confirmation_tag
+ .as_ref()
+ .map_or(0, |tag| tag.mls_encoded_len())
+ }
+}
+
+impl MlsEncode for FramedContentAuthData {
+ fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error> {
+ self.signature.mls_encode(writer)?;
+
+ if let Some(ref tag) = self.confirmation_tag {
+ tag.mls_encode(writer)?;
+ }
+
+ Ok(())
+ }
+}
+
+impl FramedContentAuthData {
+ pub(crate) fn mls_decode(
+ reader: &mut &[u8],
+ content_type: ContentType,
+ ) -> Result<Self, mls_rs_codec::Error> {
+ Ok(FramedContentAuthData {
+ signature: MessageSignature::mls_decode(reader)?,
+ confirmation_tag: match content_type {
+ ContentType::Commit => Some(ConfirmationTag::mls_decode(reader)?),
+ #[cfg(feature = "private_message")]
+ ContentType::Application => None,
+ #[cfg(feature = "by_ref_proposal")]
+ ContentType::Proposal => None,
+ },
+ })
+ }
+}
+
+#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub struct AuthenticatedContent {
+ pub(crate) wire_format: WireFormat,
+ pub(crate) content: FramedContent,
+ pub(crate) auth: FramedContentAuthData,
+}
+
+impl From<PublicMessage> for AuthenticatedContent {
+ fn from(p: PublicMessage) -> Self {
+ Self {
+ wire_format: WireFormat::PublicMessage,
+ content: p.content,
+ auth: p.auth,
+ }
+ }
+}
+
+impl AuthenticatedContent {
+ pub(crate) fn new(
+ context: &GroupContext,
+ sender: Sender,
+ content: Content,
+ authenticated_data: Vec<u8>,
+ wire_format: WireFormat,
+ ) -> AuthenticatedContent {
+ AuthenticatedContent {
+ wire_format,
+ content: FramedContent {
+ group_id: context.group_id.clone(),
+ epoch: context.epoch,
+ sender,
+ authenticated_data,
+ content,
+ },
+ auth: FramedContentAuthData {
+ signature: MessageSignature::empty(),
+ confirmation_tag: None,
+ },
+ }
+ }
+
+ #[inline(never)]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn new_signed<P: CipherSuiteProvider>(
+ signature_provider: &P,
+ context: &GroupContext,
+ sender: Sender,
+ content: Content,
+ signer: &SignatureSecretKey,
+ wire_format: WireFormat,
+ authenticated_data: Vec<u8>,
+ ) -> Result<AuthenticatedContent, MlsError> {
+ // Construct an MlsPlaintext object containing the content
+ let mut plaintext =
+ AuthenticatedContent::new(context, sender, content, authenticated_data, wire_format);
+
+ let signing_context = MessageSigningContext {
+ group_context: Some(context),
+ protocol_version: context.protocol_version,
+ };
+
+ // Sign the MlsPlaintext using the current epoch's GroupContext as context.
+ plaintext
+ .sign(signature_provider, signer, &signing_context)
+ .await?;
+
+ Ok(plaintext)
+ }
+}
+
+impl MlsDecode for AuthenticatedContent {
+ fn mls_decode(reader: &mut &[u8]) -> Result<Self, mls_rs_codec::Error> {
+ let wire_format = WireFormat::mls_decode(reader)?;
+ let content = FramedContent::mls_decode(reader)?;
+ let auth_data = FramedContentAuthData::mls_decode(reader, content.content_type())?;
+
+ Ok(AuthenticatedContent {
+ wire_format,
+ content,
+ auth: auth_data,
+ })
+ }
+}
+
+#[derive(Clone, Debug, PartialEq)]
+pub(crate) struct AuthenticatedContentTBS<'a> {
+ pub(crate) protocol_version: ProtocolVersion,
+ pub(crate) wire_format: WireFormat,
+ pub(crate) content: &'a FramedContent,
+ pub(crate) context: Option<&'a GroupContext>,
+}
+
+impl<'a> MlsSize for AuthenticatedContentTBS<'a> {
+ fn mls_encoded_len(&self) -> usize {
+ self.protocol_version.mls_encoded_len()
+ + self.wire_format.mls_encoded_len()
+ + self.content.mls_encoded_len()
+ + self.context.as_ref().map_or(0, |ctx| ctx.mls_encoded_len())
+ }
+}
+
+impl<'a> MlsEncode for AuthenticatedContentTBS<'a> {
+ fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error> {
+ self.protocol_version.mls_encode(writer)?;
+ self.wire_format.mls_encode(writer)?;
+ self.content.mls_encode(writer)?;
+
+ if let Some(context) = self.context {
+ context.mls_encode(writer)?;
+ }
+
+ Ok(())
+ }
+}
+
+impl<'a> AuthenticatedContentTBS<'a> {
+ /// The group context must not be `None` when the sender is `Member` or `NewMember`.
+ pub(crate) fn from_authenticated_content(
+ auth_content: &'a AuthenticatedContent,
+ group_context: Option<&'a GroupContext>,
+ protocol_version: ProtocolVersion,
+ ) -> Self {
+ AuthenticatedContentTBS {
+ protocol_version,
+ wire_format: auth_content.wire_format,
+ content: &auth_content.content,
+ context: match auth_content.content.sender {
+ Sender::Member(_) | Sender::NewMemberCommit => group_context,
+ #[cfg(feature = "by_ref_proposal")]
+ Sender::External(_) => None,
+ #[cfg(feature = "by_ref_proposal")]
+ Sender::NewMemberProposal => None,
+ },
+ }
+ }
+}
+
+#[derive(Debug)]
+pub(crate) struct MessageSigningContext<'a> {
+ pub group_context: Option<&'a GroupContext>,
+ pub protocol_version: ProtocolVersion,
+}
+
+impl<'a> Signable<'a> for AuthenticatedContent {
+ const SIGN_LABEL: &'static str = "FramedContentTBS";
+
+ type SigningContext = MessageSigningContext<'a>;
+
+ fn signature(&self) -> &[u8] {
+ &self.auth.signature
+ }
+
+ fn signable_content(
+ &self,
+ context: &MessageSigningContext,
+ ) -> Result<Vec<u8>, mls_rs_codec::Error> {
+ AuthenticatedContentTBS::from_authenticated_content(
+ self,
+ context.group_context,
+ context.protocol_version,
+ )
+ .mls_encode_to_vec()
+ }
+
+ fn write_signature(&mut self, signature: Vec<u8>) {
+ self.auth.signature = MessageSignature::from(signature)
+ }
+}
+
+#[derive(Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub struct MessageSignature(
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
+ Vec<u8>,
+);
+
+impl Debug for MessageSignature {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ mls_rs_core::debug::pretty_bytes(&self.0)
+ .named("MessageSignature")
+ .fmt(f)
+ }
+}
+
+impl MessageSignature {
+ pub(crate) fn empty() -> Self {
+ MessageSignature(vec![])
+ }
+}
+
+impl Deref for MessageSignature {
+ type Target = Vec<u8>;
+
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+impl From<Vec<u8>> for MessageSignature {
+ fn from(v: Vec<u8>) -> Self {
+ MessageSignature(v)
+ }
+}
diff --git a/src/group/message_verifier.rs b/src/group/message_verifier.rs
new file mode 100644
index 0000000..7a2bc59
--- /dev/null
+++ b/src/group/message_verifier.rs
@@ -0,0 +1,680 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+#[cfg(feature = "by_ref_proposal")]
+use alloc::{vec, vec::Vec};
+
+use crate::{
+ client::MlsError,
+ crypto::SignaturePublicKey,
+ group::{GroupContext, PublicMessage, Sender},
+ signer::Signable,
+ tree_kem::{node::LeafIndex, TreeKemPublic},
+ CipherSuiteProvider,
+};
+
+#[cfg(feature = "by_ref_proposal")]
+use crate::{extension::ExternalSendersExt, identity::SigningIdentity};
+
+use super::{
+ key_schedule::KeySchedule,
+ message_signature::{AuthenticatedContent, MessageSigningContext},
+ state::GroupState,
+};
+
+#[cfg(feature = "by_ref_proposal")]
+use super::proposal::Proposal;
+
+#[derive(Debug)]
+pub(crate) enum SignaturePublicKeysContainer<'a> {
+ RatchetTree(&'a TreeKemPublic),
+ #[cfg(feature = "private_message")]
+ List(&'a [Option<SignaturePublicKey>]),
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(crate) async fn verify_plaintext_authentication<P: CipherSuiteProvider>(
+ cipher_suite_provider: &P,
+ plaintext: PublicMessage,
+ key_schedule: Option<&KeySchedule>,
+ self_index: Option<LeafIndex>,
+ state: &GroupState,
+) -> Result<AuthenticatedContent, MlsError> {
+ let tag = plaintext.membership_tag.clone();
+ let auth_content = AuthenticatedContent::from(plaintext);
+ let context = &state.context;
+
+ #[cfg(feature = "by_ref_proposal")]
+ let external_signers = external_signers(context);
+
+ let current_tree = &state.public_tree;
+
+ // Verify the membership tag if needed
+ match &auth_content.content.sender {
+ Sender::Member(index) => {
+ if let Some(key_schedule) = key_schedule {
+ let expected_tag = &key_schedule
+ .get_membership_tag(&auth_content, context, cipher_suite_provider)
+ .await?;
+
+ let plaintext_tag = tag.as_ref().ok_or(MlsError::InvalidMembershipTag)?;
+
+ if expected_tag != plaintext_tag {
+ return Err(MlsError::InvalidMembershipTag);
+ }
+ }
+
+ if self_index == Some(LeafIndex(*index)) {
+ return Err(MlsError::CantProcessMessageFromSelf);
+ }
+ }
+ _ => {
+ tag.is_none()
+ .then_some(())
+ .ok_or(MlsError::MembershipTagForNonMember)?;
+ }
+ }
+
+ // Verify that the signature on the MLSAuthenticatedContent verifies using the public key
+ // from the credential stored at the leaf in the tree indicated by the sender field.
+ verify_auth_content_signature(
+ cipher_suite_provider,
+ SignaturePublicKeysContainer::RatchetTree(current_tree),
+ context,
+ &auth_content,
+ #[cfg(feature = "by_ref_proposal")]
+ &external_signers,
+ )
+ .await?;
+
+ Ok(auth_content)
+}
+
+#[cfg(feature = "by_ref_proposal")]
+fn external_signers(context: &GroupContext) -> Vec<SigningIdentity> {
+ context
+ .extensions
+ .get_as::<ExternalSendersExt>()
+ .unwrap_or(None)
+ .map_or(vec![], |extern_senders_ext| {
+ extern_senders_ext.allowed_senders
+ })
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(crate) async fn verify_auth_content_signature<P: CipherSuiteProvider>(
+ cipher_suite_provider: &P,
+ signature_keys_container: SignaturePublicKeysContainer<'_>,
+ context: &GroupContext,
+ auth_content: &AuthenticatedContent,
+ #[cfg(feature = "by_ref_proposal")] external_signers: &[SigningIdentity],
+) -> Result<(), MlsError> {
+ let sender_public_key = signing_identity_for_sender(
+ signature_keys_container,
+ &auth_content.content.sender,
+ &auth_content.content.content,
+ #[cfg(feature = "by_ref_proposal")]
+ external_signers,
+ )?;
+
+ let context = MessageSigningContext {
+ group_context: Some(context),
+ protocol_version: context.protocol_version,
+ };
+
+ auth_content
+ .verify(cipher_suite_provider, &sender_public_key, &context)
+ .await?;
+
+ Ok(())
+}
+
+fn signing_identity_for_sender(
+ signature_keys_container: SignaturePublicKeysContainer,
+ sender: &Sender,
+ content: &super::framing::Content,
+ #[cfg(feature = "by_ref_proposal")] external_signers: &[SigningIdentity],
+) -> Result<SignaturePublicKey, MlsError> {
+ match sender {
+ Sender::Member(leaf_index) => {
+ signing_identity_for_member(signature_keys_container, LeafIndex(*leaf_index))
+ }
+ #[cfg(feature = "by_ref_proposal")]
+ Sender::External(external_key_index) => {
+ signing_identity_for_external(*external_key_index, external_signers)
+ }
+ Sender::NewMemberCommit => signing_identity_for_new_member_commit(content),
+ #[cfg(feature = "by_ref_proposal")]
+ Sender::NewMemberProposal => signing_identity_for_new_member_proposal(content),
+ }
+}
+
+fn signing_identity_for_member(
+ signature_keys_container: SignaturePublicKeysContainer,
+ leaf_index: LeafIndex,
+) -> Result<SignaturePublicKey, MlsError> {
+ match signature_keys_container {
+ SignaturePublicKeysContainer::RatchetTree(tree) => Ok(tree
+ .get_leaf_node(leaf_index)?
+ .signing_identity
+ .signature_key
+ .clone()), // TODO: We can probably get rid of this clone
+ #[cfg(feature = "private_message")]
+ SignaturePublicKeysContainer::List(list) => list
+ .get(leaf_index.0 as usize)
+ .cloned()
+ .flatten()
+ .ok_or(MlsError::LeafNotFound(*leaf_index)),
+ }
+}
+
+#[cfg(feature = "by_ref_proposal")]
+fn signing_identity_for_external(
+ index: u32,
+ external_signers: &[SigningIdentity],
+) -> Result<SignaturePublicKey, MlsError> {
+ external_signers
+ .get(index as usize)
+ .map(|spk| spk.signature_key.clone())
+ .ok_or(MlsError::UnknownSigningIdentityForExternalSender)
+}
+
+fn signing_identity_for_new_member_commit(
+ content: &super::framing::Content,
+) -> Result<SignaturePublicKey, MlsError> {
+ match content {
+ super::framing::Content::Commit(commit) => {
+ if let Some(path) = &commit.path {
+ Ok(path.leaf_node.signing_identity.signature_key.clone())
+ } else {
+ Err(MlsError::CommitMissingPath)
+ }
+ }
+ #[cfg(any(feature = "private_message", feature = "by_ref_proposal"))]
+ _ => Err(MlsError::ExpectedCommitForNewMemberCommit),
+ }
+}
+
+#[cfg(feature = "by_ref_proposal")]
+fn signing_identity_for_new_member_proposal(
+ content: &super::framing::Content,
+) -> Result<SignaturePublicKey, MlsError> {
+ match content {
+ super::framing::Content::Proposal(proposal) => {
+ if let Proposal::Add(p) = proposal.as_ref() {
+ Ok(p.key_package
+ .leaf_node
+ .signing_identity
+ .signature_key
+ .clone())
+ } else {
+ Err(MlsError::ExpectedAddProposalForNewMemberProposal)
+ }
+ }
+ _ => Err(MlsError::ExpectedAddProposalForNewMemberProposal),
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use crate::{
+ client::{
+ test_utils::{test_client_with_key_pkg, TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
+ MlsError,
+ },
+ client_builder::test_utils::TestClientConfig,
+ crypto::test_utils::test_cipher_suite_provider,
+ group::{
+ membership_tag::MembershipTag,
+ message_signature::{AuthenticatedContent, MessageSignature},
+ test_utils::{test_group_custom, TestGroup},
+ Group, PublicMessage,
+ },
+ tree_kem::node::LeafIndex,
+ };
+ use alloc::vec;
+ use assert_matches::assert_matches;
+
+ #[cfg(feature = "by_ref_proposal")]
+ use crate::{extension::ExternalSendersExt, ExtensionList};
+
+ #[cfg(feature = "by_ref_proposal")]
+ use crate::{
+ crypto::SignatureSecretKey,
+ group::{
+ message_signature::MessageSigningContext,
+ proposal::{AddProposal, Proposal, RemoveProposal},
+ Content,
+ },
+ key_package::KeyPackageGeneration,
+ signer::Signable,
+ WireFormat,
+ };
+
+ #[cfg(feature = "by_ref_proposal")]
+ use alloc::boxed::Box;
+
+ use crate::group::{
+ test_utils::{test_group, test_member},
+ Sender,
+ };
+
+ #[cfg(feature = "by_ref_proposal")]
+ use crate::identity::test_utils::get_test_signing_identity;
+
+ use super::{verify_auth_content_signature, verify_plaintext_authentication};
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn make_signed_plaintext(group: &mut Group<TestClientConfig>) -> PublicMessage {
+ group
+ .commit(vec![])
+ .await
+ .unwrap()
+ .commit_message
+ .into_plaintext()
+ .unwrap()
+ }
+
+ struct TestEnv {
+ alice: TestGroup,
+ bob: TestGroup,
+ }
+
+ impl TestEnv {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn new() -> Self {
+ let mut alice = test_group_custom(
+ TEST_PROTOCOL_VERSION,
+ TEST_CIPHER_SUITE,
+ Default::default(),
+ None,
+ None,
+ )
+ .await;
+
+ let (bob_client, bob_key_pkg) =
+ test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "bob").await;
+
+ let commit_output = alice
+ .group
+ .commit_builder()
+ .add_member(bob_key_pkg)
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ alice.group.apply_pending_commit().await.unwrap();
+
+ let (bob, _) = Group::join(
+ &commit_output.welcome_messages[0],
+ None,
+ bob_client.config,
+ bob_client.signer.unwrap(),
+ )
+ .await
+ .unwrap();
+
+ TestEnv {
+ alice,
+ bob: TestGroup { group: bob },
+ }
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn valid_plaintext_is_verified() {
+ let mut env = TestEnv::new().await;
+
+ let message = make_signed_plaintext(&mut env.alice.group).await;
+
+ verify_plaintext_authentication(
+ &env.bob.group.cipher_suite_provider,
+ message,
+ Some(&env.bob.group.key_schedule),
+ None,
+ &env.bob.group.state,
+ )
+ .await
+ .unwrap();
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn valid_auth_content_is_verified() {
+ let mut env = TestEnv::new().await;
+
+ let message = AuthenticatedContent::from(make_signed_plaintext(&mut env.alice.group).await);
+
+ verify_auth_content_signature(
+ &env.bob.group.cipher_suite_provider,
+ super::SignaturePublicKeysContainer::RatchetTree(&env.bob.group.state.public_tree),
+ env.bob.group.context(),
+ &message,
+ #[cfg(feature = "by_ref_proposal")]
+ &[],
+ )
+ .await
+ .unwrap();
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn invalid_plaintext_is_not_verified() {
+ let mut env = TestEnv::new().await;
+ let mut message = make_signed_plaintext(&mut env.alice.group).await;
+ message.auth.signature = MessageSignature::from(b"test".to_vec());
+
+ message.membership_tag = env
+ .alice
+ .group
+ .key_schedule
+ .get_membership_tag(
+ &AuthenticatedContent::from(message.clone()),
+ env.alice.group.context(),
+ &test_cipher_suite_provider(env.alice.group.cipher_suite()),
+ )
+ .await
+ .unwrap()
+ .into();
+
+ let res = verify_plaintext_authentication(
+ &env.bob.group.cipher_suite_provider,
+ message,
+ Some(&env.bob.group.key_schedule),
+ None,
+ &env.bob.group.state,
+ )
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidSignature));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn plaintext_from_member_requires_membership_tag() {
+ let mut env = TestEnv::new().await;
+ let mut message = make_signed_plaintext(&mut env.alice.group).await;
+ message.membership_tag = None;
+
+ let res = verify_plaintext_authentication(
+ &env.bob.group.cipher_suite_provider,
+ message,
+ Some(&env.bob.group.key_schedule),
+ None,
+ &env.bob.group.state,
+ )
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidMembershipTag));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn plaintext_fails_with_invalid_membership_tag() {
+ let mut env = TestEnv::new().await;
+ let mut message = make_signed_plaintext(&mut env.alice.group).await;
+ message.membership_tag = Some(MembershipTag::from(b"test".to_vec()));
+
+ let res = verify_plaintext_authentication(
+ &env.bob.group.cipher_suite_provider,
+ message,
+ Some(&env.bob.group.key_schedule),
+ None,
+ &env.bob.group.state,
+ )
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidMembershipTag));
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn test_new_member_proposal<F>(
+ key_pkg_gen: KeyPackageGeneration,
+ signer: &SignatureSecretKey,
+ test_group: &TestGroup,
+ mut edit: F,
+ ) -> PublicMessage
+ where
+ F: FnMut(&mut AuthenticatedContent),
+ {
+ let mut content = AuthenticatedContent::new_signed(
+ &test_group.group.cipher_suite_provider,
+ test_group.group.context(),
+ Sender::NewMemberProposal,
+ Content::Proposal(Box::new(Proposal::Add(Box::new(AddProposal {
+ key_package: key_pkg_gen.key_package,
+ })))),
+ signer,
+ WireFormat::PublicMessage,
+ vec![],
+ )
+ .await
+ .unwrap();
+
+ edit(&mut content);
+
+ let signing_context = MessageSigningContext {
+ group_context: Some(test_group.group.context()),
+ protocol_version: test_group.group.protocol_version(),
+ };
+
+ content
+ .sign(
+ &test_group.group.cipher_suite_provider,
+ signer,
+ &signing_context,
+ )
+ .await
+ .unwrap();
+
+ PublicMessage {
+ content: content.content,
+ auth: content.auth,
+ membership_tag: None,
+ }
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn valid_proposal_from_new_member_is_verified() {
+ let test_group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let (key_pkg_gen, signer) =
+ test_member(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, b"bob").await;
+ let message = test_new_member_proposal(key_pkg_gen, &signer, &test_group, |_| {}).await;
+
+ verify_plaintext_authentication(
+ &test_group.group.cipher_suite_provider,
+ message,
+ Some(&test_group.group.key_schedule),
+ None,
+ &test_group.group.state,
+ )
+ .await
+ .unwrap();
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn proposal_from_new_member_must_not_have_membership_tag() {
+ let test_group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let (key_pkg_gen, signer) =
+ test_member(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, b"bob").await;
+
+ let mut message = test_new_member_proposal(key_pkg_gen, &signer, &test_group, |_| {}).await;
+ message.membership_tag = Some(MembershipTag::from(vec![]));
+
+ let res = verify_plaintext_authentication(
+ &test_group.group.cipher_suite_provider,
+ message,
+ Some(&test_group.group.key_schedule),
+ None,
+ &test_group.group.state,
+ )
+ .await;
+
+ assert_matches!(res, Err(MlsError::MembershipTagForNonMember));
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn new_member_proposal_sender_must_be_add_proposal() {
+ let test_group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let (key_pkg_gen, signer) =
+ test_member(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, b"bob").await;
+
+ let message = test_new_member_proposal(key_pkg_gen, &signer, &test_group, |msg| {
+ msg.content.content = Content::Proposal(Box::new(Proposal::Remove(RemoveProposal {
+ to_remove: LeafIndex(0),
+ })))
+ })
+ .await;
+
+ let res: Result<AuthenticatedContent, MlsError> = verify_plaintext_authentication(
+ &test_group.group.cipher_suite_provider,
+ message,
+ Some(&test_group.group.key_schedule),
+ None,
+ &test_group.group.state,
+ )
+ .await;
+
+ assert_matches!(res, Err(MlsError::ExpectedAddProposalForNewMemberProposal));
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn new_member_commit_must_be_external_commit() {
+ let test_group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let (key_pkg_gen, signer) =
+ test_member(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, b"bob").await;
+
+ let message = test_new_member_proposal(key_pkg_gen, &signer, &test_group, |msg| {
+ msg.content.sender = Sender::NewMemberCommit;
+ })
+ .await;
+
+ let res = verify_plaintext_authentication(
+ &test_group.group.cipher_suite_provider,
+ message,
+ Some(&test_group.group.key_schedule),
+ None,
+ &test_group.group.state,
+ )
+ .await;
+
+ assert_matches!(res, Err(MlsError::ExpectedCommitForNewMemberCommit));
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn valid_proposal_from_external_is_verified() {
+ let (bob_key_pkg_gen, _) =
+ test_member(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, b"bob").await;
+
+ let (ted_signing, ted_secret) = get_test_signing_identity(TEST_CIPHER_SUITE, b"ted").await;
+
+ let mut test_group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let mut extensions = ExtensionList::default();
+
+ extensions
+ .set_from(ExternalSendersExt {
+ allowed_senders: vec![ted_signing],
+ })
+ .unwrap();
+
+ test_group
+ .group
+ .commit_builder()
+ .set_group_context_ext(extensions)
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ test_group.group.apply_pending_commit().await.unwrap();
+
+ let message = test_new_member_proposal(bob_key_pkg_gen, &ted_secret, &test_group, |msg| {
+ msg.content.sender = Sender::External(0)
+ })
+ .await;
+
+ verify_plaintext_authentication(
+ &test_group.group.cipher_suite_provider,
+ message,
+ Some(&test_group.group.key_schedule),
+ None,
+ &test_group.group.state,
+ )
+ .await
+ .unwrap();
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn external_proposal_must_be_from_valid_sender() {
+ let (bob_key_pkg_gen, _) =
+ test_member(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, b"bob").await;
+ let (_, ted_secret) = get_test_signing_identity(TEST_CIPHER_SUITE, b"ted").await;
+ let test_group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+
+ let message = test_new_member_proposal(bob_key_pkg_gen, &ted_secret, &test_group, |msg| {
+ msg.content.sender = Sender::External(0)
+ })
+ .await;
+
+ let res = verify_plaintext_authentication(
+ &test_group.group.cipher_suite_provider,
+ message,
+ Some(&test_group.group.key_schedule),
+ None,
+ &test_group.group.state,
+ )
+ .await;
+
+ assert_matches!(res, Err(MlsError::UnknownSigningIdentityForExternalSender));
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn proposal_from_external_sender_must_not_have_membership_tag() {
+ let (bob_key_pkg_gen, _) =
+ test_member(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, b"bob").await;
+
+ let (_, ted_secret) = get_test_signing_identity(TEST_CIPHER_SUITE, b"ted").await;
+
+ let test_group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+
+ let mut message =
+ test_new_member_proposal(bob_key_pkg_gen, &ted_secret, &test_group, |_| {}).await;
+
+ message.membership_tag = Some(MembershipTag::from(vec![]));
+
+ let res = verify_plaintext_authentication(
+ &test_group.group.cipher_suite_provider,
+ message,
+ Some(&test_group.group.key_schedule),
+ None,
+ &test_group.group.state,
+ )
+ .await;
+
+ assert_matches!(res, Err(MlsError::MembershipTagForNonMember));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn plaintext_from_self_fails_verification() {
+ let mut env = TestEnv::new().await;
+
+ let message = make_signed_plaintext(&mut env.alice.group).await;
+
+ let res = verify_plaintext_authentication(
+ &env.alice.group.cipher_suite_provider,
+ message,
+ Some(&env.alice.group.key_schedule),
+ Some(LeafIndex::new(env.alice.group.current_member_index())),
+ &env.alice.group.state,
+ )
+ .await;
+
+ assert_matches!(res, Err(MlsError::CantProcessMessageFromSelf))
+ }
+}
diff --git a/src/group/mls_rules.rs b/src/group/mls_rules.rs
new file mode 100644
index 0000000..98b1dac
--- /dev/null
+++ b/src/group/mls_rules.rs
@@ -0,0 +1,283 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use crate::group::{proposal_filter::ProposalBundle, Roster};
+
+#[cfg(feature = "private_message")]
+use crate::{
+ group::{padding::PaddingMode, Sender},
+ WireFormat,
+};
+
+use alloc::boxed::Box;
+use core::convert::Infallible;
+use mls_rs_core::{
+ error::IntoAnyError, extension::ExtensionList, group::Member, identity::SigningIdentity,
+};
+
+#[derive(Copy, Clone, Debug, PartialEq, Eq)]
+pub enum CommitDirection {
+ Send,
+ Receive,
+}
+
+/// The source of the commit: either a current member or a new member joining
+/// via external commit.
+#[derive(Clone, Debug, PartialEq, Eq)]
+pub enum CommitSource {
+ ExistingMember(Member),
+ NewMember(SigningIdentity),
+}
+
+/// Options controlling commit generation
+#[derive(Clone, Copy, Debug, PartialEq, Eq)]
+#[non_exhaustive]
+pub struct CommitOptions {
+ pub path_required: bool,
+ pub ratchet_tree_extension: bool,
+ pub single_welcome_message: bool,
+ pub allow_external_commit: bool,
+}
+
+impl Default for CommitOptions {
+ fn default() -> Self {
+ CommitOptions {
+ path_required: false,
+ ratchet_tree_extension: true,
+ single_welcome_message: true,
+ allow_external_commit: false,
+ }
+ }
+}
+
+impl CommitOptions {
+ pub fn new() -> Self {
+ Self::default()
+ }
+
+ pub fn with_path_required(self, path_required: bool) -> Self {
+ Self {
+ path_required,
+ ..self
+ }
+ }
+
+ pub fn with_ratchet_tree_extension(self, ratchet_tree_extension: bool) -> Self {
+ Self {
+ ratchet_tree_extension,
+ ..self
+ }
+ }
+
+ pub fn with_single_welcome_message(self, single_welcome_message: bool) -> Self {
+ Self {
+ single_welcome_message,
+ ..self
+ }
+ }
+
+ pub fn with_allow_external_commit(self, allow_external_commit: bool) -> Self {
+ Self {
+ allow_external_commit,
+ ..self
+ }
+ }
+}
+
+/// Options controlling encryption of control and application messages
+#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
+#[non_exhaustive]
+pub struct EncryptionOptions {
+ #[cfg(feature = "private_message")]
+ pub encrypt_control_messages: bool,
+ #[cfg(feature = "private_message")]
+ pub padding_mode: PaddingMode,
+}
+
+#[cfg(feature = "private_message")]
+impl EncryptionOptions {
+ pub fn new(encrypt_control_messages: bool, padding_mode: PaddingMode) -> Self {
+ Self {
+ encrypt_control_messages,
+ padding_mode,
+ }
+ }
+
+ pub(crate) fn control_wire_format(&self, sender: Sender) -> WireFormat {
+ match sender {
+ Sender::Member(_) if self.encrypt_control_messages => WireFormat::PrivateMessage,
+ _ => WireFormat::PublicMessage,
+ }
+ }
+}
+
+/// A set of user controlled rules that customize the behavior of MLS.
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(mls_build_async, maybe_async::must_be_async)]
+pub trait MlsRules: Send + Sync {
+ type Error: IntoAnyError;
+
+ /// This is called when preparing or receiving a commit to pre-process the set of committed
+ /// proposals.
+ ///
+ /// Both proposals received during the current epoch and at the time of commit
+ /// will be presented for validation and filtering. Filter and validate will
+ /// present a raw list of proposals. Standard MLS rules are applied internally
+ /// on the result of these rules.
+ ///
+ /// Each member of a group MUST apply the same proposal rules in order to
+ /// maintain a working group.
+ ///
+ /// Typically, any invalid proposal should result in an error. The exception are invalid
+ /// by-reference proposals processed when _preparing_ a commit, which should be filtered
+ /// out instead. This is to avoid the deadlock situation when no commit can be generated
+ /// after receiving an invalid set of proposal messages.
+ ///
+ /// `ProposalBundle` can be arbitrarily modified. For example, a Remove proposal that
+ /// removes a moderator can result in adding a GroupContextExtensions proposal that updates
+ /// the moderator list in the group context. The resulting `ProposalBundle` is validated
+ /// by the library.
+ async fn filter_proposals(
+ &self,
+ direction: CommitDirection,
+ source: CommitSource,
+ current_roster: &Roster,
+ extension_list: &ExtensionList,
+ proposals: ProposalBundle,
+ ) -> Result<ProposalBundle, Self::Error>;
+
+ /// This is called when preparing a commit to determine various options: whether to enforce an update
+ /// path in case it is not mandated by MLS, whether to include the ratchet tree in the welcome
+ /// message (if the commit adds members) and whether to generate a single welcome message, or one
+ /// welcome message for each added member.
+ ///
+ /// The `new_roster` and `new_extension_list` describe the group state after the commit.
+ fn commit_options(
+ &self,
+ new_roster: &Roster,
+ new_extension_list: &ExtensionList,
+ proposals: &ProposalBundle,
+ ) -> Result<CommitOptions, Self::Error>;
+
+ /// This is called when sending any packet. For proposals and commits, this determines whether to
+ /// encrypt them. For any encrypted packet, this determines the padding mode used.
+ ///
+ /// Note that for commits, the `current_roster` and `current_extension_list` describe the group state
+ /// before the commit, unlike in [commit_options](MlsRules::commit_options).
+ fn encryption_options(
+ &self,
+ current_roster: &Roster,
+ current_extension_list: &ExtensionList,
+ ) -> Result<EncryptionOptions, Self::Error>;
+}
+
+macro_rules! delegate_mls_rules {
+ ($implementer:ty) => {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ #[cfg_attr(mls_build_async, maybe_async::must_be_async)]
+ impl<T: MlsRules + ?Sized> MlsRules for $implementer {
+ type Error = T::Error;
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn filter_proposals(
+ &self,
+ direction: CommitDirection,
+ source: CommitSource,
+ current_roster: &Roster,
+ extension_list: &ExtensionList,
+ proposals: ProposalBundle,
+ ) -> Result<ProposalBundle, Self::Error> {
+ (**self)
+ .filter_proposals(direction, source, current_roster, extension_list, proposals)
+ .await
+ }
+
+ fn commit_options(
+ &self,
+ roster: &Roster,
+ extension_list: &ExtensionList,
+ proposals: &ProposalBundle,
+ ) -> Result<CommitOptions, Self::Error> {
+ (**self).commit_options(roster, extension_list, proposals)
+ }
+
+ fn encryption_options(
+ &self,
+ roster: &Roster,
+ extension_list: &ExtensionList,
+ ) -> Result<EncryptionOptions, Self::Error> {
+ (**self).encryption_options(roster, extension_list)
+ }
+ }
+ };
+}
+
+delegate_mls_rules!(Box<T>);
+delegate_mls_rules!(&T);
+
+#[derive(Clone, Debug, Default)]
+#[non_exhaustive]
+/// Default MLS rules with pass-through proposal filter and customizable options.
+pub struct DefaultMlsRules {
+ pub commit_options: CommitOptions,
+ pub encryption_options: EncryptionOptions,
+}
+
+impl DefaultMlsRules {
+ /// Create new MLS rules with default settings: do not enforce path and do
+ /// put the ratchet tree in the extension.
+ pub fn new() -> Self {
+ Default::default()
+ }
+
+ /// Set commit options.
+ pub fn with_commit_options(self, commit_options: CommitOptions) -> Self {
+ Self {
+ commit_options,
+ encryption_options: self.encryption_options,
+ }
+ }
+
+ /// Set encryption options.
+ pub fn with_encryption_options(self, encryption_options: EncryptionOptions) -> Self {
+ Self {
+ commit_options: self.commit_options,
+ encryption_options,
+ }
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(mls_build_async, maybe_async::must_be_async)]
+impl MlsRules for DefaultMlsRules {
+ type Error = Infallible;
+
+ async fn filter_proposals(
+ &self,
+ _direction: CommitDirection,
+ _source: CommitSource,
+ _current_roster: &Roster,
+ _extension_list: &ExtensionList,
+ proposals: ProposalBundle,
+ ) -> Result<ProposalBundle, Self::Error> {
+ Ok(proposals)
+ }
+
+ fn commit_options(
+ &self,
+ _: &Roster,
+ _: &ExtensionList,
+ _: &ProposalBundle,
+ ) -> Result<CommitOptions, Self::Error> {
+ Ok(self.commit_options)
+ }
+
+ fn encryption_options(
+ &self,
+ _: &Roster,
+ _: &ExtensionList,
+ ) -> Result<EncryptionOptions, Self::Error> {
+ Ok(self.encryption_options)
+ }
+}
diff --git a/src/group/mod.rs b/src/group/mod.rs
new file mode 100644
index 0000000..0d84a84
--- /dev/null
+++ b/src/group/mod.rs
@@ -0,0 +1,4236 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::vec;
+use alloc::vec::Vec;
+use core::fmt::{self, Debug};
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+use mls_rs_core::error::IntoAnyError;
+use mls_rs_core::secret::Secret;
+use mls_rs_core::time::MlsTime;
+
+use crate::cipher_suite::CipherSuite;
+use crate::client::MlsError;
+use crate::client_config::ClientConfig;
+use crate::crypto::{HpkeCiphertext, SignatureSecretKey};
+use crate::extension::RatchetTreeExt;
+use crate::identity::SigningIdentity;
+use crate::key_package::{KeyPackage, KeyPackageRef};
+use crate::protocol_version::ProtocolVersion;
+use crate::psk::secret::PskSecret;
+use crate::psk::PreSharedKeyID;
+use crate::signer::Signable;
+use crate::tree_kem::hpke_encryption::HpkeEncryptable;
+use crate::tree_kem::kem::TreeKem;
+use crate::tree_kem::node::LeafIndex;
+use crate::tree_kem::path_secret::PathSecret;
+pub use crate::tree_kem::Capabilities;
+use crate::tree_kem::{
+ leaf_node::LeafNode,
+ leaf_node_validator::{LeafNodeValidator, ValidationContext},
+};
+use crate::tree_kem::{math as tree_math, ValidatedUpdatePath};
+use crate::tree_kem::{TreeKemPrivate, TreeKemPublic};
+use crate::{CipherSuiteProvider, CryptoProvider};
+
+#[cfg(feature = "by_ref_proposal")]
+use crate::crypto::{HpkePublicKey, HpkeSecretKey};
+
+use crate::extension::ExternalPubExt;
+
+#[cfg(feature = "private_message")]
+use self::mls_rules::{EncryptionOptions, MlsRules};
+
+#[cfg(feature = "psk")]
+pub use self::resumption::ReinitClient;
+
+#[cfg(feature = "psk")]
+use crate::psk::{
+ resolver::PskResolver, secret::PskSecretInput, ExternalPskId, JustPreSharedKeyID, PskGroupId,
+ ResumptionPSKUsage, ResumptionPsk,
+};
+
+#[cfg(all(feature = "std", feature = "by_ref_proposal"))]
+use std::collections::HashMap;
+
+#[cfg(feature = "private_message")]
+use ciphertext_processor::*;
+
+use confirmation_tag::*;
+use framing::*;
+use key_schedule::*;
+use membership_tag::*;
+use message_signature::*;
+use message_verifier::*;
+use proposal::*;
+#[cfg(feature = "by_ref_proposal")]
+use proposal_cache::*;
+use state::*;
+use transcript_hash::*;
+
+#[cfg(test)]
+pub(crate) use self::commit::test_utils::CommitModifiers;
+
+#[cfg(all(test, feature = "private_message"))]
+pub use self::framing::PrivateMessage;
+
+#[cfg(feature = "psk")]
+use self::proposal_filter::ProposalInfo;
+
+#[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+use secret_tree::*;
+
+#[cfg(feature = "prior_epoch")]
+use self::epoch::PriorEpoch;
+
+use self::epoch::EpochSecrets;
+pub use self::message_processor::{
+ ApplicationMessageDescription, CommitMessageDescription, ProposalMessageDescription,
+ ProposalSender, ReceivedMessage, StateUpdate,
+};
+use self::message_processor::{EventOrContent, MessageProcessor, ProvisionalState};
+#[cfg(feature = "by_ref_proposal")]
+use self::proposal_ref::ProposalRef;
+use self::state_repo::GroupStateRepository;
+pub use group_info::GroupInfo;
+
+pub use self::framing::{ContentType, Sender};
+pub use commit::*;
+pub use context::GroupContext;
+pub use roster::*;
+
+pub(crate) use transcript_hash::ConfirmedTranscriptHash;
+pub(crate) use util::*;
+
+#[cfg(all(feature = "by_ref_proposal", feature = "external_client"))]
+pub use self::message_processor::CachedProposal;
+
+#[cfg(feature = "private_message")]
+mod ciphertext_processor;
+
+mod commit;
+pub(crate) mod confirmation_tag;
+mod context;
+pub(crate) mod epoch;
+pub(crate) mod framing;
+mod group_info;
+pub(crate) mod key_schedule;
+mod membership_tag;
+pub(crate) mod message_processor;
+pub(crate) mod message_signature;
+pub(crate) mod message_verifier;
+pub mod mls_rules;
+#[cfg(feature = "private_message")]
+pub(crate) mod padding;
+/// Proposals to evolve a MLS [`Group`]
+pub mod proposal;
+mod proposal_cache;
+pub(crate) mod proposal_filter;
+#[cfg(feature = "by_ref_proposal")]
+pub(crate) mod proposal_ref;
+#[cfg(feature = "psk")]
+mod resumption;
+mod roster;
+pub(crate) mod snapshot;
+pub(crate) mod state;
+
+#[cfg(feature = "prior_epoch")]
+pub(crate) mod state_repo;
+#[cfg(not(feature = "prior_epoch"))]
+pub(crate) mod state_repo_light;
+#[cfg(not(feature = "prior_epoch"))]
+pub(crate) use state_repo_light as state_repo;
+
+pub(crate) mod transcript_hash;
+mod util;
+
+/// External commit building.
+pub mod external_commit;
+
+#[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+pub(crate) mod secret_tree;
+
+#[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+pub use secret_tree::MessageKeyData as MessageKey;
+
+#[cfg(all(test, feature = "rfc_compliant"))]
+mod interop_test_vectors;
+
+mod exported_tree;
+
+pub use exported_tree::ExportedTree;
+
+#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+struct GroupSecrets {
+ joiner_secret: JoinerSecret,
+ path_secret: Option<PathSecret>,
+ psks: Vec<PreSharedKeyID>,
+}
+
+impl HpkeEncryptable for GroupSecrets {
+ const ENCRYPT_LABEL: &'static str = "Welcome";
+
+ fn from_bytes(bytes: Vec<u8>) -> Result<Self, MlsError> {
+ Self::mls_decode(&mut bytes.as_slice()).map_err(Into::into)
+ }
+
+ fn get_bytes(&self) -> Result<Vec<u8>, MlsError> {
+ self.mls_encode_to_vec().map_err(Into::into)
+ }
+}
+
+#[derive(Clone, Debug, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+pub(crate) struct EncryptedGroupSecrets {
+ pub new_member: KeyPackageRef,
+ pub encrypted_group_secrets: HpkeCiphertext,
+}
+
+#[derive(Clone, Eq, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+pub(crate) struct Welcome {
+ pub cipher_suite: CipherSuite,
+ pub secrets: Vec<EncryptedGroupSecrets>,
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ pub encrypted_group_info: Vec<u8>,
+}
+
+impl Debug for Welcome {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("Welcome")
+ .field("cipher_suite", &self.cipher_suite)
+ .field("secrets", &self.secrets)
+ .field(
+ "encrypted_group_info",
+ &mls_rs_core::debug::pretty_bytes(&self.encrypted_group_info),
+ )
+ .finish()
+ }
+}
+
+#[derive(Clone, Debug)]
+#[cfg_attr(
+ all(feature = "ffi", not(test)),
+ safer_ffi_gen::ffi_type(clone, opaque)
+)]
+#[non_exhaustive]
+/// Information provided to new members upon joining a group.
+pub struct NewMemberInfo {
+ /// Group info extensions found within the Welcome message used to join
+ /// the group.
+ pub group_info_extensions: ExtensionList,
+}
+
+#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
+impl NewMemberInfo {
+ pub(crate) fn new(group_info_extensions: ExtensionList) -> Self {
+ let mut new_member_info = Self {
+ group_info_extensions,
+ };
+
+ new_member_info.ungrease();
+
+ new_member_info
+ }
+
+ /// Group info extensions found within the Welcome message used to join
+ /// the group.
+ #[cfg(feature = "ffi")]
+ pub fn group_info_extensions(&self) -> &ExtensionList {
+ &self.group_info_extensions
+ }
+}
+
+/// An MLS end-to-end encrypted group.
+///
+/// # Group Evolution
+///
+/// MLS Groups are evolved via a propose-then-commit system. Each group state
+/// produced by a commit is called an epoch and can produce and consume
+/// application, proposal, and commit messages. A [commit](Group::commit) is used
+/// to advance to the next epoch by applying existing proposals sent in
+/// the current epoch by-reference along with an optional set of proposals
+/// that are included by-value using a [`CommitBuilder`].
+#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::ffi_type(opaque))]
+#[derive(Clone)]
+pub struct Group<C>
+where
+ C: ClientConfig,
+{
+ config: C,
+ cipher_suite_provider: <C::CryptoProvider as CryptoProvider>::CipherSuiteProvider,
+ state_repo: GroupStateRepository<C::GroupStateStorage, C::KeyPackageRepository>,
+ pub(crate) state: GroupState,
+ epoch_secrets: EpochSecrets,
+ private_tree: TreeKemPrivate,
+ key_schedule: KeySchedule,
+ #[cfg(all(feature = "std", feature = "by_ref_proposal"))]
+ pending_updates: HashMap<HpkePublicKey, (HpkeSecretKey, Option<SignatureSecretKey>)>, // Hash of leaf node hpke public key to secret key
+ #[cfg(all(not(feature = "std"), feature = "by_ref_proposal"))]
+ pending_updates: Vec<(HpkePublicKey, (HpkeSecretKey, Option<SignatureSecretKey>))>,
+ pending_commit: Option<CommitGeneration>,
+ #[cfg(feature = "psk")]
+ previous_psk: Option<PskSecretInput>,
+ #[cfg(test)]
+ pub(crate) commit_modifiers: CommitModifiers,
+ pub(crate) signer: SignatureSecretKey,
+}
+
+#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
+impl<C> Group<C>
+where
+ C: ClientConfig + Clone,
+{
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn new(
+ config: C,
+ group_id: Option<Vec<u8>>,
+ cipher_suite: CipherSuite,
+ protocol_version: ProtocolVersion,
+ signing_identity: SigningIdentity,
+ group_context_extensions: ExtensionList,
+ signer: SignatureSecretKey,
+ ) -> Result<Self, MlsError> {
+ let cipher_suite_provider = cipher_suite_provider(config.crypto_provider(), cipher_suite)?;
+
+ let (leaf_node, leaf_node_secret) = LeafNode::generate(
+ &cipher_suite_provider,
+ config.leaf_properties(),
+ signing_identity,
+ &signer,
+ config.lifetime(),
+ )
+ .await?;
+
+ let identity_provider = config.identity_provider();
+
+ let leaf_node_validator = LeafNodeValidator::new(
+ &cipher_suite_provider,
+ &identity_provider,
+ Some(&group_context_extensions),
+ );
+
+ leaf_node_validator
+ .check_if_valid(&leaf_node, ValidationContext::Add(None))
+ .await?;
+
+ let (mut public_tree, private_tree) = TreeKemPublic::derive(
+ leaf_node,
+ leaf_node_secret,
+ &config.identity_provider(),
+ &group_context_extensions,
+ )
+ .await?;
+
+ let tree_hash = public_tree.tree_hash(&cipher_suite_provider).await?;
+
+ let group_id = group_id.map(Ok).unwrap_or_else(|| {
+ cipher_suite_provider
+ .random_bytes_vec(cipher_suite_provider.kdf_extract_size())
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
+ })?;
+
+ let context = GroupContext::new_group(
+ protocol_version,
+ cipher_suite,
+ group_id,
+ tree_hash,
+ group_context_extensions,
+ );
+
+ let state_repo = GroupStateRepository::new(
+ #[cfg(feature = "prior_epoch")]
+ context.group_id.clone(),
+ config.group_state_storage(),
+ config.key_package_repo(),
+ None,
+ )?;
+
+ let key_schedule_result = KeySchedule::from_random_epoch_secret(
+ &cipher_suite_provider,
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ public_tree.total_leaf_count(),
+ )
+ .await?;
+
+ let confirmation_tag = ConfirmationTag::create(
+ &key_schedule_result.confirmation_key,
+ &vec![].into(),
+ &cipher_suite_provider,
+ )
+ .await?;
+
+ let interim_hash = InterimTranscriptHash::create(
+ &cipher_suite_provider,
+ &vec![].into(),
+ &confirmation_tag,
+ )
+ .await?;
+
+ Ok(Self {
+ config,
+ state: GroupState::new(context, public_tree, interim_hash, confirmation_tag),
+ private_tree,
+ key_schedule: key_schedule_result.key_schedule,
+ #[cfg(feature = "by_ref_proposal")]
+ pending_updates: Default::default(),
+ pending_commit: None,
+ #[cfg(test)]
+ commit_modifiers: Default::default(),
+ epoch_secrets: key_schedule_result.epoch_secrets,
+ state_repo,
+ cipher_suite_provider,
+ #[cfg(feature = "psk")]
+ previous_psk: None,
+ signer,
+ })
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn join(
+ welcome: &MlsMessage,
+ tree_data: Option<ExportedTree<'_>>,
+ config: C,
+ signer: SignatureSecretKey,
+ ) -> Result<(Self, NewMemberInfo), MlsError> {
+ Self::from_welcome_message(
+ welcome,
+ tree_data,
+ config,
+ signer,
+ #[cfg(feature = "psk")]
+ None,
+ )
+ .await
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn from_welcome_message(
+ welcome: &MlsMessage,
+ tree_data: Option<ExportedTree<'_>>,
+ config: C,
+ signer: SignatureSecretKey,
+ #[cfg(feature = "psk")] additional_psk: Option<PskSecretInput>,
+ ) -> Result<(Self, NewMemberInfo), MlsError> {
+ let protocol_version = welcome.version;
+
+ if !config.version_supported(protocol_version) {
+ return Err(MlsError::UnsupportedProtocolVersion(protocol_version));
+ }
+
+ let MlsMessagePayload::Welcome(welcome) = &welcome.payload else {
+ return Err(MlsError::UnexpectedMessageType);
+ };
+
+ let cipher_suite_provider =
+ cipher_suite_provider(config.crypto_provider(), welcome.cipher_suite)?;
+
+ let (encrypted_group_secrets, key_package_generation) =
+ find_key_package_generation(&config.key_package_repo(), &welcome.secrets).await?;
+
+ let key_package_version = key_package_generation.key_package.version;
+
+ if key_package_version != protocol_version {
+ return Err(MlsError::ProtocolVersionMismatch);
+ }
+
+ // Decrypt the encrypted_group_secrets using HPKE with the algorithms indicated by the
+ // cipher suite and the HPKE private key corresponding to the GroupSecrets. If a
+ // PreSharedKeyID is part of the GroupSecrets and the client is not in possession of
+ // the corresponding PSK, return an error
+ let group_secrets = GroupSecrets::decrypt(
+ &cipher_suite_provider,
+ &key_package_generation.init_secret_key,
+ &key_package_generation.key_package.hpke_init_key,
+ &welcome.encrypted_group_info,
+ &encrypted_group_secrets.encrypted_group_secrets,
+ )
+ .await?;
+
+ #[cfg(feature = "psk")]
+ let psk_secret = if let Some(psk) = additional_psk {
+ let psk_id = group_secrets
+ .psks
+ .first()
+ .ok_or(MlsError::UnexpectedPskId)?;
+
+ match &psk_id.key_id {
+ JustPreSharedKeyID::Resumption(r) if r.usage != ResumptionPSKUsage::Application => {
+ Ok(())
+ }
+ _ => Err(MlsError::UnexpectedPskId),
+ }?;
+
+ let mut psk = psk;
+ psk.id.psk_nonce = psk_id.psk_nonce.clone();
+ PskSecret::calculate(&[psk], &cipher_suite_provider).await?
+ } else {
+ PskResolver::<
+ <C as ClientConfig>::GroupStateStorage,
+ <C as ClientConfig>::KeyPackageRepository,
+ <C as ClientConfig>::PskStore,
+ > {
+ group_context: None,
+ current_epoch: None,
+ prior_epochs: None,
+ psk_store: &config.secret_store(),
+ }
+ .resolve_to_secret(&group_secrets.psks, &cipher_suite_provider)
+ .await?
+ };
+
+ #[cfg(not(feature = "psk"))]
+ let psk_secret = PskSecret::new(&cipher_suite_provider);
+
+ // From the joiner_secret in the decrypted GroupSecrets object and the PSKs specified in
+ // the GroupSecrets, derive the welcome_secret and using that the welcome_key and
+ // welcome_nonce.
+ let welcome_secret = WelcomeSecret::from_joiner_secret(
+ &cipher_suite_provider,
+ &group_secrets.joiner_secret,
+ &psk_secret,
+ )
+ .await?;
+
+ // Use the key and nonce to decrypt the encrypted_group_info field.
+ let decrypted_group_info = welcome_secret
+ .decrypt(&welcome.encrypted_group_info)
+ .await?;
+
+ let group_info = GroupInfo::mls_decode(&mut &**decrypted_group_info)?;
+
+ let public_tree = validate_group_info_joiner(
+ protocol_version,
+ &group_info,
+ tree_data,
+ &config.identity_provider(),
+ &cipher_suite_provider,
+ )
+ .await?;
+
+ // Identify a leaf in the tree array (any even-numbered node) whose leaf_node is identical
+ // to the leaf_node field of the KeyPackage. If no such field exists, return an error. Let
+ // index represent the index of this node among the leaves in the tree, namely the index of
+ // the node in the tree array divided by two.
+ let self_index = public_tree
+ .find_leaf_node(&key_package_generation.key_package.leaf_node)
+ .ok_or(MlsError::WelcomeKeyPackageNotFound)?;
+
+ let used_key_package_ref = key_package_generation.reference;
+
+ let mut private_tree =
+ TreeKemPrivate::new_self_leaf(self_index, key_package_generation.leaf_node_secret_key);
+
+ // If the path_secret value is set in the GroupSecrets object
+ if let Some(path_secret) = group_secrets.path_secret {
+ private_tree
+ .update_secrets(
+ &cipher_suite_provider,
+ group_info.signer,
+ path_secret,
+ &public_tree,
+ )
+ .await?;
+ }
+
+ // Use the joiner_secret from the GroupSecrets object to generate the epoch secret and
+ // other derived secrets for the current epoch.
+ let key_schedule_result = KeySchedule::from_joiner(
+ &cipher_suite_provider,
+ &group_secrets.joiner_secret,
+ &group_info.group_context,
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ public_tree.total_leaf_count(),
+ &psk_secret,
+ )
+ .await?;
+
+ // Verify the confirmation tag in the GroupInfo using the derived confirmation key and the
+ // confirmed_transcript_hash from the GroupInfo.
+ if !group_info
+ .confirmation_tag
+ .matches(
+ &key_schedule_result.confirmation_key,
+ &group_info.group_context.confirmed_transcript_hash,
+ &cipher_suite_provider,
+ )
+ .await?
+ {
+ return Err(MlsError::InvalidConfirmationTag);
+ }
+
+ Self::join_with(
+ config,
+ group_info,
+ public_tree,
+ key_schedule_result.key_schedule,
+ key_schedule_result.epoch_secrets,
+ private_tree,
+ Some(used_key_package_ref),
+ signer,
+ )
+ .await
+ }
+
+ #[allow(clippy::too_many_arguments)]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn join_with(
+ config: C,
+ group_info: GroupInfo,
+ public_tree: TreeKemPublic,
+ key_schedule: KeySchedule,
+ epoch_secrets: EpochSecrets,
+ private_tree: TreeKemPrivate,
+ used_key_package_ref: Option<KeyPackageRef>,
+ signer: SignatureSecretKey,
+ ) -> Result<(Self, NewMemberInfo), MlsError> {
+ let cs = group_info.group_context.cipher_suite;
+
+ let cs = config
+ .crypto_provider()
+ .cipher_suite_provider(cs)
+ .ok_or(MlsError::UnsupportedCipherSuite(cs))?;
+
+ // Use the confirmed transcript hash and confirmation tag to compute the interim transcript
+ // hash in the new state.
+ let interim_transcript_hash = InterimTranscriptHash::create(
+ &cs,
+ &group_info.group_context.confirmed_transcript_hash,
+ &group_info.confirmation_tag,
+ )
+ .await?;
+
+ let state_repo = GroupStateRepository::new(
+ #[cfg(feature = "prior_epoch")]
+ group_info.group_context.group_id.clone(),
+ config.group_state_storage(),
+ config.key_package_repo(),
+ used_key_package_ref,
+ )?;
+
+ let group = Group {
+ config,
+ state: GroupState::new(
+ group_info.group_context,
+ public_tree,
+ interim_transcript_hash,
+ group_info.confirmation_tag,
+ ),
+ private_tree,
+ key_schedule,
+ #[cfg(feature = "by_ref_proposal")]
+ pending_updates: Default::default(),
+ pending_commit: None,
+ #[cfg(test)]
+ commit_modifiers: Default::default(),
+ epoch_secrets,
+ state_repo,
+ cipher_suite_provider: cs,
+ #[cfg(feature = "psk")]
+ previous_psk: None,
+ signer,
+ };
+
+ Ok((group, NewMemberInfo::new(group_info.extensions)))
+ }
+
+ #[inline(always)]
+ pub(crate) fn current_epoch_tree(&self) -> &TreeKemPublic {
+ &self.state.public_tree
+ }
+
+ /// The current epoch of the group. This value is incremented each
+ /// time a [`Group::commit`] message is processed.
+ #[inline(always)]
+ pub fn current_epoch(&self) -> u64 {
+ self.context().epoch
+ }
+
+ /// Index within the group's state for the local group instance.
+ ///
+ /// This index corresponds to indexes in content descriptions within
+ /// [`ReceivedMessage`].
+ #[inline(always)]
+ pub fn current_member_index(&self) -> u32 {
+ self.private_tree.self_index.0
+ }
+
+ fn current_user_leaf_node(&self) -> Result<&LeafNode, MlsError> {
+ self.current_epoch_tree()
+ .get_leaf_node(self.private_tree.self_index)
+ }
+
+ /// Signing identity currently in use by the local group instance.
+ #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen_ignore)]
+ pub fn current_member_signing_identity(&self) -> Result<&SigningIdentity, MlsError> {
+ self.current_user_leaf_node().map(|ln| &ln.signing_identity)
+ }
+
+ /// Member at a specific index in the group state.
+ ///
+ /// These indexes correspond to indexes in content descriptions within
+ /// [`ReceivedMessage`].
+ pub fn member_at_index(&self, index: u32) -> Option<Member> {
+ let leaf_index = LeafIndex(index);
+
+ self.current_epoch_tree()
+ .get_leaf_node(leaf_index)
+ .ok()
+ .map(|ln| member_from_leaf_node(ln, leaf_index))
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn proposal_message(
+ &mut self,
+ proposal: Proposal,
+ authenticated_data: Vec<u8>,
+ ) -> Result<MlsMessage, MlsError> {
+ let sender = Sender::Member(*self.private_tree.self_index);
+
+ let auth_content = AuthenticatedContent::new_signed(
+ &self.cipher_suite_provider,
+ self.context(),
+ sender,
+ Content::Proposal(alloc::boxed::Box::new(proposal.clone())),
+ &self.signer,
+ #[cfg(feature = "private_message")]
+ self.encryption_options()?.control_wire_format(sender),
+ #[cfg(not(feature = "private_message"))]
+ WireFormat::PublicMessage,
+ authenticated_data,
+ )
+ .await?;
+
+ let proposal_ref =
+ ProposalRef::from_content(&self.cipher_suite_provider, &auth_content).await?;
+
+ self.state
+ .proposals
+ .insert(proposal_ref, proposal, auth_content.content.sender);
+
+ self.format_for_wire(auth_content).await
+ }
+
+ /// Unique identifier for this group.
+ pub fn group_id(&self) -> &[u8] {
+ &self.context().group_id
+ }
+
+ fn provisional_private_tree(
+ &self,
+ provisional_state: &ProvisionalState,
+ ) -> Result<(TreeKemPrivate, Option<SignatureSecretKey>), MlsError> {
+ let mut provisional_private_tree = self.private_tree.clone();
+ let self_index = provisional_private_tree.self_index;
+
+ // Remove secret keys for blanked nodes
+ let path = provisional_state
+ .public_tree
+ .nodes
+ .direct_copath(self_index);
+
+ provisional_private_tree
+ .secret_keys
+ .resize(path.len() + 1, None);
+
+ for (i, n) in path.iter().enumerate() {
+ if provisional_state.public_tree.nodes.is_blank(n.path)? {
+ provisional_private_tree.secret_keys[i + 1] = None;
+ }
+ }
+
+ // Apply own update
+ let new_signer = None;
+
+ #[cfg(feature = "by_ref_proposal")]
+ let mut new_signer = new_signer;
+
+ #[cfg(feature = "by_ref_proposal")]
+ for p in &provisional_state.applied_proposals.updates {
+ if p.sender == Sender::Member(*self_index) {
+ let leaf_pk = &p.proposal.leaf_node.public_key;
+
+ // Update the leaf in the private tree if this is our update
+ #[cfg(feature = "std")]
+ let new_leaf_sk_and_signer = self.pending_updates.get(leaf_pk);
+
+ #[cfg(not(feature = "std"))]
+ let new_leaf_sk_and_signer = self
+ .pending_updates
+ .iter()
+ .find_map(|(pk, sk)| (pk == leaf_pk).then_some(sk));
+
+ let new_leaf_sk = new_leaf_sk_and_signer.map(|(sk, _)| sk.clone());
+ new_signer = new_leaf_sk_and_signer.and_then(|(_, sk)| sk.clone());
+
+ provisional_private_tree
+ .update_leaf(new_leaf_sk.ok_or(MlsError::UpdateErrorNoSecretKey)?);
+
+ break;
+ }
+ }
+
+ Ok((provisional_private_tree, new_signer))
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn encrypt_group_secrets(
+ &self,
+ key_package: &KeyPackage,
+ leaf_index: LeafIndex,
+ joiner_secret: &JoinerSecret,
+ path_secrets: Option<&Vec<Option<PathSecret>>>,
+ #[cfg(feature = "psk")] psks: Vec<PreSharedKeyID>,
+ encrypted_group_info: &[u8],
+ ) -> Result<EncryptedGroupSecrets, MlsError> {
+ let path_secret = path_secrets
+ .map(|secrets| {
+ secrets
+ .get(
+ tree_math::leaf_lca_level(*self.private_tree.self_index, *leaf_index)
+ as usize
+ - 1,
+ )
+ .cloned()
+ .flatten()
+ .ok_or(MlsError::InvalidTreeKemPrivateKey)
+ })
+ .transpose()?;
+
+ #[cfg(not(feature = "psk"))]
+ let psks = Vec::new();
+
+ let group_secrets = GroupSecrets {
+ joiner_secret: joiner_secret.clone(),
+ path_secret,
+ psks,
+ };
+
+ let encrypted_group_secrets = group_secrets
+ .encrypt(
+ &self.cipher_suite_provider,
+ &key_package.hpke_init_key,
+ encrypted_group_info,
+ )
+ .await?;
+
+ Ok(EncryptedGroupSecrets {
+ new_member: key_package
+ .to_reference(&self.cipher_suite_provider)
+ .await?,
+ encrypted_group_secrets,
+ })
+ }
+
+ /// Create a proposal message that adds a new member to the group.
+ ///
+ /// `authenticated_data` will be sent unencrypted along with the contents
+ /// of the proposal message.
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn propose_add(
+ &mut self,
+ key_package: MlsMessage,
+ authenticated_data: Vec<u8>,
+ ) -> Result<MlsMessage, MlsError> {
+ let proposal = self.add_proposal(key_package)?;
+ self.proposal_message(proposal, authenticated_data).await
+ }
+
+ fn add_proposal(&self, key_package: MlsMessage) -> Result<Proposal, MlsError> {
+ Ok(Proposal::Add(alloc::boxed::Box::new(AddProposal {
+ key_package: key_package
+ .into_key_package()
+ .ok_or(MlsError::UnexpectedMessageType)?,
+ })))
+ }
+
+ /// Create a proposal message that updates your own public keys.
+ ///
+ /// This proposal is useful for contributing additional forward secrecy
+ /// and post-compromise security to the group without having to perform
+ /// the necessary computation of a [`Group::commit`].
+ ///
+ /// `authenticated_data` will be sent unencrypted along with the contents
+ /// of the proposal message.
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn propose_update(
+ &mut self,
+ authenticated_data: Vec<u8>,
+ ) -> Result<MlsMessage, MlsError> {
+ let proposal = self.update_proposal(None, None).await?;
+ self.proposal_message(proposal, authenticated_data).await
+ }
+
+ /// Create a proposal message that updates your own public keys
+ /// as well as your credential.
+ ///
+ /// This proposal is useful for contributing additional forward secrecy
+ /// and post-compromise security to the group without having to perform
+ /// the necessary computation of a [`Group::commit`].
+ ///
+ /// Identity updates are allowed by the group by default assuming that the
+ /// new identity provided is considered
+ /// [valid](crate::IdentityProvider::validate_member)
+ /// by and matches the output of the
+ /// [identity](crate::IdentityProvider)
+ /// function of the current
+ /// [`IdentityProvider`](crate::IdentityProvider).
+ ///
+ /// `authenticated_data` will be sent unencrypted along with the contents
+ /// of the proposal message.
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn propose_update_with_identity(
+ &mut self,
+ signer: SignatureSecretKey,
+ signing_identity: SigningIdentity,
+ authenticated_data: Vec<u8>,
+ ) -> Result<MlsMessage, MlsError> {
+ let proposal = self
+ .update_proposal(Some(signer), Some(signing_identity))
+ .await?;
+
+ self.proposal_message(proposal, authenticated_data).await
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn update_proposal(
+ &mut self,
+ signer: Option<SignatureSecretKey>,
+ signing_identity: Option<SigningIdentity>,
+ ) -> Result<Proposal, MlsError> {
+ // Grab a copy of the current node and update it to have new key material
+ let mut new_leaf_node = self.current_user_leaf_node()?.clone();
+
+ let secret_key = new_leaf_node
+ .update(
+ &self.cipher_suite_provider,
+ self.group_id(),
+ self.current_member_index(),
+ self.config.leaf_properties(),
+ signing_identity,
+ signer.as_ref().unwrap_or(&self.signer),
+ )
+ .await?;
+
+ // Store the secret key in the pending updates storage for later
+ #[cfg(feature = "std")]
+ self.pending_updates
+ .insert(new_leaf_node.public_key.clone(), (secret_key, signer));
+
+ #[cfg(not(feature = "std"))]
+ self.pending_updates
+ .push((new_leaf_node.public_key.clone(), (secret_key, signer)));
+
+ Ok(Proposal::Update(UpdateProposal {
+ leaf_node: new_leaf_node,
+ }))
+ }
+
+ /// Create a proposal message that removes an existing member from the
+ /// group.
+ ///
+ /// `authenticated_data` will be sent unencrypted along with the contents
+ /// of the proposal message.
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn propose_remove(
+ &mut self,
+ index: u32,
+ authenticated_data: Vec<u8>,
+ ) -> Result<MlsMessage, MlsError> {
+ let proposal = self.remove_proposal(index)?;
+ self.proposal_message(proposal, authenticated_data).await
+ }
+
+ fn remove_proposal(&self, index: u32) -> Result<Proposal, MlsError> {
+ let leaf_index = LeafIndex(index);
+
+ // Verify that this leaf is actually in the tree
+ self.current_epoch_tree().get_leaf_node(leaf_index)?;
+
+ Ok(Proposal::Remove(RemoveProposal {
+ to_remove: leaf_index,
+ }))
+ }
+
+ /// Create a proposal message that adds an external pre shared key to the group.
+ ///
+ /// Each group member will need to have the PSK associated with
+ /// [`ExternalPskId`](mls_rs_core::psk::ExternalPskId) installed within
+ /// the [`PreSharedKeyStorage`](mls_rs_core::psk::PreSharedKeyStorage)
+ /// in use by this group upon processing a [commit](Group::commit) that
+ /// contains this proposal.
+ ///
+ /// `authenticated_data` will be sent unencrypted along with the contents
+ /// of the proposal message.
+ #[cfg(all(feature = "by_ref_proposal", feature = "psk"))]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn propose_external_psk(
+ &mut self,
+ psk: ExternalPskId,
+ authenticated_data: Vec<u8>,
+ ) -> Result<MlsMessage, MlsError> {
+ let proposal = self.psk_proposal(JustPreSharedKeyID::External(psk))?;
+ self.proposal_message(proposal, authenticated_data).await
+ }
+
+ #[cfg(feature = "psk")]
+ fn psk_proposal(&self, key_id: JustPreSharedKeyID) -> Result<Proposal, MlsError> {
+ Ok(Proposal::Psk(PreSharedKeyProposal {
+ psk: PreSharedKeyID::new(key_id, &self.cipher_suite_provider)?,
+ }))
+ }
+
+ /// Create a proposal message that adds a pre shared key from a previous
+ /// epoch to the current group state.
+ ///
+ /// Each group member will need to have the secret state from `psk_epoch`.
+ /// In particular, the members who joined between `psk_epoch` and the
+ /// current epoch cannot process a commit containing this proposal.
+ ///
+ /// `authenticated_data` will be sent unencrypted along with the contents
+ /// of the proposal message.
+ #[cfg(all(feature = "by_ref_proposal", feature = "psk"))]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn propose_resumption_psk(
+ &mut self,
+ psk_epoch: u64,
+ authenticated_data: Vec<u8>,
+ ) -> Result<MlsMessage, MlsError> {
+ let key_id = ResumptionPsk {
+ psk_epoch,
+ usage: ResumptionPSKUsage::Application,
+ psk_group_id: PskGroupId(self.group_id().to_vec()),
+ };
+
+ let proposal = self.psk_proposal(JustPreSharedKeyID::Resumption(key_id))?;
+ self.proposal_message(proposal, authenticated_data).await
+ }
+
+ /// Create a proposal message that requests for this group to be
+ /// reinitialized.
+ ///
+ /// Once a [`ReInitProposal`](proposal::ReInitProposal)
+ /// has been sent, another group member can complete reinitialization of
+ /// the group by calling [`Group::get_reinit_client`].
+ ///
+ /// `authenticated_data` will be sent unencrypted along with the contents
+ /// of the proposal message.
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn propose_reinit(
+ &mut self,
+ group_id: Option<Vec<u8>>,
+ version: ProtocolVersion,
+ cipher_suite: CipherSuite,
+ extensions: ExtensionList,
+ authenticated_data: Vec<u8>,
+ ) -> Result<MlsMessage, MlsError> {
+ let proposal = self.reinit_proposal(group_id, version, cipher_suite, extensions)?;
+ self.proposal_message(proposal, authenticated_data).await
+ }
+
+ fn reinit_proposal(
+ &self,
+ group_id: Option<Vec<u8>>,
+ version: ProtocolVersion,
+ cipher_suite: CipherSuite,
+ extensions: ExtensionList,
+ ) -> Result<Proposal, MlsError> {
+ let group_id = group_id.map(Ok).unwrap_or_else(|| {
+ self.cipher_suite_provider
+ .random_bytes_vec(self.cipher_suite_provider.kdf_extract_size())
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
+ })?;
+
+ Ok(Proposal::ReInit(ReInitProposal {
+ group_id,
+ version,
+ cipher_suite,
+ extensions,
+ }))
+ }
+
+ /// Create a proposal message that sets extensions stored in the group
+ /// state.
+ ///
+ /// # Warning
+ ///
+ /// This function does not create a diff that will be applied to the
+ /// current set of extension that are in use. In order for an existing
+ /// extension to not be overwritten by this proposal, it must be included
+ /// in the new set of extensions being proposed.
+ ///
+ ///
+ /// `authenticated_data` will be sent unencrypted along with the contents
+ /// of the proposal message.
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn propose_group_context_extensions(
+ &mut self,
+ extensions: ExtensionList,
+ authenticated_data: Vec<u8>,
+ ) -> Result<MlsMessage, MlsError> {
+ let proposal = self.group_context_extensions_proposal(extensions);
+ self.proposal_message(proposal, authenticated_data).await
+ }
+
+ fn group_context_extensions_proposal(&self, extensions: ExtensionList) -> Proposal {
+ Proposal::GroupContextExtensions(extensions)
+ }
+
+ /// Create a custom proposal message.
+ ///
+ /// `authenticated_data` will be sent unencrypted along with the contents
+ /// of the proposal message.
+ #[cfg(all(feature = "custom_proposal", feature = "by_ref_proposal"))]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn propose_custom(
+ &mut self,
+ proposal: CustomProposal,
+ authenticated_data: Vec<u8>,
+ ) -> Result<MlsMessage, MlsError> {
+ self.proposal_message(Proposal::Custom(proposal), authenticated_data)
+ .await
+ }
+
+ /// Delete all sent and received proposals cached for commit.
+ #[cfg(feature = "by_ref_proposal")]
+ pub fn clear_proposal_cache(&mut self) {
+ self.state.proposals.clear()
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn format_for_wire(
+ &mut self,
+ content: AuthenticatedContent,
+ ) -> Result<MlsMessage, MlsError> {
+ #[cfg(feature = "private_message")]
+ let payload = if content.wire_format == WireFormat::PrivateMessage {
+ MlsMessagePayload::Cipher(self.create_ciphertext(content).await?)
+ } else {
+ MlsMessagePayload::Plain(self.create_plaintext(content).await?)
+ };
+ #[cfg(not(feature = "private_message"))]
+ let payload = MlsMessagePayload::Plain(self.create_plaintext(content).await?);
+
+ Ok(MlsMessage::new(self.protocol_version(), payload))
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn create_plaintext(
+ &self,
+ auth_content: AuthenticatedContent,
+ ) -> Result<PublicMessage, MlsError> {
+ let membership_tag = if matches!(auth_content.content.sender, Sender::Member(_)) {
+ let tag = self
+ .key_schedule
+ .get_membership_tag(&auth_content, self.context(), &self.cipher_suite_provider)
+ .await?;
+
+ Some(tag)
+ } else {
+ None
+ };
+
+ Ok(PublicMessage {
+ content: auth_content.content,
+ auth: auth_content.auth,
+ membership_tag,
+ })
+ }
+
+ #[cfg(feature = "private_message")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn create_ciphertext(
+ &mut self,
+ auth_content: AuthenticatedContent,
+ ) -> Result<PrivateMessage, MlsError> {
+ let padding_mode = self.encryption_options()?.padding_mode;
+
+ let mut encryptor = CiphertextProcessor::new(self, self.cipher_suite_provider.clone());
+
+ encryptor.seal(auth_content, padding_mode).await
+ }
+
+ /// Encrypt an application message using the current group state.
+ ///
+ /// `authenticated_data` will be sent unencrypted along with the contents
+ /// of the proposal message.
+ #[cfg(feature = "private_message")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn encrypt_application_message(
+ &mut self,
+ message: &[u8],
+ authenticated_data: Vec<u8>,
+ ) -> Result<MlsMessage, MlsError> {
+ // A group member that has observed one or more proposals within an epoch MUST send a Commit message
+ // before sending application data
+ #[cfg(feature = "by_ref_proposal")]
+ if !self.state.proposals.is_empty() {
+ return Err(MlsError::CommitRequired);
+ }
+
+ let auth_content = AuthenticatedContent::new_signed(
+ &self.cipher_suite_provider,
+ self.context(),
+ Sender::Member(*self.private_tree.self_index),
+ Content::Application(message.to_vec().into()),
+ &self.signer,
+ WireFormat::PrivateMessage,
+ authenticated_data,
+ )
+ .await?;
+
+ self.format_for_wire(auth_content).await
+ }
+
+ #[cfg(feature = "private_message")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn decrypt_incoming_ciphertext(
+ &mut self,
+ message: &PrivateMessage,
+ ) -> Result<AuthenticatedContent, MlsError> {
+ let epoch_id = message.epoch;
+
+ let auth_content = if epoch_id == self.context().epoch {
+ let content = CiphertextProcessor::new(self, self.cipher_suite_provider.clone())
+ .open(message)
+ .await?;
+
+ verify_auth_content_signature(
+ &self.cipher_suite_provider,
+ SignaturePublicKeysContainer::RatchetTree(&self.state.public_tree),
+ self.context(),
+ &content,
+ #[cfg(feature = "by_ref_proposal")]
+ &[],
+ )
+ .await?;
+
+ Ok::<_, MlsError>(content)
+ } else {
+ #[cfg(feature = "prior_epoch")]
+ {
+ let epoch = self
+ .state_repo
+ .get_epoch_mut(epoch_id)
+ .await?
+ .ok_or(MlsError::EpochNotFound)?;
+
+ let content = CiphertextProcessor::new(epoch, self.cipher_suite_provider.clone())
+ .open(message)
+ .await?;
+
+ verify_auth_content_signature(
+ &self.cipher_suite_provider,
+ SignaturePublicKeysContainer::List(&epoch.signature_public_keys),
+ &epoch.context,
+ &content,
+ #[cfg(feature = "by_ref_proposal")]
+ &[],
+ )
+ .await?;
+
+ Ok(content)
+ }
+
+ #[cfg(not(feature = "prior_epoch"))]
+ Err(MlsError::EpochNotFound)
+ }?;
+
+ Ok(auth_content)
+ }
+
+ /// Apply a pending commit that was created by [`Group::commit`] or
+ /// [`CommitBuilder::build`].
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn apply_pending_commit(&mut self) -> Result<CommitMessageDescription, MlsError> {
+ let pending_commit = self
+ .pending_commit
+ .clone()
+ .ok_or(MlsError::PendingCommitNotFound)?;
+
+ self.process_commit(pending_commit.content, None).await
+ }
+
+ /// Returns true if a commit has been created but not yet applied
+ /// with [`Group::apply_pending_commit`] or cleared with [`Group::clear_pending_commit`]
+ pub fn has_pending_commit(&self) -> bool {
+ self.pending_commit.is_some()
+ }
+
+ /// Clear the currently pending commit.
+ ///
+ /// This function will automatically be called in the event that a
+ /// commit message is processed using [`Group::process_incoming_message`]
+ /// before [`Group::apply_pending_commit`] is called.
+ pub fn clear_pending_commit(&mut self) {
+ self.pending_commit = None
+ }
+
+ /// Process an inbound message for this group.
+ ///
+ /// # Warning
+ ///
+ /// Changes to the group's state as a result of processing `message` will
+ /// not be persisted by the
+ /// [`GroupStateStorage`](crate::GroupStateStorage)
+ /// in use by this group until [`Group::write_to_storage`] is called.
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ #[inline(never)]
+ pub async fn process_incoming_message(
+ &mut self,
+ message: MlsMessage,
+ ) -> Result<ReceivedMessage, MlsError> {
+ if let Some(pending) = &self.pending_commit {
+ let message_hash = CommitHash::compute(&self.cipher_suite_provider, &message).await?;
+
+ if message_hash == pending.commit_message_hash {
+ let message_description = self.apply_pending_commit().await?;
+
+ return Ok(ReceivedMessage::Commit(message_description));
+ }
+ }
+
+ MessageProcessor::process_incoming_message(
+ self,
+ message,
+ #[cfg(feature = "by_ref_proposal")]
+ true,
+ )
+ .await
+ }
+
+ /// Process an inbound message for this group, providing additional context
+ /// with a message timestamp.
+ ///
+ /// Providing a timestamp is useful when the
+ /// [`IdentityProvider`](crate::IdentityProvider)
+ /// in use by the group can determine validity based on a timestamp.
+ /// For example, this allows for checking X.509 certificate expiration
+ /// at the time when `message` was received by a server rather than when
+ /// a specific client asynchronously received `message`
+ ///
+ /// # Warning
+ ///
+ /// Changes to the group's state as a result of processing `message` will
+ /// not be persisted by the
+ /// [`GroupStateStorage`](crate::GroupStateStorage)
+ /// in use by this group until [`Group::write_to_storage`] is called.
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn process_incoming_message_with_time(
+ &mut self,
+ message: MlsMessage,
+ time: MlsTime,
+ ) -> Result<ReceivedMessage, MlsError> {
+ MessageProcessor::process_incoming_message_with_time(
+ self,
+ message,
+ #[cfg(feature = "by_ref_proposal")]
+ true,
+ Some(time),
+ )
+ .await
+ }
+
+ /// Find a group member by
+ /// [identity](crate::IdentityProvider::identity)
+ ///
+ /// This function determines identity by calling the
+ /// [`IdentityProvider`](crate::IdentityProvider)
+ /// currently in use by the group.
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn member_with_identity(&self, identity: &[u8]) -> Result<Member, MlsError> {
+ let tree = &self.state.public_tree;
+
+ #[cfg(feature = "tree_index")]
+ let index = tree.get_leaf_node_with_identity(identity);
+
+ #[cfg(not(feature = "tree_index"))]
+ let index = tree
+ .get_leaf_node_with_identity(
+ identity,
+ &self.identity_provider(),
+ &self.state.context.extensions,
+ )
+ .await?;
+
+ let index = index.ok_or(MlsError::MemberNotFound)?;
+ let node = self.state.public_tree.get_leaf_node(index)?;
+
+ Ok(member_from_leaf_node(node, index))
+ }
+
+ /// Create a group info message that can be used for external proposals and commits.
+ ///
+ /// The returned `GroupInfo` is suitable for one external commit for the current epoch.
+ /// If `with_tree_in_extension` is set to true, the returned `GroupInfo` contains the
+ /// ratchet tree and therefore contains all information needed to join the group. Otherwise,
+ /// the ratchet tree must be obtained separately, e.g. via
+ /// (ExternalClient::export_tree)[crate::external_client::ExternalGroup::export_tree].
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn group_info_message_allowing_ext_commit(
+ &self,
+ with_tree_in_extension: bool,
+ ) -> Result<MlsMessage, MlsError> {
+ let mut extensions = ExtensionList::new();
+
+ extensions.set_from({
+ self.key_schedule
+ .get_external_key_pair_ext(&self.cipher_suite_provider)
+ .await?
+ })?;
+
+ self.group_info_message_internal(extensions, with_tree_in_extension)
+ .await
+ }
+
+ /// Create a group info message that can be used for external proposals.
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn group_info_message(
+ &self,
+ with_tree_in_extension: bool,
+ ) -> Result<MlsMessage, MlsError> {
+ self.group_info_message_internal(ExtensionList::new(), with_tree_in_extension)
+ .await
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn group_info_message_internal(
+ &self,
+ mut initial_extensions: ExtensionList,
+ with_tree_in_extension: bool,
+ ) -> Result<MlsMessage, MlsError> {
+ if with_tree_in_extension {
+ initial_extensions.set_from(RatchetTreeExt {
+ tree_data: ExportedTree::new(self.state.public_tree.nodes.clone()),
+ })?;
+ }
+
+ let mut info = GroupInfo {
+ group_context: self.context().clone(),
+ extensions: initial_extensions,
+ confirmation_tag: self.state.confirmation_tag.clone(),
+ signer: self.private_tree.self_index,
+ signature: Vec::new(),
+ };
+
+ info.grease(self.cipher_suite_provider())?;
+
+ info.sign(&self.cipher_suite_provider, &self.signer, &())
+ .await?;
+
+ Ok(MlsMessage::new(
+ self.protocol_version(),
+ MlsMessagePayload::GroupInfo(info),
+ ))
+ }
+
+ /// Get the current group context summarizing various information about the group.
+ #[inline(always)]
+ pub fn context(&self) -> &GroupContext {
+ &self.group_state().context
+ }
+
+ /// Get the
+ /// [epoch_authenticator](https://messaginglayersecurity.rocks/mls-protocol/draft-ietf-mls-protocol.html#name-key-schedule)
+ /// of the current epoch.
+ pub fn epoch_authenticator(&self) -> Result<Secret, MlsError> {
+ Ok(self.key_schedule.authentication_secret.clone().into())
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn export_secret(
+ &self,
+ label: &[u8],
+ context: &[u8],
+ len: usize,
+ ) -> Result<Secret, MlsError> {
+ self.key_schedule
+ .export_secret(label, context, len, &self.cipher_suite_provider)
+ .await
+ .map(Into::into)
+ }
+
+ /// Export the current epoch's ratchet tree in serialized format.
+ ///
+ /// This function is used to provide the current group tree to new members
+ /// when the `ratchet_tree_extension` is not used according to [`MlsRules::commit_options`].
+ pub fn export_tree(&self) -> ExportedTree<'_> {
+ ExportedTree::new_borrowed(&self.current_epoch_tree().nodes)
+ }
+
+ /// Current version of the MLS protocol in use by this group.
+ pub fn protocol_version(&self) -> ProtocolVersion {
+ self.context().protocol_version
+ }
+
+ /// Current cipher suite in use by this group.
+ pub fn cipher_suite(&self) -> CipherSuite {
+ self.context().cipher_suite
+ }
+
+ /// Current roster
+ pub fn roster(&self) -> Roster<'_> {
+ self.group_state().public_tree.roster()
+ }
+
+ /// Determines equality of two different groups internal states.
+ /// Useful for testing.
+ ///
+ pub fn equal_group_state(a: &Group<C>, b: &Group<C>) -> bool {
+ a.state == b.state && a.key_schedule == b.key_schedule && a.epoch_secrets == b.epoch_secrets
+ }
+
+ #[cfg(feature = "psk")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn get_psk(
+ &self,
+ psks: &[ProposalInfo<PreSharedKeyProposal>],
+ ) -> Result<(PskSecret, Vec<PreSharedKeyID>), MlsError> {
+ if let Some(psk) = self.previous_psk.clone() {
+ // TODO consider throwing error if psks not empty
+ let psk_id = vec![psk.id.clone()];
+ let psk = PskSecret::calculate(&[psk], self.cipher_suite_provider()).await?;
+
+ Ok((psk, psk_id))
+ } else {
+ let psks = psks
+ .iter()
+ .map(|psk| psk.proposal.psk.clone())
+ .collect::<Vec<_>>();
+
+ let psk = PskResolver {
+ group_context: Some(self.context()),
+ current_epoch: Some(&self.epoch_secrets),
+ prior_epochs: Some(&self.state_repo),
+ psk_store: &self.config.secret_store(),
+ }
+ .resolve_to_secret(&psks, self.cipher_suite_provider())
+ .await?;
+
+ Ok((psk, psks))
+ }
+ }
+
+ #[cfg(feature = "private_message")]
+ pub(crate) fn encryption_options(&self) -> Result<EncryptionOptions, MlsError> {
+ self.config
+ .mls_rules()
+ .encryption_options(&self.roster(), self.group_context().extensions())
+ .map_err(|e| MlsError::MlsRulesError(e.into_any_error()))
+ }
+
+ #[cfg(not(feature = "psk"))]
+ fn get_psk(&self) -> PskSecret {
+ PskSecret::new(self.cipher_suite_provider())
+ }
+
+ #[cfg(feature = "secret_tree_access")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ #[inline(never)]
+ pub async fn next_encryption_key(&mut self) -> Result<MessageKey, MlsError> {
+ self.epoch_secrets
+ .secret_tree
+ .next_message_key(
+ &self.cipher_suite_provider,
+ crate::tree_kem::node::NodeIndex::from(self.private_tree.self_index),
+ KeyType::Application,
+ )
+ .await
+ }
+
+ #[cfg(feature = "secret_tree_access")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn derive_decryption_key(
+ &mut self,
+ sender: u32,
+ generation: u32,
+ ) -> Result<MessageKey, MlsError> {
+ self.epoch_secrets
+ .secret_tree
+ .message_key_generation(
+ &self.cipher_suite_provider,
+ crate::tree_kem::node::NodeIndex::from(sender),
+ KeyType::Application,
+ generation,
+ )
+ .await
+ }
+}
+
+#[cfg(feature = "private_message")]
+impl<C> GroupStateProvider for Group<C>
+where
+ C: ClientConfig + Clone,
+{
+ fn group_context(&self) -> &GroupContext {
+ self.context()
+ }
+
+ fn self_index(&self) -> LeafIndex {
+ self.private_tree.self_index
+ }
+
+ fn epoch_secrets_mut(&mut self) -> &mut EpochSecrets {
+ &mut self.epoch_secrets
+ }
+
+ fn epoch_secrets(&self) -> &EpochSecrets {
+ &self.epoch_secrets
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(all(target_arch = "wasm32", mls_build_async), maybe_async::must_be_async(?Send))]
+#[cfg_attr(
+ all(not(target_arch = "wasm32"), mls_build_async),
+ maybe_async::must_be_async
+)]
+impl<C> MessageProcessor for Group<C>
+where
+ C: ClientConfig + Clone,
+{
+ type MlsRules = C::MlsRules;
+ type IdentityProvider = C::IdentityProvider;
+ type PreSharedKeyStorage = C::PskStore;
+ type OutputType = ReceivedMessage;
+ type CipherSuiteProvider = <C::CryptoProvider as CryptoProvider>::CipherSuiteProvider;
+
+ #[cfg(feature = "private_message")]
+ fn self_index(&self) -> Option<LeafIndex> {
+ Some(self.private_tree.self_index)
+ }
+
+ #[cfg(feature = "private_message")]
+ async fn process_ciphertext(
+ &mut self,
+ cipher_text: &PrivateMessage,
+ ) -> Result<EventOrContent<Self::OutputType>, MlsError> {
+ self.decrypt_incoming_ciphertext(cipher_text)
+ .await
+ .map(EventOrContent::Content)
+ }
+
+ async fn verify_plaintext_authentication(
+ &self,
+ message: PublicMessage,
+ ) -> Result<EventOrContent<Self::OutputType>, MlsError> {
+ let auth_content = verify_plaintext_authentication(
+ &self.cipher_suite_provider,
+ message,
+ Some(&self.key_schedule),
+ Some(self.private_tree.self_index),
+ &self.state,
+ )
+ .await?;
+
+ Ok(EventOrContent::Content(auth_content))
+ }
+
+ async fn apply_update_path(
+ &mut self,
+ sender: LeafIndex,
+ update_path: &ValidatedUpdatePath,
+ provisional_state: &mut ProvisionalState,
+ ) -> Result<Option<(TreeKemPrivate, PathSecret)>, MlsError> {
+ // Update the private tree to create a provisional private tree
+ let (mut provisional_private_tree, new_signer) =
+ self.provisional_private_tree(provisional_state)?;
+
+ if let Some(signer) = new_signer {
+ self.signer = signer;
+ }
+
+ provisional_state
+ .public_tree
+ .apply_update_path(
+ sender,
+ update_path,
+ &provisional_state.group_context.extensions,
+ self.identity_provider(),
+ self.cipher_suite_provider(),
+ )
+ .await?;
+
+ if let Some(pending) = &self.pending_commit {
+ Ok(Some((
+ pending.pending_private_tree.clone(),
+ pending.pending_commit_secret.clone(),
+ )))
+ } else {
+ // Update the tree hash to get context for decryption
+ provisional_state.group_context.tree_hash = provisional_state
+ .public_tree
+ .tree_hash(&self.cipher_suite_provider)
+ .await?;
+
+ let context_bytes = provisional_state.group_context.mls_encode_to_vec()?;
+
+ TreeKem::new(
+ &mut provisional_state.public_tree,
+ &mut provisional_private_tree,
+ )
+ .decap(
+ sender,
+ update_path,
+ &provisional_state.indexes_of_added_kpkgs,
+ &context_bytes,
+ &self.cipher_suite_provider,
+ )
+ .await
+ .map(|root_secret| Some((provisional_private_tree, root_secret)))
+ }
+ }
+
+ async fn update_key_schedule(
+ &mut self,
+ secrets: Option<(TreeKemPrivate, PathSecret)>,
+ interim_transcript_hash: InterimTranscriptHash,
+ confirmation_tag: &ConfirmationTag,
+ provisional_state: ProvisionalState,
+ ) -> Result<(), MlsError> {
+ let commit_secret = if let Some(secrets) = secrets {
+ self.private_tree = secrets.0;
+ secrets.1
+ } else {
+ PathSecret::empty(&self.cipher_suite_provider)
+ };
+
+ // Use the commit_secret, the psk_secret, the provisional GroupContext, and the init secret
+ // from the previous epoch (or from the external init) to compute the epoch secret and
+ // derived secrets for the new epoch
+
+ let key_schedule = match provisional_state
+ .applied_proposals
+ .external_initializations
+ .first()
+ .cloned()
+ {
+ Some(ext_init) if self.pending_commit.is_none() => {
+ self.key_schedule
+ .derive_for_external(&ext_init.proposal.kem_output, &self.cipher_suite_provider)
+ .await?
+ }
+ _ => self.key_schedule.clone(),
+ };
+
+ #[cfg(feature = "psk")]
+ let (psk, _) = self
+ .get_psk(&provisional_state.applied_proposals.psks)
+ .await?;
+
+ #[cfg(not(feature = "psk"))]
+ let psk = self.get_psk();
+
+ let key_schedule_result = KeySchedule::from_key_schedule(
+ &key_schedule,
+ &commit_secret,
+ &provisional_state.group_context,
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ provisional_state.public_tree.total_leaf_count(),
+ &psk,
+ &self.cipher_suite_provider,
+ )
+ .await?;
+
+ // Use the confirmation_key for the new epoch to compute the confirmation tag for
+ // this message, as described below, and verify that it is the same as the
+ // confirmation_tag field in the MlsPlaintext object.
+ let new_confirmation_tag = ConfirmationTag::create(
+ &key_schedule_result.confirmation_key,
+ &provisional_state.group_context.confirmed_transcript_hash,
+ &self.cipher_suite_provider,
+ )
+ .await?;
+
+ if &new_confirmation_tag != confirmation_tag {
+ return Err(MlsError::InvalidConfirmationTag);
+ }
+
+ #[cfg(feature = "prior_epoch")]
+ let signature_public_keys = self
+ .state
+ .public_tree
+ .leaves()
+ .map(|l| l.map(|n| n.signing_identity.signature_key.clone()))
+ .collect();
+
+ #[cfg(feature = "prior_epoch")]
+ let past_epoch = PriorEpoch {
+ context: self.context().clone(),
+ self_index: self.private_tree.self_index,
+ secrets: self.epoch_secrets.clone(),
+ signature_public_keys,
+ };
+
+ #[cfg(feature = "prior_epoch")]
+ self.state_repo.insert(past_epoch).await?;
+
+ self.epoch_secrets = key_schedule_result.epoch_secrets;
+ self.state.context = provisional_state.group_context;
+ self.state.interim_transcript_hash = interim_transcript_hash;
+ self.key_schedule = key_schedule_result.key_schedule;
+ self.state.public_tree = provisional_state.public_tree;
+ self.state.confirmation_tag = new_confirmation_tag;
+
+ // Clear the proposals list
+ #[cfg(feature = "by_ref_proposal")]
+ self.state.proposals.clear();
+
+ // Clear the pending updates list
+ #[cfg(feature = "by_ref_proposal")]
+ {
+ self.pending_updates = Default::default();
+ }
+
+ self.pending_commit = None;
+
+ Ok(())
+ }
+
+ fn mls_rules(&self) -> Self::MlsRules {
+ self.config.mls_rules()
+ }
+
+ fn identity_provider(&self) -> Self::IdentityProvider {
+ self.config.identity_provider()
+ }
+
+ fn psk_storage(&self) -> Self::PreSharedKeyStorage {
+ self.config.secret_store()
+ }
+
+ fn group_state(&self) -> &GroupState {
+ &self.state
+ }
+
+ fn group_state_mut(&mut self) -> &mut GroupState {
+ &mut self.state
+ }
+
+ fn can_continue_processing(&self, provisional_state: &ProvisionalState) -> bool {
+ !(provisional_state
+ .applied_proposals
+ .removals
+ .iter()
+ .any(|p| p.proposal.to_remove == self.private_tree.self_index)
+ && self.pending_commit.is_none())
+ }
+
+ #[cfg(feature = "private_message")]
+ fn min_epoch_available(&self) -> Option<u64> {
+ None
+ }
+
+ fn cipher_suite_provider(&self) -> &Self::CipherSuiteProvider {
+ &self.cipher_suite_provider
+ }
+}
+
+#[cfg(test)]
+pub(crate) mod test_utils;
+
+#[cfg(test)]
+mod tests {
+ use crate::{
+ client::test_utils::{
+ test_client_with_key_pkg, TestClientBuilder, TEST_CIPHER_SUITE,
+ TEST_CUSTOM_PROPOSAL_TYPE, TEST_PROTOCOL_VERSION,
+ },
+ client_builder::{test_utils::TestClientConfig, ClientBuilder, MlsConfig},
+ crypto::test_utils::TestCryptoProvider,
+ group::{
+ mls_rules::{CommitDirection, CommitSource},
+ proposal_filter::ProposalBundle,
+ },
+ identity::{
+ basic::BasicIdentityProvider,
+ test_utils::{get_test_signing_identity, BasicWithCustomProvider},
+ },
+ key_package::test_utils::test_key_package_message,
+ mls_rules::CommitOptions,
+ tree_kem::{
+ leaf_node::{test_utils::get_test_capabilities, LeafNodeSource},
+ UpdatePathNode,
+ },
+ };
+
+ #[cfg(any(feature = "private_message", feature = "custom_proposal"))]
+ use crate::group::mls_rules::DefaultMlsRules;
+
+ #[cfg(feature = "prior_epoch")]
+ use crate::group::padding::PaddingMode;
+
+ use crate::{extension::RequiredCapabilitiesExt, key_package::test_utils::test_key_package};
+
+ #[cfg(all(feature = "by_ref_proposal", feature = "custom_proposal"))]
+ use super::test_utils::test_group_custom_config;
+
+ #[cfg(feature = "psk")]
+ use crate::{client::Client, psk::PreSharedKey};
+
+ #[cfg(any(feature = "by_ref_proposal", feature = "private_message"))]
+ use crate::group::test_utils::random_bytes;
+
+ #[cfg(feature = "by_ref_proposal")]
+ use crate::{
+ extension::test_utils::TestExtension, identity::test_utils::get_test_basic_credential,
+ time::MlsTime,
+ };
+
+ use super::{
+ test_utils::{
+ get_test_25519_key, get_test_groups_with_features, group_extensions, process_commit,
+ test_group, test_group_custom, test_n_member_group, TestGroup, TEST_GROUP,
+ },
+ *,
+ };
+
+ use assert_matches::assert_matches;
+
+ use mls_rs_core::extension::{Extension, ExtensionType};
+ use mls_rs_core::identity::{Credential, CredentialType, CustomCredential};
+
+ #[cfg(feature = "by_ref_proposal")]
+ use mls_rs_core::identity::CertificateChain;
+
+ #[cfg(feature = "state_update")]
+ use itertools::Itertools;
+
+ #[cfg(feature = "state_update")]
+ use alloc::format;
+
+ #[cfg(feature = "by_ref_proposal")]
+ use crate::{crypto::test_utils::test_cipher_suite_provider, extension::ExternalSendersExt};
+
+ #[cfg(any(feature = "private_message", feature = "state_update"))]
+ use super::test_utils::test_member;
+
+ use mls_rs_core::extension::MlsExtension;
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_create_group() {
+ for (protocol_version, cipher_suite) in ProtocolVersion::all().flat_map(|p| {
+ TestCryptoProvider::all_supported_cipher_suites()
+ .into_iter()
+ .map(move |cs| (p, cs))
+ }) {
+ let test_group = test_group(protocol_version, cipher_suite).await;
+ let group = test_group.group;
+
+ assert_eq!(group.cipher_suite(), cipher_suite);
+ assert_eq!(group.state.context.epoch, 0);
+ assert_eq!(group.state.context.group_id, TEST_GROUP.to_vec());
+ assert_eq!(group.state.context.extensions, group_extensions());
+
+ assert_eq!(
+ group.state.context.confirmed_transcript_hash,
+ ConfirmedTranscriptHash::from(vec![])
+ );
+
+ #[cfg(feature = "private_message")]
+ assert!(group.state.proposals.is_empty());
+
+ #[cfg(feature = "by_ref_proposal")]
+ assert!(group.pending_updates.is_empty());
+
+ assert!(!group.has_pending_commit());
+
+ assert_eq!(
+ group.private_tree.self_index.0,
+ group.current_member_index()
+ );
+ }
+ }
+
+ #[cfg(feature = "private_message")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_pending_proposals_application_data() {
+ let mut test_group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+
+ // Create a proposal
+ let (bob_key_package, _) =
+ test_member(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, b"bob").await;
+
+ let proposal = test_group
+ .group
+ .add_proposal(bob_key_package.key_package_message())
+ .unwrap();
+
+ test_group
+ .group
+ .proposal_message(proposal, vec![])
+ .await
+ .unwrap();
+
+ // We should not be able to send application messages until a commit happens
+ let res = test_group
+ .group
+ .encrypt_application_message(b"test", vec![])
+ .await;
+
+ assert_matches!(res, Err(MlsError::CommitRequired));
+
+ // We should be able to send application messages after a commit
+ test_group.group.commit(vec![]).await.unwrap();
+
+ assert!(test_group.group.has_pending_commit());
+
+ test_group.group.apply_pending_commit().await.unwrap();
+
+ let res = test_group
+ .group
+ .encrypt_application_message(b"test", vec![])
+ .await;
+
+ assert!(res.is_ok());
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_update_proposals() {
+ let new_extension = TestExtension { foo: 10 };
+ let mut extension_list = ExtensionList::default();
+ extension_list.set_from(new_extension).unwrap();
+
+ let mut test_group = test_group_custom(
+ TEST_PROTOCOL_VERSION,
+ TEST_CIPHER_SUITE,
+ vec![42.into()],
+ Some(extension_list.clone()),
+ None,
+ )
+ .await;
+
+ let existing_leaf = test_group.group.current_user_leaf_node().unwrap().clone();
+
+ // Create an update proposal
+ let proposal = test_group.update_proposal().await;
+
+ let update = match proposal {
+ Proposal::Update(update) => update,
+ _ => panic!("non update proposal found"),
+ };
+
+ assert_ne!(update.leaf_node.public_key, existing_leaf.public_key);
+
+ assert_eq!(
+ update.leaf_node.signing_identity,
+ existing_leaf.signing_identity
+ );
+
+ assert_eq!(update.leaf_node.ungreased_extensions(), extension_list);
+ assert_eq!(
+ update.leaf_node.ungreased_capabilities().sorted(),
+ Capabilities {
+ extensions: vec![42.into()],
+ ..get_test_capabilities()
+ }
+ .sorted()
+ );
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_invalid_commit_self_update() {
+ let mut test_group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+
+ // Create an update proposal
+ let proposal_msg = test_group.group.propose_update(vec![]).await.unwrap();
+
+ let proposal = match proposal_msg.into_plaintext().unwrap().content.content {
+ Content::Proposal(p) => p,
+ _ => panic!("found non-proposal message"),
+ };
+
+ let update_leaf = match *proposal {
+ Proposal::Update(u) => u.leaf_node,
+ _ => panic!("found proposal message that isn't an update"),
+ };
+
+ test_group.group.commit(vec![]).await.unwrap();
+ test_group.group.apply_pending_commit().await.unwrap();
+
+ // The leaf node should not be the one from the update, because the committer rejects it
+ assert_ne!(
+ &update_leaf,
+ test_group.group.current_user_leaf_node().unwrap()
+ );
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn update_proposal_with_bad_key_package_is_ignored_when_committing() {
+ let (mut alice_group, mut bob_group) =
+ test_two_member_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, true).await;
+
+ let mut proposal = alice_group.update_proposal().await;
+
+ if let Proposal::Update(ref mut update) = proposal {
+ update.leaf_node.signature = random_bytes(32);
+ } else {
+ panic!("Invalid update proposal")
+ }
+
+ let proposal_message = alice_group
+ .group
+ .proposal_message(proposal.clone(), vec![])
+ .await
+ .unwrap();
+
+ let proposal_plaintext = match proposal_message.payload {
+ MlsMessagePayload::Plain(p) => p,
+ _ => panic!("Unexpected non-plaintext message"),
+ };
+
+ let proposal_ref = ProposalRef::from_content(
+ &bob_group.group.cipher_suite_provider,
+ &proposal_plaintext.clone().into(),
+ )
+ .await
+ .unwrap();
+
+ // Hack bob's receipt of the proposal
+ bob_group.group.state.proposals.insert(
+ proposal_ref,
+ proposal,
+ proposal_plaintext.content.sender,
+ );
+
+ let commit_output = bob_group.group.commit(vec![]).await.unwrap();
+
+ assert_matches!(
+ commit_output.commit_message,
+ MlsMessage {
+ payload: MlsMessagePayload::Plain(
+ PublicMessage {
+ content: FramedContent {
+ content: Content::Commit(c),
+ ..
+ },
+ ..
+ }),
+ ..
+ } if c.proposals.is_empty()
+ );
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn test_two_member_group(
+ protocol_version: ProtocolVersion,
+ cipher_suite: CipherSuite,
+ tree_ext: bool,
+ ) -> (TestGroup, TestGroup) {
+ let mut test_group = test_group_custom(
+ protocol_version,
+ cipher_suite,
+ Default::default(),
+ None,
+ Some(CommitOptions::new().with_ratchet_tree_extension(tree_ext)),
+ )
+ .await;
+
+ let (bob_test_group, _) = test_group.join("bob").await;
+
+ assert!(Group::equal_group_state(
+ &test_group.group,
+ &bob_test_group.group
+ ));
+
+ (test_group, bob_test_group)
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_welcome_processing_exported_tree() {
+ test_two_member_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, false).await;
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_welcome_processing_tree_extension() {
+ test_two_member_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, true).await;
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_welcome_processing_missing_tree() {
+ let mut test_group = test_group_custom(
+ TEST_PROTOCOL_VERSION,
+ TEST_CIPHER_SUITE,
+ Default::default(),
+ None,
+ Some(CommitOptions::new().with_ratchet_tree_extension(false)),
+ )
+ .await;
+
+ let (bob_client, bob_key_package) =
+ test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "bob").await;
+
+ // Add bob to the group
+ let commit_output = test_group
+ .group
+ .commit_builder()
+ .add_member(bob_key_package)
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ // Group from Bob's perspective
+ let bob_group = Group::join(
+ &commit_output.welcome_messages[0],
+ None,
+ bob_client.config,
+ bob_client.signer.unwrap(),
+ )
+ .await
+ .map(|_| ());
+
+ assert_matches!(bob_group, Err(MlsError::RatchetTreeNotFound));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_group_context_ext_proposal_create() {
+ let test_group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+
+ let mut extension_list = ExtensionList::new();
+ extension_list
+ .set_from(RequiredCapabilitiesExt {
+ extensions: vec![42.into()],
+ proposals: vec![],
+ credentials: vec![],
+ })
+ .unwrap();
+
+ let proposal = test_group
+ .group
+ .group_context_extensions_proposal(extension_list.clone());
+
+ assert_matches!(proposal, Proposal::GroupContextExtensions(ext) if ext == extension_list);
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn group_context_extension_proposal_test(
+ ext_list: ExtensionList,
+ ) -> (TestGroup, Result<MlsMessage, MlsError>) {
+ let protocol_version = TEST_PROTOCOL_VERSION;
+ let cipher_suite = TEST_CIPHER_SUITE;
+
+ let mut test_group =
+ test_group_custom(protocol_version, cipher_suite, vec![42.into()], None, None).await;
+
+ let commit = test_group
+ .group
+ .commit_builder()
+ .set_group_context_ext(ext_list)
+ .unwrap()
+ .build()
+ .await
+ .map(|commit_output| commit_output.commit_message);
+
+ (test_group, commit)
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_group_context_ext_proposal_commit() {
+ let mut extension_list = ExtensionList::new();
+
+ extension_list
+ .set_from(RequiredCapabilitiesExt {
+ extensions: vec![42.into()],
+ proposals: vec![],
+ credentials: vec![],
+ })
+ .unwrap();
+
+ let (mut test_group, _) =
+ group_context_extension_proposal_test(extension_list.clone()).await;
+
+ #[cfg(feature = "state_update")]
+ {
+ let update = test_group.group.apply_pending_commit().await.unwrap();
+ assert!(update.state_update.active);
+ }
+
+ #[cfg(not(feature = "state_update"))]
+ test_group.group.apply_pending_commit().await.unwrap();
+
+ assert_eq!(test_group.group.state.context.extensions, extension_list)
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_group_context_ext_proposal_invalid() {
+ let mut extension_list = ExtensionList::new();
+ extension_list
+ .set_from(RequiredCapabilitiesExt {
+ extensions: vec![999.into()],
+ proposals: vec![],
+ credentials: vec![],
+ })
+ .unwrap();
+
+ let (_, commit) = group_context_extension_proposal_test(extension_list.clone()).await;
+
+ assert_matches!(
+ commit,
+ Err(MlsError::RequiredExtensionNotFound(a)) if a == 999.into()
+ );
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn make_group_with_required_capabilities(
+ required_caps: RequiredCapabilitiesExt,
+ ) -> Result<Group<TestClientConfig>, MlsError> {
+ test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "alice")
+ .await
+ .0
+ .create_group(core::iter::once(required_caps.into_extension().unwrap()).collect())
+ .await
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn creating_group_with_member_not_supporting_required_credential_type_fails() {
+ let group_creation = make_group_with_required_capabilities(RequiredCapabilitiesExt {
+ credentials: vec![CredentialType::BASIC, CredentialType::X509],
+ ..Default::default()
+ })
+ .await
+ .map(|_| ());
+
+ assert_matches!(
+ group_creation,
+ Err(MlsError::RequiredCredentialNotFound(CredentialType::X509))
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn creating_group_with_member_not_supporting_required_extension_type_fails() {
+ const EXTENSION_TYPE: ExtensionType = ExtensionType::new(33);
+
+ let group_creation = make_group_with_required_capabilities(RequiredCapabilitiesExt {
+ extensions: vec![EXTENSION_TYPE],
+ ..Default::default()
+ })
+ .await
+ .map(|_| ());
+
+ assert_matches!(
+ group_creation,
+ Err(MlsError::RequiredExtensionNotFound(EXTENSION_TYPE))
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn creating_group_with_member_not_supporting_required_proposal_type_fails() {
+ const PROPOSAL_TYPE: ProposalType = ProposalType::new(33);
+
+ let group_creation = make_group_with_required_capabilities(RequiredCapabilitiesExt {
+ proposals: vec![PROPOSAL_TYPE],
+ ..Default::default()
+ })
+ .await
+ .map(|_| ());
+
+ assert_matches!(
+ group_creation,
+ Err(MlsError::RequiredProposalNotFound(PROPOSAL_TYPE))
+ );
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg(not(target_arch = "wasm32"))]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn creating_group_with_member_not_supporting_external_sender_credential_fails() {
+ let ext_senders = make_x509_external_senders_ext()
+ .await
+ .into_extension()
+ .unwrap();
+
+ let group_creation =
+ test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "alice")
+ .await
+ .0
+ .create_group(core::iter::once(ext_senders).collect())
+ .await
+ .map(|_| ());
+
+ assert_matches!(
+ group_creation,
+ Err(MlsError::RequiredCredentialNotFound(CredentialType::X509))
+ );
+ }
+
+ #[cfg(all(not(target_arch = "wasm32"), feature = "private_message"))]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_group_encrypt_plaintext_padding() {
+ let protocol_version = TEST_PROTOCOL_VERSION;
+ // This test requires a cipher suite whose signatures are not variable in length.
+ let cipher_suite = CipherSuite::CURVE25519_AES128;
+
+ let mut test_group = test_group_custom_config(protocol_version, cipher_suite, |b| {
+ b.mls_rules(
+ DefaultMlsRules::default()
+ .with_encryption_options(EncryptionOptions::new(true, PaddingMode::None)),
+ )
+ })
+ .await;
+
+ let without_padding = test_group
+ .group
+ .encrypt_application_message(&random_bytes(150), vec![])
+ .await
+ .unwrap();
+
+ let mut test_group =
+ test_group_custom_config(protocol_version, cipher_suite, |b| {
+ b.mls_rules(DefaultMlsRules::default().with_encryption_options(
+ EncryptionOptions::new(true, PaddingMode::StepFunction),
+ ))
+ })
+ .await;
+
+ let with_padding = test_group
+ .group
+ .encrypt_application_message(&random_bytes(150), vec![])
+ .await
+ .unwrap();
+
+ assert!(with_padding.mls_encoded_len() > without_padding.mls_encoded_len());
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn external_commit_requires_external_pub_extension() {
+ let protocol_version = TEST_PROTOCOL_VERSION;
+ let cipher_suite = TEST_CIPHER_SUITE;
+ let group = test_group(protocol_version, cipher_suite).await;
+
+ let info = group
+ .group
+ .group_info_message(false)
+ .await
+ .unwrap()
+ .into_group_info()
+ .unwrap();
+
+ let info_msg = MlsMessage::new(protocol_version, MlsMessagePayload::GroupInfo(info));
+
+ let signing_identity = group
+ .group
+ .current_member_signing_identity()
+ .unwrap()
+ .clone();
+
+ let res = external_commit::ExternalCommitBuilder::new(
+ group.group.signer,
+ signing_identity,
+ group.group.config,
+ )
+ .build(info_msg)
+ .await
+ .map(|_| {});
+
+ assert_matches!(res, Err(MlsError::MissingExternalPubExtension));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn external_commit_via_commit_options_round_trip() {
+ let mut group = test_group_custom(
+ TEST_PROTOCOL_VERSION,
+ TEST_CIPHER_SUITE,
+ vec![],
+ None,
+ CommitOptions::default()
+ .with_allow_external_commit(true)
+ .into(),
+ )
+ .await;
+
+ let commit_output = group.group.commit(vec![]).await.unwrap();
+
+ let (test_client, _) =
+ test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "bob").await;
+
+ test_client
+ .external_commit_builder()
+ .unwrap()
+ .build(commit_output.external_commit_group_info.unwrap())
+ .await
+ .unwrap();
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_path_update_preference() {
+ let protocol_version = TEST_PROTOCOL_VERSION;
+ let cipher_suite = TEST_CIPHER_SUITE;
+
+ let mut test_group = test_group_custom(
+ protocol_version,
+ cipher_suite,
+ Default::default(),
+ None,
+ Some(CommitOptions::new()),
+ )
+ .await;
+
+ let test_key_package =
+ test_key_package_message(protocol_version, cipher_suite, "alice").await;
+
+ test_group
+ .group
+ .commit_builder()
+ .add_member(test_key_package.clone())
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ assert!(test_group
+ .group
+ .pending_commit
+ .unwrap()
+ .pending_commit_secret
+ .iter()
+ .all(|x| x == &0));
+
+ let mut test_group = test_group_custom(
+ protocol_version,
+ cipher_suite,
+ Default::default(),
+ None,
+ Some(CommitOptions::new().with_path_required(true)),
+ )
+ .await;
+
+ test_group
+ .group
+ .commit_builder()
+ .add_member(test_key_package)
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ assert!(!test_group
+ .group
+ .pending_commit
+ .unwrap()
+ .pending_commit_secret
+ .iter()
+ .all(|x| x == &0));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_path_update_preference_override() {
+ let protocol_version = TEST_PROTOCOL_VERSION;
+ let cipher_suite = TEST_CIPHER_SUITE;
+
+ let mut test_group = test_group_custom(
+ protocol_version,
+ cipher_suite,
+ Default::default(),
+ None,
+ Some(CommitOptions::new()),
+ )
+ .await;
+
+ test_group.group.commit(vec![]).await.unwrap();
+
+ assert!(!test_group
+ .group
+ .pending_commit
+ .unwrap()
+ .pending_commit_secret
+ .iter()
+ .all(|x| x == &0));
+ }
+
+ #[cfg(feature = "private_message")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn group_rejects_unencrypted_application_message() {
+ let protocol_version = TEST_PROTOCOL_VERSION;
+ let cipher_suite = TEST_CIPHER_SUITE;
+
+ let mut alice = test_group(protocol_version, cipher_suite).await;
+ let (mut bob, _) = alice.join("bob").await;
+
+ let message = alice
+ .make_plaintext(Content::Application(b"hello".to_vec().into()))
+ .await;
+
+ let res = bob.group.process_incoming_message(message).await;
+
+ assert_matches!(res, Err(MlsError::UnencryptedApplicationMessage));
+ }
+
+ #[cfg(feature = "state_update")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_state_update() {
+ let protocol_version = TEST_PROTOCOL_VERSION;
+ let cipher_suite = TEST_CIPHER_SUITE;
+
+ // Create a group with 10 members
+ let mut alice = test_group(protocol_version, cipher_suite).await;
+ let (mut bob, _) = alice.join("bob").await;
+ let mut leaves = vec![];
+
+ for i in 0..8 {
+ let (group, commit) = alice.join(&format!("charlie{i}")).await;
+ leaves.push(group.group.current_user_leaf_node().unwrap().clone());
+ bob.process_message(commit).await.unwrap();
+ }
+
+ // Create many proposals, make Alice commit them
+
+ let update_message = bob.group.propose_update(vec![]).await.unwrap();
+
+ alice.process_message(update_message).await.unwrap();
+
+ let external_psk_ids: Vec<ExternalPskId> = (0..5)
+ .map(|i| {
+ let external_id = ExternalPskId::new(vec![i]);
+
+ alice
+ .group
+ .config
+ .secret_store()
+ .insert(ExternalPskId::new(vec![i]), PreSharedKey::from(vec![i]));
+
+ bob.group
+ .config
+ .secret_store()
+ .insert(ExternalPskId::new(vec![i]), PreSharedKey::from(vec![i]));
+
+ external_id
+ })
+ .collect();
+
+ let mut commit_builder = alice.group.commit_builder();
+
+ for external_psk in external_psk_ids {
+ commit_builder = commit_builder.add_external_psk(external_psk).unwrap();
+ }
+
+ for index in [2, 5, 6] {
+ commit_builder = commit_builder.remove_member(index).unwrap();
+ }
+
+ for i in 0..5 {
+ let (key_package, _) = test_member(
+ protocol_version,
+ cipher_suite,
+ format!("dave{i}").as_bytes(),
+ )
+ .await;
+
+ commit_builder = commit_builder
+ .add_member(key_package.key_package_message())
+ .unwrap()
+ }
+
+ let commit_output = commit_builder.build().await.unwrap();
+
+ let commit_description = alice.process_pending_commit().await.unwrap();
+
+ assert!(!commit_description.is_external);
+
+ assert_eq!(
+ commit_description.committer,
+ alice.group.current_member_index()
+ );
+
+ // Check that applying pending commit and processing commit yields correct update.
+ let state_update_alice = commit_description.state_update.clone();
+
+ assert_eq!(
+ state_update_alice
+ .roster_update
+ .added()
+ .iter()
+ .map(|m| m.index)
+ .collect::<Vec<_>>(),
+ vec![2, 5, 6, 10, 11]
+ );
+
+ assert_eq!(
+ state_update_alice.roster_update.removed(),
+ vec![2, 5, 6]
+ .into_iter()
+ .map(|i| member_from_leaf_node(&leaves[i as usize - 2], LeafIndex(i)))
+ .collect::<Vec<_>>()
+ );
+
+ assert_eq!(
+ state_update_alice
+ .roster_update
+ .updated()
+ .iter()
+ .map(|update| update.new.clone())
+ .collect_vec()
+ .as_slice(),
+ &alice.group.roster().members()[0..2]
+ );
+
+ assert_eq!(
+ state_update_alice.added_psks,
+ (0..5)
+ .map(|i| ExternalPskId::new(vec![i]))
+ .collect::<Vec<_>>()
+ );
+
+ let payload = bob
+ .process_message(commit_output.commit_message)
+ .await
+ .unwrap();
+
+ let ReceivedMessage::Commit(bob_commit_description) = payload else {
+ panic!("expected commit");
+ };
+
+ assert_eq!(commit_description, bob_commit_description);
+ }
+
+ #[cfg(feature = "state_update")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn commit_description_external_commit() {
+ use crate::client::test_utils::TestClientBuilder;
+
+ let mut alice_group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+
+ let (bob_identity, secret_key) = get_test_signing_identity(TEST_CIPHER_SUITE, b"bob").await;
+
+ let bob = TestClientBuilder::new_for_test()
+ .signing_identity(bob_identity, secret_key, TEST_CIPHER_SUITE)
+ .build();
+
+ let (bob_group, commit) = bob
+ .external_commit_builder()
+ .unwrap()
+ .build(
+ alice_group
+ .group
+ .group_info_message_allowing_ext_commit(true)
+ .await
+ .unwrap(),
+ )
+ .await
+ .unwrap();
+
+ let event = alice_group.process_message(commit).await.unwrap();
+
+ let ReceivedMessage::Commit(commit_description) = event else {
+ panic!("expected commit");
+ };
+
+ assert!(commit_description.is_external);
+ assert_eq!(commit_description.committer, 1);
+
+ assert_eq!(
+ commit_description.state_update.roster_update.added(),
+ &bob_group.roster().members()[1..2]
+ );
+
+ itertools::assert_equal(
+ bob_group.roster().members_iter(),
+ alice_group.group.roster().members_iter(),
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn can_join_new_group_externally() {
+ use crate::client::test_utils::TestClientBuilder;
+
+ let mut alice_group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+
+ let (bob_identity, secret_key) = get_test_signing_identity(TEST_CIPHER_SUITE, b"bob").await;
+
+ let bob = TestClientBuilder::new_for_test()
+ .signing_identity(bob_identity, secret_key, TEST_CIPHER_SUITE)
+ .build();
+
+ let (_, commit) = bob
+ .external_commit_builder()
+ .unwrap()
+ .with_tree_data(alice_group.group.export_tree().into_owned())
+ .build(
+ alice_group
+ .group
+ .group_info_message_allowing_ext_commit(false)
+ .await
+ .unwrap(),
+ )
+ .await
+ .unwrap();
+
+ alice_group.process_message(commit).await.unwrap();
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_membership_tag_from_non_member() {
+ let (mut alice_group, mut bob_group) =
+ test_two_member_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, true).await;
+
+ let mut commit_output = alice_group.group.commit(vec![]).await.unwrap();
+
+ let plaintext = match commit_output.commit_message.payload {
+ MlsMessagePayload::Plain(ref mut plain) => plain,
+ _ => panic!("Non plaintext message"),
+ };
+
+ plaintext.content.sender = Sender::NewMemberCommit;
+
+ let res = bob_group
+ .process_message(commit_output.commit_message)
+ .await;
+
+ assert_matches!(res, Err(MlsError::MembershipTagForNonMember));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_partial_commits() {
+ let protocol_version = TEST_PROTOCOL_VERSION;
+ let cipher_suite = TEST_CIPHER_SUITE;
+
+ let mut alice = test_group(protocol_version, cipher_suite).await;
+ let (mut bob, _) = alice.join("bob").await;
+ let (mut charlie, commit) = alice.join("charlie").await;
+ bob.process_message(commit).await.unwrap();
+
+ let (_, commit) = charlie.join("dave").await;
+
+ alice.process_message(commit.clone()).await.unwrap();
+ bob.process_message(commit.clone()).await.unwrap();
+
+ let Content::Commit(commit) = commit.into_plaintext().unwrap().content.content else {
+ panic!("Expected commit")
+ };
+
+ assert!(commit.path.is_none());
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn group_with_path_required() -> TestGroup {
+ let mut alice = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+
+ alice.group.config.0.mls_rules.commit_options.path_required = true;
+
+ alice
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn old_hpke_secrets_are_removed() {
+ let mut alice = group_with_path_required().await;
+ alice.join("bob").await;
+ alice.join("charlie").await;
+
+ alice
+ .group
+ .commit_builder()
+ .remove_member(1)
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ assert!(alice.group.private_tree.secret_keys[1].is_some());
+ alice.process_pending_commit().await.unwrap();
+ assert!(alice.group.private_tree.secret_keys[1].is_none());
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn old_hpke_secrets_of_removed_are_removed() {
+ let mut alice = group_with_path_required().await;
+ alice.join("bob").await;
+ let (mut charlie, _) = alice.join("charlie").await;
+
+ let commit = charlie
+ .group
+ .commit_builder()
+ .remove_member(1)
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ assert!(alice.group.private_tree.secret_keys[1].is_some());
+ alice.process_message(commit.commit_message).await.unwrap();
+ assert!(alice.group.private_tree.secret_keys[1].is_none());
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn old_hpke_secrets_of_updated_are_removed() {
+ let mut alice = group_with_path_required().await;
+ let (mut bob, _) = alice.join("bob").await;
+ let (mut charlie, commit) = alice.join("charlie").await;
+ bob.process_message(commit).await.unwrap();
+
+ let update = bob.group.propose_update(vec![]).await.unwrap();
+ charlie.process_message(update.clone()).await.unwrap();
+ alice.process_message(update).await.unwrap();
+
+ let commit = charlie.group.commit(vec![]).await.unwrap();
+
+ assert!(alice.group.private_tree.secret_keys[1].is_some());
+ alice.process_message(commit.commit_message).await.unwrap();
+ assert!(alice.group.private_tree.secret_keys[1].is_none());
+ }
+
+ #[cfg(feature = "psk")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn only_selected_members_of_the_original_group_can_join_subgroup() {
+ let mut alice = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let (mut bob, _) = alice.join("bob").await;
+ let (carol, commit) = alice.join("carol").await;
+
+ // Apply the commit that adds carol
+ bob.group.process_incoming_message(commit).await.unwrap();
+
+ let bob_identity = bob.group.current_member_signing_identity().unwrap().clone();
+ let signer = bob.group.signer.clone();
+
+ let new_key_pkg = Client::new(
+ bob.group.config.clone(),
+ Some(signer),
+ Some((bob_identity, TEST_CIPHER_SUITE)),
+ TEST_PROTOCOL_VERSION,
+ )
+ .generate_key_package_message()
+ .await
+ .unwrap();
+
+ let (mut alice_sub_group, welcome) = alice
+ .group
+ .branch(b"subgroup".to_vec(), vec![new_key_pkg])
+ .await
+ .unwrap();
+
+ let welcome = &welcome[0];
+
+ let (mut bob_sub_group, _) = bob.group.join_subgroup(welcome, None).await.unwrap();
+
+ // Carol can't join
+ let res = carol.group.join_subgroup(welcome, None).await.map(|_| ());
+ assert_matches!(res, Err(_));
+
+ // Alice and Bob can still talk
+ let commit_output = alice_sub_group.commit(vec![]).await.unwrap();
+
+ bob_sub_group
+ .process_incoming_message(commit_output.commit_message)
+ .await
+ .unwrap();
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn joining_group_fails_if_unsupported<F>(
+ f: F,
+ ) -> Result<(TestGroup, MlsMessage), MlsError>
+ where
+ F: FnMut(&mut TestClientConfig),
+ {
+ let mut alice_group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ alice_group.join_with_custom_config("alice", false, f).await
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn joining_group_fails_if_protocol_version_is_not_supported() {
+ let res = joining_group_fails_if_unsupported(|config| {
+ config.0.settings.protocol_versions.clear();
+ })
+ .await
+ .map(|_| ());
+
+ assert_matches!(
+ res,
+ Err(MlsError::UnsupportedProtocolVersion(v)) if v ==
+ TEST_PROTOCOL_VERSION
+ );
+ }
+
+ // WebCrypto does not support disabling ciphersuites
+ #[cfg(not(target_arch = "wasm32"))]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn joining_group_fails_if_cipher_suite_is_not_supported() {
+ let res = joining_group_fails_if_unsupported(|config| {
+ config
+ .0
+ .crypto_provider
+ .enabled_cipher_suites
+ .retain(|&x| x != TEST_CIPHER_SUITE);
+ })
+ .await
+ .map(|_| ());
+
+ assert_matches!(
+ res,
+ Err(MlsError::UnsupportedCipherSuite(TEST_CIPHER_SUITE))
+ );
+ }
+
+ #[cfg(feature = "private_message")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn member_can_see_sender_creds() {
+ let mut alice_group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let (mut bob_group, _) = alice_group.join("bob").await;
+
+ let bob_msg = b"I'm Bob";
+
+ let msg = bob_group
+ .group
+ .encrypt_application_message(bob_msg, vec![])
+ .await
+ .unwrap();
+
+ let received_by_alice = alice_group
+ .group
+ .process_incoming_message(msg)
+ .await
+ .unwrap();
+
+ assert_matches!(
+ received_by_alice,
+ ReceivedMessage::ApplicationMessage(ApplicationMessageDescription { sender_index, .. })
+ if sender_index == bob_group.group.current_member_index()
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn members_of_a_group_have_identical_authentication_secrets() {
+ let mut alice_group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let (bob_group, _) = alice_group.join("bob").await;
+
+ assert_eq!(
+ alice_group.group.epoch_authenticator().unwrap(),
+ bob_group.group.epoch_authenticator().unwrap()
+ );
+ }
+
+ #[cfg(feature = "private_message")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn member_cannot_decrypt_same_message_twice() {
+ let mut alice_group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let (mut bob_group, _) = alice_group.join("bob").await;
+
+ let message = alice_group
+ .group
+ .encrypt_application_message(b"foobar", Vec::new())
+ .await
+ .unwrap();
+
+ let received_message = bob_group
+ .group
+ .process_incoming_message(message.clone())
+ .await
+ .unwrap();
+
+ assert_matches!(
+ received_message,
+ ReceivedMessage::ApplicationMessage(m) if m.data() == b"foobar"
+ );
+
+ let res = bob_group.group.process_incoming_message(message).await;
+
+ assert_matches!(res, Err(MlsError::KeyMissing(0)));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn removing_requirements_allows_to_add() {
+ let mut alice_group = test_group_custom(
+ TEST_PROTOCOL_VERSION,
+ TEST_CIPHER_SUITE,
+ vec![17.into()],
+ None,
+ None,
+ )
+ .await;
+
+ alice_group
+ .group
+ .commit_builder()
+ .set_group_context_ext(
+ vec![RequiredCapabilitiesExt {
+ extensions: vec![17.into()],
+ ..Default::default()
+ }
+ .into_extension()
+ .unwrap()]
+ .try_into()
+ .unwrap(),
+ )
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ alice_group.process_pending_commit().await.unwrap();
+
+ let test_key_package =
+ test_key_package(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "bob").await;
+
+ let test_key_package = MlsMessage::new(
+ TEST_PROTOCOL_VERSION,
+ MlsMessagePayload::KeyPackage(test_key_package),
+ );
+
+ alice_group
+ .group
+ .commit_builder()
+ .add_member(test_key_package)
+ .unwrap()
+ .set_group_context_ext(Default::default())
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ let state_update = alice_group
+ .process_pending_commit()
+ .await
+ .unwrap()
+ .state_update;
+
+ #[cfg(feature = "state_update")]
+ assert_eq!(
+ state_update
+ .roster_update
+ .added()
+ .iter()
+ .map(|m| m.index)
+ .collect::<Vec<_>>(),
+ vec![1]
+ );
+
+ #[cfg(not(feature = "state_update"))]
+ assert!(state_update == StateUpdate {});
+
+ assert_eq!(alice_group.group.roster().members_iter().count(), 2);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn commit_leaf_wrong_source() {
+ // RFC, 13.4.2. "The leaf_node_source field MUST be set to commit."
+ let mut groups = test_n_member_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, 3).await;
+
+ groups[0].group.commit_modifiers.modify_leaf = |leaf, sk| {
+ leaf.leaf_node_source = LeafNodeSource::Update;
+ Some(sk.clone())
+ };
+
+ let commit_output = groups[0].group.commit(vec![]).await.unwrap();
+
+ let res = groups[2]
+ .process_message(commit_output.commit_message)
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidLeafNodeSource));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn commit_leaf_same_hpke_key() {
+ // RFC 13.4.2. "Verify that the encryption_key value in the LeafNode is different from the committer's current leaf node"
+
+ let mut groups = test_n_member_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, 3).await;
+
+ // Group 0 starts using fixed key
+ groups[0].group.commit_modifiers.modify_leaf = |leaf, sk| {
+ leaf.public_key = get_test_25519_key(1u8);
+ Some(sk.clone())
+ };
+
+ let commit_output = groups[0].group.commit(vec![]).await.unwrap();
+ groups[0].process_pending_commit().await.unwrap();
+ groups[2]
+ .process_message(commit_output.commit_message)
+ .await
+ .unwrap();
+
+ // Group 0 tries to use the fixed key againd
+ let commit_output = groups[0].group.commit(vec![]).await.unwrap();
+
+ let res = groups[2]
+ .process_message(commit_output.commit_message)
+ .await;
+
+ assert_matches!(res, Err(MlsError::SameHpkeKey(0)));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn commit_leaf_duplicate_hpke_key() {
+ // RFC 8.3 "Verify that the following fields are unique among the members of the group: `encryption_key`"
+
+ if TEST_CIPHER_SUITE != CipherSuite::CURVE25519_AES128
+ && TEST_CIPHER_SUITE != CipherSuite::CURVE25519_CHACHA
+ {
+ return;
+ }
+
+ let mut groups = test_n_member_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, 10).await;
+
+ // Group 1 uses the fixed key
+ groups[1].group.commit_modifiers.modify_leaf = |leaf, sk| {
+ leaf.public_key = get_test_25519_key(1u8);
+ Some(sk.clone())
+ };
+
+ let commit_output = groups
+ .get_mut(1)
+ .unwrap()
+ .group
+ .commit(vec![])
+ .await
+ .unwrap();
+
+ process_commit(&mut groups, commit_output.commit_message, 1).await;
+
+ // Group 0 tries to use the fixed key too
+ groups[0].group.commit_modifiers.modify_leaf = |leaf, sk| {
+ leaf.public_key = get_test_25519_key(1u8);
+ Some(sk.clone())
+ };
+
+ let commit_output = groups[0].group.commit(vec![]).await.unwrap();
+
+ let res = groups[7]
+ .process_message(commit_output.commit_message)
+ .await;
+
+ assert_matches!(res, Err(MlsError::DuplicateLeafData(_)));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn commit_leaf_duplicate_signature_key() {
+ // RFC 8.3 "Verify that the following fields are unique among the members of the group: `signature_key`"
+
+ if TEST_CIPHER_SUITE != CipherSuite::CURVE25519_AES128
+ && TEST_CIPHER_SUITE != CipherSuite::CURVE25519_CHACHA
+ {
+ return;
+ }
+
+ let mut groups = test_n_member_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, 10).await;
+
+ // Group 1 uses the fixed key
+ groups[1].group.commit_modifiers.modify_leaf = |leaf, _| {
+ let sk = hex!(
+ "3468b4c890255c983e3d5cbf5cb64c1ef7f6433a518f2f3151d6672f839a06ebcad4fc381fe61822af45135c82921a348e6f46643d66ddefc70483565433714b"
+ )
+ .into();
+
+ leaf.signing_identity.signature_key =
+ hex!("cad4fc381fe61822af45135c82921a348e6f46643d66ddefc70483565433714b").into();
+
+ Some(sk)
+ };
+
+ let commit_output = groups
+ .get_mut(1)
+ .unwrap()
+ .group
+ .commit(vec![])
+ .await
+ .unwrap();
+
+ process_commit(&mut groups, commit_output.commit_message, 1).await;
+
+ // Group 0 tries to use the fixed key too
+ groups[0].group.commit_modifiers.modify_leaf = |leaf, _| {
+ let sk = hex!(
+ "3468b4c890255c983e3d5cbf5cb64c1ef7f6433a518f2f3151d6672f839a06ebcad4fc381fe61822af45135c82921a348e6f46643d66ddefc70483565433714b"
+ )
+ .into();
+
+ leaf.signing_identity.signature_key =
+ hex!("cad4fc381fe61822af45135c82921a348e6f46643d66ddefc70483565433714b").into();
+
+ Some(sk)
+ };
+
+ let commit_output = groups[0].group.commit(vec![]).await.unwrap();
+
+ let res = groups[7]
+ .process_message(commit_output.commit_message)
+ .await;
+
+ assert_matches!(res, Err(MlsError::DuplicateLeafData(_)));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn commit_leaf_incorrect_signature() {
+ let mut groups = test_n_member_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, 3).await;
+
+ groups[0].group.commit_modifiers.modify_leaf = |leaf, _| {
+ leaf.signature[0] ^= 1;
+ None
+ };
+
+ let commit_output = groups[0].group.commit(vec![]).await.unwrap();
+
+ let res = groups[2]
+ .process_message(commit_output.commit_message)
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidSignature));
+ }
+
+ #[cfg(not(target_arch = "wasm32"))]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn commit_leaf_not_supporting_used_context_extension() {
+ const EXT_TYPE: ExtensionType = ExtensionType::new(999);
+
+ // The new leaf of the committer doesn't support an extension set in group context
+ let extension = Extension::new(EXT_TYPE, vec![]);
+
+ let mut groups =
+ get_test_groups_with_features(3, vec![extension].into(), Default::default()).await;
+
+ groups[0].commit_modifiers.modify_leaf = |leaf, sk| {
+ leaf.capabilities = get_test_capabilities();
+ Some(sk.clone())
+ };
+
+ let commit_output = groups[0].commit(vec![]).await.unwrap();
+
+ let res = groups[1]
+ .process_incoming_message(commit_output.commit_message)
+ .await;
+
+ assert_matches!(res, Err(MlsError::UnsupportedGroupExtension(EXT_TYPE)));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn commit_leaf_not_supporting_required_extension() {
+ // The new leaf of the committer doesn't support an extension required by group context
+
+ let extension = RequiredCapabilitiesExt {
+ extensions: vec![999.into()],
+ proposals: vec![],
+ credentials: vec![],
+ };
+
+ let extensions = vec![extension.into_extension().unwrap()];
+ let mut groups =
+ get_test_groups_with_features(3, extensions.into(), Default::default()).await;
+
+ groups[0].commit_modifiers.modify_leaf = |leaf, sk| {
+ leaf.capabilities = Capabilities::default();
+ Some(sk.clone())
+ };
+
+ let commit_output = groups[0].commit(vec![]).await.unwrap();
+
+ let res = groups[2]
+ .process_incoming_message(commit_output.commit_message)
+ .await;
+
+ assert!(res.is_err());
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn commit_leaf_has_unsupported_credential() {
+ // The new leaf of the committer has a credential unsupported by another leaf
+ let mut groups =
+ get_test_groups_with_features(3, Default::default(), Default::default()).await;
+
+ for group in groups.iter_mut() {
+ group.config.0.identity_provider.allow_any_custom = true;
+ }
+
+ groups[0].commit_modifiers.modify_leaf = |leaf, sk| {
+ leaf.signing_identity.credential = Credential::Custom(CustomCredential::new(
+ CredentialType::new(43),
+ leaf.signing_identity
+ .credential
+ .as_basic()
+ .unwrap()
+ .identifier
+ .to_vec(),
+ ));
+
+ Some(sk.clone())
+ };
+
+ let commit_output = groups[0].commit(vec![]).await.unwrap();
+
+ let res = groups[2]
+ .process_incoming_message(commit_output.commit_message)
+ .await;
+
+ assert_matches!(res, Err(MlsError::CredentialTypeOfNewLeafIsUnsupported));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn commit_leaf_not_supporting_credential_used_in_another_leaf() {
+ // The new leaf of the committer doesn't support another leaf's credential
+
+ let mut groups =
+ get_test_groups_with_features(3, Default::default(), Default::default()).await;
+
+ groups[0].commit_modifiers.modify_leaf = |leaf, sk| {
+ leaf.capabilities.credentials = vec![2.into()];
+ Some(sk.clone())
+ };
+
+ let commit_output = groups[0].commit(vec![]).await.unwrap();
+
+ let res = groups[2]
+ .process_incoming_message(commit_output.commit_message)
+ .await;
+
+ assert_matches!(res, Err(MlsError::InUseCredentialTypeUnsupportedByNewLeaf));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn commit_leaf_not_supporting_required_credential() {
+ // The new leaf of the committer doesn't support a credential required by group context
+
+ let extension = RequiredCapabilitiesExt {
+ extensions: vec![],
+ proposals: vec![],
+ credentials: vec![1.into()],
+ };
+
+ let extensions = vec![extension.into_extension().unwrap()];
+ let mut groups =
+ get_test_groups_with_features(3, extensions.into(), Default::default()).await;
+
+ groups[0].commit_modifiers.modify_leaf = |leaf, sk| {
+ leaf.capabilities.credentials = vec![2.into()];
+ Some(sk.clone())
+ };
+
+ let commit_output = groups[0].commit(vec![]).await.unwrap();
+
+ let res = groups[2]
+ .process_incoming_message(commit_output.commit_message)
+ .await;
+
+ assert_matches!(res, Err(MlsError::RequiredCredentialNotFound(_)));
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg(not(target_arch = "wasm32"))]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn make_x509_external_senders_ext() -> ExternalSendersExt {
+ let (_, ext_sender_pk) = test_cipher_suite_provider(TEST_CIPHER_SUITE)
+ .signature_key_generate()
+ .await
+ .unwrap();
+
+ let ext_sender_id = SigningIdentity {
+ signature_key: ext_sender_pk,
+ credential: Credential::X509(CertificateChain::from(vec![random_bytes(32)])),
+ };
+
+ ExternalSendersExt::new(vec![ext_sender_id])
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg(not(target_arch = "wasm32"))]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn commit_leaf_not_supporting_external_sender_credential_leads_to_rejected_commit() {
+ let ext_senders = make_x509_external_senders_ext()
+ .await
+ .into_extension()
+ .unwrap();
+
+ let mut alice = ClientBuilder::new()
+ .crypto_provider(TestCryptoProvider::new())
+ .identity_provider(
+ BasicWithCustomProvider::default().with_credential_type(CredentialType::X509),
+ )
+ .with_random_signing_identity("alice", TEST_CIPHER_SUITE)
+ .await
+ .build()
+ .create_group(core::iter::once(ext_senders).collect())
+ .await
+ .unwrap();
+
+ // New leaf supports only basic credentials (used by the group) but not X509 used by external sender
+ alice.commit_modifiers.modify_leaf = |leaf, sk| {
+ leaf.capabilities.credentials = vec![CredentialType::BASIC];
+ Some(sk.clone())
+ };
+
+ alice.commit(vec![]).await.unwrap();
+ let res = alice.apply_pending_commit().await;
+
+ assert_matches!(
+ res,
+ Err(MlsError::RequiredCredentialNotFound(CredentialType::X509))
+ );
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg(not(target_arch = "wasm32"))]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn node_not_supporting_external_sender_credential_cannot_join_group() {
+ let ext_senders = make_x509_external_senders_ext()
+ .await
+ .into_extension()
+ .unwrap();
+
+ let mut alice = ClientBuilder::new()
+ .crypto_provider(TestCryptoProvider::new())
+ .identity_provider(
+ BasicWithCustomProvider::default().with_credential_type(CredentialType::X509),
+ )
+ .with_random_signing_identity("alice", TEST_CIPHER_SUITE)
+ .await
+ .build()
+ .create_group(core::iter::once(ext_senders).collect())
+ .await
+ .unwrap();
+
+ let (_, bob_key_pkg) =
+ test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "bob").await;
+
+ let commit = alice
+ .commit_builder()
+ .add_member(bob_key_pkg)
+ .unwrap()
+ .build()
+ .await;
+
+ assert_matches!(
+ commit,
+ Err(MlsError::RequiredCredentialNotFound(CredentialType::X509))
+ );
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg(not(target_arch = "wasm32"))]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn external_senders_extension_is_rejected_if_member_does_not_support_credential_type() {
+ let mut alice = ClientBuilder::new()
+ .crypto_provider(TestCryptoProvider::new())
+ .identity_provider(
+ BasicWithCustomProvider::default().with_credential_type(CredentialType::X509),
+ )
+ .with_random_signing_identity("alice", TEST_CIPHER_SUITE)
+ .await
+ .build()
+ .create_group(Default::default())
+ .await
+ .unwrap();
+
+ let (_, bob_key_pkg) =
+ test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "bob").await;
+
+ alice
+ .commit_builder()
+ .add_member(bob_key_pkg)
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ alice.apply_pending_commit().await.unwrap();
+ assert_eq!(alice.roster().members_iter().count(), 2);
+
+ let ext_senders = make_x509_external_senders_ext()
+ .await
+ .into_extension()
+ .unwrap();
+
+ let res = alice
+ .commit_builder()
+ .set_group_context_ext(core::iter::once(ext_senders).collect())
+ .unwrap()
+ .build()
+ .await;
+
+ assert_matches!(
+ res,
+ Err(MlsError::RequiredCredentialNotFound(CredentialType::X509))
+ );
+ }
+
+ /*
+ * Edge case paths
+ */
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn committing_degenerate_path_succeeds() {
+ let mut groups = test_n_member_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, 10).await;
+
+ groups[0].group.commit_modifiers.modify_tree = |tree: &mut TreeKemPublic| {
+ tree.update_node(get_test_25519_key(1u8), 1).unwrap();
+ tree.update_node(get_test_25519_key(1u8), 3).unwrap();
+ };
+
+ groups[0].group.commit_modifiers.modify_leaf = |leaf, sk| {
+ leaf.public_key = get_test_25519_key(1u8);
+ Some(sk.clone())
+ };
+
+ let commit_output = groups[0].group.commit(vec![]).await.unwrap();
+
+ let res = groups[7]
+ .process_message(commit_output.commit_message)
+ .await;
+
+ assert!(res.is_ok());
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn inserting_key_in_filtered_node_fails() {
+ let mut groups = test_n_member_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, 10).await;
+
+ let commit_output = groups[0]
+ .group
+ .commit_builder()
+ .remove_member(1)
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ groups[0].process_pending_commit().await.unwrap();
+
+ for group in groups.iter_mut().skip(2) {
+ group
+ .process_message(commit_output.commit_message.clone())
+ .await
+ .unwrap();
+ }
+
+ groups[0].group.commit_modifiers.modify_tree = |tree: &mut TreeKemPublic| {
+ tree.update_node(get_test_25519_key(1u8), 1).unwrap();
+ };
+
+ groups[0].group.commit_modifiers.modify_path = |path: Vec<UpdatePathNode>| {
+ let mut path = path;
+ let mut node = path[0].clone();
+ node.public_key = get_test_25519_key(1u8);
+ path.insert(0, node);
+ path
+ };
+
+ let commit_output = groups[0].group.commit(vec![]).await.unwrap();
+
+ let res = groups[7]
+ .process_message(commit_output.commit_message)
+ .await;
+
+ // We should get a path validation error, since the path is too long
+ assert_matches!(res, Err(MlsError::WrongPathLen));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn commit_with_too_short_path_fails() {
+ let mut groups = test_n_member_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, 10).await;
+
+ let commit_output = groups[0]
+ .group
+ .commit_builder()
+ .remove_member(1)
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ groups[0].process_pending_commit().await.unwrap();
+
+ for group in groups.iter_mut().skip(2) {
+ group
+ .process_message(commit_output.commit_message.clone())
+ .await
+ .unwrap();
+ }
+
+ groups[0].group.commit_modifiers.modify_path = |path: Vec<UpdatePathNode>| {
+ let mut path = path;
+ path.pop();
+ path
+ };
+
+ let commit_output = groups[0].group.commit(vec![]).await.unwrap();
+
+ let res = groups[7]
+ .process_message(commit_output.commit_message)
+ .await;
+
+ assert!(res.is_err());
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn update_proposal_can_change_credential() {
+ let mut groups = test_n_member_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, 3).await;
+ let (identity, secret_key) = get_test_signing_identity(TEST_CIPHER_SUITE, b"member").await;
+
+ let update = groups[0]
+ .group
+ .propose_update_with_identity(secret_key, identity.clone(), vec![])
+ .await
+ .unwrap();
+
+ groups[1].process_message(update).await.unwrap();
+ let commit_output = groups[1].group.commit(vec![]).await.unwrap();
+
+ // Check that the credential was updated by in the committer's state.
+ groups[1].process_pending_commit().await.unwrap();
+ let new_member = groups[1].group.roster().member_with_index(0).unwrap();
+
+ assert_eq!(
+ new_member.signing_identity.credential,
+ get_test_basic_credential(b"member".to_vec())
+ );
+
+ assert_eq!(
+ new_member.signing_identity.signature_key,
+ identity.signature_key
+ );
+
+ // Check that the credential was updated in the updater's state.
+ groups[0]
+ .process_message(commit_output.commit_message)
+ .await
+ .unwrap();
+ let new_member = groups[0].group.roster().member_with_index(0).unwrap();
+
+ assert_eq!(
+ new_member.signing_identity.credential,
+ get_test_basic_credential(b"member".to_vec())
+ );
+
+ assert_eq!(
+ new_member.signing_identity.signature_key,
+ identity.signature_key
+ );
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_commit_with_old_adds_fails() {
+ let mut groups = test_n_member_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, 2).await;
+
+ let key_package =
+ test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "foobar").await;
+
+ let proposal = groups[0]
+ .group
+ .propose_add(key_package, vec![])
+ .await
+ .unwrap();
+
+ let commit = groups[0].group.commit(vec![]).await.unwrap().commit_message;
+
+ // 10 years from now
+ let future_time = MlsTime::now().seconds_since_epoch() + 10 * 365 * 24 * 3600;
+
+ let future_time =
+ MlsTime::from_duration_since_epoch(core::time::Duration::from_secs(future_time));
+
+ groups[1]
+ .group
+ .process_incoming_message(proposal)
+ .await
+ .unwrap();
+ let res = groups[1]
+ .group
+ .process_incoming_message_with_time(commit, future_time)
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidLifetime));
+ }
+
+ #[cfg(feature = "custom_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn custom_proposal_setup() -> (TestGroup, TestGroup) {
+ let mut alice = test_group_custom_config(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, |b| {
+ b.custom_proposal_type(TEST_CUSTOM_PROPOSAL_TYPE)
+ })
+ .await;
+
+ let (bob, _) = alice
+ .join_with_custom_config("bob", true, |c| {
+ c.0.settings
+ .custom_proposal_types
+ .push(TEST_CUSTOM_PROPOSAL_TYPE)
+ })
+ .await
+ .unwrap();
+
+ (alice, bob)
+ }
+
+ #[cfg(feature = "custom_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn custom_proposal_by_value() {
+ let (mut alice, mut bob) = custom_proposal_setup().await;
+
+ let custom_proposal = CustomProposal::new(TEST_CUSTOM_PROPOSAL_TYPE, vec![0, 1, 2]);
+
+ let commit = alice
+ .group
+ .commit_builder()
+ .custom_proposal(custom_proposal.clone())
+ .build()
+ .await
+ .unwrap()
+ .commit_message;
+
+ let res = bob.group.process_incoming_message(commit).await.unwrap();
+
+ #[cfg(feature = "state_update")]
+ assert_matches!(res, ReceivedMessage::Commit(CommitMessageDescription { state_update: StateUpdate { custom_proposals, .. }, .. })
+ if custom_proposals.len() == 1 && custom_proposals[0].proposal == custom_proposal);
+
+ #[cfg(not(feature = "state_update"))]
+ assert_matches!(res, ReceivedMessage::Commit(_));
+ }
+
+ #[cfg(feature = "custom_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn custom_proposal_by_reference() {
+ let (mut alice, mut bob) = custom_proposal_setup().await;
+
+ let custom_proposal = CustomProposal::new(TEST_CUSTOM_PROPOSAL_TYPE, vec![0, 1, 2]);
+
+ let proposal = alice
+ .group
+ .propose_custom(custom_proposal.clone(), vec![])
+ .await
+ .unwrap();
+
+ let recv_prop = bob.group.process_incoming_message(proposal).await.unwrap();
+
+ assert_matches!(recv_prop, ReceivedMessage::Proposal(ProposalMessageDescription { proposal: Proposal::Custom(c), ..})
+ if c == custom_proposal);
+
+ let commit = bob.group.commit(vec![]).await.unwrap().commit_message;
+ let res = alice.group.process_incoming_message(commit).await.unwrap();
+
+ #[cfg(feature = "state_update")]
+ assert_matches!(res, ReceivedMessage::Commit(CommitMessageDescription { state_update: StateUpdate { custom_proposals, .. }, .. })
+ if custom_proposals.len() == 1 && custom_proposals[0].proposal == custom_proposal);
+
+ #[cfg(not(feature = "state_update"))]
+ assert_matches!(res, ReceivedMessage::Commit(_));
+ }
+
+ #[cfg(feature = "psk")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn can_join_with_psk() {
+ let mut alice = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE)
+ .await
+ .group;
+
+ let (bob, key_pkg) =
+ test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "bob").await;
+
+ let psk_id = ExternalPskId::new(vec![0]);
+ let psk = PreSharedKey::from(vec![0]);
+
+ alice
+ .config
+ .secret_store()
+ .insert(psk_id.clone(), psk.clone());
+
+ bob.config.secret_store().insert(psk_id.clone(), psk);
+
+ let commit = alice
+ .commit_builder()
+ .add_member(key_pkg)
+ .unwrap()
+ .add_external_psk(psk_id)
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ bob.join_group(None, &commit.welcome_messages[0])
+ .await
+ .unwrap();
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn invalid_update_does_not_prevent_other_updates() {
+ const EXTENSION_TYPE: ExtensionType = ExtensionType::new(33);
+
+ let group_extensions = ExtensionList::from(vec![RequiredCapabilitiesExt {
+ extensions: vec![EXTENSION_TYPE],
+ ..Default::default()
+ }
+ .into_extension()
+ .unwrap()]);
+
+ // Alice creates a group requiring support for an extension
+ let mut alice = TestClientBuilder::new_for_test()
+ .with_random_signing_identity("alice", TEST_CIPHER_SUITE)
+ .await
+ .extension_type(EXTENSION_TYPE)
+ .build()
+ .create_group(group_extensions.clone())
+ .await
+ .unwrap();
+
+ let (bob_signing_identity, bob_secret_key) =
+ get_test_signing_identity(TEST_CIPHER_SUITE, b"bob").await;
+
+ let bob_client = TestClientBuilder::new_for_test()
+ .signing_identity(
+ bob_signing_identity.clone(),
+ bob_secret_key.clone(),
+ TEST_CIPHER_SUITE,
+ )
+ .extension_type(EXTENSION_TYPE)
+ .build();
+
+ let carol_client = TestClientBuilder::new_for_test()
+ .with_random_signing_identity("carol", TEST_CIPHER_SUITE)
+ .await
+ .extension_type(EXTENSION_TYPE)
+ .build();
+
+ let dave_client = TestClientBuilder::new_for_test()
+ .with_random_signing_identity("dave", TEST_CIPHER_SUITE)
+ .await
+ .extension_type(EXTENSION_TYPE)
+ .build();
+
+ // Alice adds Bob, Carol and Dave to the group. They all support the mandatory extension.
+ let commit = alice
+ .commit_builder()
+ .add_member(bob_client.generate_key_package_message().await.unwrap())
+ .unwrap()
+ .add_member(carol_client.generate_key_package_message().await.unwrap())
+ .unwrap()
+ .add_member(dave_client.generate_key_package_message().await.unwrap())
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ alice.apply_pending_commit().await.unwrap();
+
+ let mut bob = bob_client
+ .join_group(None, &commit.welcome_messages[0])
+ .await
+ .unwrap()
+ .0;
+
+ bob.write_to_storage().await.unwrap();
+
+ // Bob reloads his group data, but with parameters that will cause his generated leaves to
+ // not support the mandatory extension.
+ let mut bob = TestClientBuilder::new_for_test()
+ .signing_identity(bob_signing_identity, bob_secret_key, TEST_CIPHER_SUITE)
+ .key_package_repo(bob.config.key_package_repo())
+ .group_state_storage(bob.config.group_state_storage())
+ .build()
+ .load_group(alice.group_id())
+ .await
+ .unwrap();
+
+ let mut carol = carol_client
+ .join_group(None, &commit.welcome_messages[0])
+ .await
+ .unwrap()
+ .0;
+
+ let mut dave = dave_client
+ .join_group(None, &commit.welcome_messages[0])
+ .await
+ .unwrap()
+ .0;
+
+ // Bob's updated leaf does not support the mandatory extension.
+ let bob_update = bob.propose_update(Vec::new()).await.unwrap();
+ let carol_update = carol.propose_update(Vec::new()).await.unwrap();
+ let dave_update = dave.propose_update(Vec::new()).await.unwrap();
+
+ // Alice receives the update proposals to be committed.
+ alice.process_incoming_message(bob_update).await.unwrap();
+ alice.process_incoming_message(carol_update).await.unwrap();
+ alice.process_incoming_message(dave_update).await.unwrap();
+
+ // Alice commits the update proposals.
+ alice.commit(Vec::new()).await.unwrap();
+ let commit_desc = alice.apply_pending_commit().await.unwrap();
+
+ let find_update_for = |id: &str| {
+ commit_desc
+ .state_update
+ .roster_update
+ .updated()
+ .iter()
+ .filter_map(|u| u.prior.signing_identity.credential.as_basic())
+ .any(|c| c.identifier == id.as_bytes())
+ };
+
+ // Check that all updates preserve identities.
+ let identities_are_preserved = commit_desc
+ .state_update
+ .roster_update
+ .updated()
+ .iter()
+ .filter_map(|u| {
+ let before = &u.prior.signing_identity.credential.as_basic()?.identifier;
+ let after = &u.new.signing_identity.credential.as_basic()?.identifier;
+ Some((before, after))
+ })
+ .all(|(before, after)| before == after);
+
+ assert!(identities_are_preserved);
+
+ // Carol's and Dave's updates should be part of the commit.
+ assert!(find_update_for("carol"));
+ assert!(find_update_for("dave"));
+
+ // Bob's update should be rejected.
+ assert!(!find_update_for("bob"));
+
+ // Check that all members are still in the group.
+ let all_members_are_in = alice
+ .roster()
+ .members_iter()
+ .zip(["alice", "bob", "carol", "dave"])
+ .all(|(member, id)| {
+ member
+ .signing_identity
+ .credential
+ .as_basic()
+ .unwrap()
+ .identifier
+ == id.as_bytes()
+ });
+
+ assert!(all_members_are_in);
+ }
+
+ #[cfg(feature = "custom_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn custom_proposal_may_enforce_path() {
+ test_custom_proposal_mls_rules(true).await;
+ }
+
+ #[cfg(feature = "custom_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn custom_proposal_need_not_enforce_path() {
+ test_custom_proposal_mls_rules(false).await;
+ }
+
+ #[cfg(feature = "custom_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn test_custom_proposal_mls_rules(path_required_for_custom: bool) {
+ let mls_rules = CustomMlsRules {
+ path_required_for_custom,
+ external_joiner_can_send_custom: true,
+ };
+
+ let mut alice = client_with_custom_rules(b"alice", mls_rules.clone())
+ .await
+ .create_group(Default::default())
+ .await
+ .unwrap();
+
+ let alice_pub_before = alice.current_user_leaf_node().unwrap().public_key.clone();
+
+ let kp = client_with_custom_rules(b"bob", mls_rules)
+ .await
+ .generate_key_package_message()
+ .await
+ .unwrap();
+
+ alice
+ .commit_builder()
+ .custom_proposal(CustomProposal::new(TEST_CUSTOM_PROPOSAL_TYPE, vec![]))
+ .add_member(kp)
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ alice.apply_pending_commit().await.unwrap();
+
+ let alice_pub_after = &alice.current_user_leaf_node().unwrap().public_key;
+
+ if path_required_for_custom {
+ assert_ne!(alice_pub_after, &alice_pub_before);
+ } else {
+ assert_eq!(alice_pub_after, &alice_pub_before);
+ }
+ }
+
+ #[cfg(feature = "custom_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn custom_proposal_by_value_in_external_join_may_be_allowed() {
+ test_custom_proposal_by_value_in_external_join(true).await
+ }
+
+ #[cfg(feature = "custom_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn custom_proposal_by_value_in_external_join_may_not_be_allowed() {
+ test_custom_proposal_by_value_in_external_join(false).await
+ }
+
+ #[cfg(feature = "custom_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn test_custom_proposal_by_value_in_external_join(external_joiner_can_send_custom: bool) {
+ let mls_rules = CustomMlsRules {
+ path_required_for_custom: true,
+ external_joiner_can_send_custom,
+ };
+
+ let mut alice = client_with_custom_rules(b"alice", mls_rules.clone())
+ .await
+ .create_group(Default::default())
+ .await
+ .unwrap();
+
+ let group_info = alice
+ .group_info_message_allowing_ext_commit(true)
+ .await
+ .unwrap();
+
+ let commit = client_with_custom_rules(b"bob", mls_rules)
+ .await
+ .external_commit_builder()
+ .unwrap()
+ .with_custom_proposal(CustomProposal::new(TEST_CUSTOM_PROPOSAL_TYPE, vec![]))
+ .build(group_info)
+ .await;
+
+ if external_joiner_can_send_custom {
+ let commit = commit.unwrap().1;
+ alice.process_incoming_message(commit).await.unwrap();
+ } else {
+ assert_matches!(commit.map(|_| ()), Err(MlsError::MlsRulesError(_)));
+ }
+ }
+
+ #[cfg(feature = "custom_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn custom_proposal_by_ref_in_external_join() {
+ let mls_rules = CustomMlsRules {
+ path_required_for_custom: true,
+ external_joiner_can_send_custom: true,
+ };
+
+ let mut alice = client_with_custom_rules(b"alice", mls_rules.clone())
+ .await
+ .create_group(Default::default())
+ .await
+ .unwrap();
+
+ let by_ref = CustomProposal::new(TEST_CUSTOM_PROPOSAL_TYPE, vec![]);
+ let by_ref = alice.propose_custom(by_ref, vec![]).await.unwrap();
+
+ let group_info = alice
+ .group_info_message_allowing_ext_commit(true)
+ .await
+ .unwrap();
+
+ let (_, commit) = client_with_custom_rules(b"bob", mls_rules)
+ .await
+ .external_commit_builder()
+ .unwrap()
+ .with_received_custom_proposal(by_ref)
+ .build(group_info)
+ .await
+ .unwrap();
+
+ alice.process_incoming_message(commit).await.unwrap();
+ }
+
+ #[cfg(feature = "custom_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn client_with_custom_rules(
+ name: &[u8],
+ mls_rules: CustomMlsRules,
+ ) -> Client<impl MlsConfig> {
+ let (signing_identity, signer) = get_test_signing_identity(TEST_CIPHER_SUITE, name).await;
+
+ ClientBuilder::new()
+ .crypto_provider(TestCryptoProvider::new())
+ .identity_provider(BasicWithCustomProvider::new(BasicIdentityProvider::new()))
+ .signing_identity(signing_identity, signer, TEST_CIPHER_SUITE)
+ .custom_proposal_type(TEST_CUSTOM_PROPOSAL_TYPE)
+ .mls_rules(mls_rules)
+ .build()
+ }
+
+ #[derive(Debug, Clone)]
+ struct CustomMlsRules {
+ path_required_for_custom: bool,
+ external_joiner_can_send_custom: bool,
+ }
+
+ #[cfg(feature = "custom_proposal")]
+ impl ProposalBundle {
+ fn has_test_custom_proposal(&self) -> bool {
+ self.custom_proposal_types()
+ .any(|t| t == TEST_CUSTOM_PROPOSAL_TYPE)
+ }
+ }
+
+ #[cfg(feature = "custom_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ #[cfg_attr(mls_build_async, maybe_async::must_be_async)]
+ impl crate::MlsRules for CustomMlsRules {
+ type Error = MlsError;
+
+ fn commit_options(
+ &self,
+ _: &Roster,
+ _: &ExtensionList,
+ proposals: &ProposalBundle,
+ ) -> Result<CommitOptions, MlsError> {
+ Ok(CommitOptions::default().with_path_required(
+ !proposals.has_test_custom_proposal() || self.path_required_for_custom,
+ ))
+ }
+
+ fn encryption_options(
+ &self,
+ _: &Roster,
+ _: &ExtensionList,
+ ) -> Result<crate::mls_rules::EncryptionOptions, MlsError> {
+ Ok(Default::default())
+ }
+
+ async fn filter_proposals(
+ &self,
+ _: CommitDirection,
+ sender: CommitSource,
+ _: &Roster,
+ _: &ExtensionList,
+ proposals: ProposalBundle,
+ ) -> Result<ProposalBundle, MlsError> {
+ let is_external = matches!(sender, CommitSource::NewMember(_));
+ let has_custom = proposals.has_test_custom_proposal();
+ let allowed = !has_custom || !is_external || self.external_joiner_can_send_custom;
+
+ allowed.then_some(proposals).ok_or(MlsError::InvalidSender)
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn group_can_receive_commit_from_self() {
+ let mut group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE)
+ .await
+ .group;
+
+ let commit = group.commit(vec![]).await.unwrap();
+
+ let update = group
+ .process_incoming_message(commit.commit_message)
+ .await
+ .unwrap();
+
+ let ReceivedMessage::Commit(update) = update else {
+ panic!("expected commit message")
+ };
+
+ assert_eq!(update.committer, *group.private_tree.self_index);
+ }
+}
diff --git a/src/group/padding.rs b/src/group/padding.rs
new file mode 100644
index 0000000..6320ccf
--- /dev/null
+++ b/src/group/padding.rs
@@ -0,0 +1,109 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+/// Padding used when sending an encrypted group message.
+#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::ffi_type)]
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
+#[repr(u8)]
+pub enum PaddingMode {
+ /// Step function based on the size of the message being sent.
+ /// The amount of padding used will increase with the size of the original
+ /// message.
+ #[default]
+ StepFunction,
+ /// No padding.
+ None,
+}
+
+impl PaddingMode {
+ pub(super) fn padded_size(&self, content_size: usize) -> usize {
+ match self {
+ PaddingMode::StepFunction => {
+ // The padding hides all but 2 most significant bits of `length`. The hidden bits are replaced
+ // by zeros and then the next number is taken to make sure the message fits.
+ let blind = 1
+ << ((content_size + 1)
+ .next_power_of_two()
+ .max(256)
+ .trailing_zeros()
+ - 3);
+
+ (content_size | (blind - 1)) + 1
+ }
+ PaddingMode::None => content_size,
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::PaddingMode;
+
+ use alloc::vec;
+ use alloc::vec::Vec;
+ #[cfg(target_arch = "wasm32")]
+ use wasm_bindgen_test::wasm_bindgen_test as test;
+
+ #[derive(serde::Deserialize, serde::Serialize)]
+ struct TestCase {
+ input: usize,
+ output: usize,
+ }
+
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn generate_message_padding_test_vector() -> Vec<TestCase> {
+ let mut test_cases = vec![];
+ for x in 1..1024 {
+ test_cases.push(TestCase {
+ input: x,
+ output: PaddingMode::StepFunction.padded_size(x),
+ });
+ }
+ test_cases
+ }
+
+ fn load_test_cases() -> Vec<TestCase> {
+ load_test_case_json!(
+ message_padding_test_vector,
+ generate_message_padding_test_vector()
+ )
+ }
+
+ #[test]
+ fn test_no_padding() {
+ for i in [0, 100, 1000, 10000] {
+ assert_eq!(PaddingMode::None.padded_size(i), i)
+ }
+ }
+
+ #[test]
+ fn test_padding_length() {
+ assert_eq!(PaddingMode::StepFunction.padded_size(0), 32);
+
+ // Short
+ assert_eq!(PaddingMode::StepFunction.padded_size(63), 64);
+ assert_eq!(PaddingMode::StepFunction.padded_size(64), 96);
+ assert_eq!(PaddingMode::StepFunction.padded_size(65), 96);
+
+ // Almost long and almost short
+ assert_eq!(PaddingMode::StepFunction.padded_size(127), 128);
+ assert_eq!(PaddingMode::StepFunction.padded_size(128), 160);
+ assert_eq!(PaddingMode::StepFunction.padded_size(129), 160);
+
+ // One length from each of the 4 buckets between 256 and 512
+ assert_eq!(PaddingMode::StepFunction.padded_size(260), 320);
+ assert_eq!(PaddingMode::StepFunction.padded_size(330), 384);
+ assert_eq!(PaddingMode::StepFunction.padded_size(390), 448);
+ assert_eq!(PaddingMode::StepFunction.padded_size(490), 512);
+
+ // All test cases
+ let test_cases: Vec<TestCase> = load_test_cases();
+ for test_case in test_cases {
+ assert_eq!(
+ test_case.output,
+ PaddingMode::StepFunction.padded_size(test_case.input)
+ );
+ }
+ }
+}
diff --git a/src/group/proposal.rs b/src/group/proposal.rs
new file mode 100644
index 0000000..a31be29
--- /dev/null
+++ b/src/group/proposal.rs
@@ -0,0 +1,578 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::{boxed::Box, vec::Vec};
+
+#[cfg(feature = "by_ref_proposal")]
+use crate::tree_kem::leaf_node::LeafNode;
+
+use crate::{
+ client::MlsError, tree_kem::node::LeafIndex, CipherSuite, KeyPackage, MlsMessage,
+ ProtocolVersion,
+};
+use core::fmt::{self, Debug};
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+use mls_rs_core::{group::Capabilities, identity::SigningIdentity};
+
+#[cfg(feature = "by_ref_proposal")]
+use crate::group::proposal_ref::ProposalRef;
+
+pub use mls_rs_core::extension::ExtensionList;
+pub use mls_rs_core::group::ProposalType;
+
+#[cfg(feature = "psk")]
+use crate::psk::{ExternalPskId, JustPreSharedKeyID, PreSharedKeyID};
+
+#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+/// A proposal that adds a member to a [`Group`](crate::group::Group).
+pub struct AddProposal {
+ pub(crate) key_package: KeyPackage,
+}
+
+impl AddProposal {
+ /// The [`SigningIdentity`] of the [`Member`](mls_rs_core::group::Member)
+ /// that will be added by this proposal.
+ pub fn signing_identity(&self) -> &SigningIdentity {
+ self.key_package.signing_identity()
+ }
+
+ /// Client [`Capabilities`] of the [`Member`](mls_rs_core::group::Member)
+ /// that will be added by this proposal.
+ pub fn capabilities(&self) -> Capabilities {
+ self.key_package.leaf_node.ungreased_capabilities()
+ }
+
+ /// Key package extensions that are assoiciated with the
+ /// [`Member`](mls_rs_core::group::Member) that will be added by this proposal.
+ pub fn key_package_extensions(&self) -> ExtensionList {
+ self.key_package.ungreased_extensions()
+ }
+
+ /// Leaf node extensions that will be entered into the group state for the
+ /// [`Member`](mls_rs_core::group::Member) that will be added.
+ pub fn leaf_node_extensions(&self) -> ExtensionList {
+ self.key_package.leaf_node.ungreased_extensions()
+ }
+}
+
+impl From<KeyPackage> for AddProposal {
+ fn from(key_package: KeyPackage) -> Self {
+ Self { key_package }
+ }
+}
+
+impl TryFrom<MlsMessage> for AddProposal {
+ type Error = MlsError;
+
+ fn try_from(value: MlsMessage) -> Result<Self, Self::Error> {
+ value
+ .into_key_package()
+ .ok_or(MlsError::UnexpectedMessageType)
+ .map(Into::into)
+ }
+}
+
+#[cfg(feature = "by_ref_proposal")]
+#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+/// A proposal that will update an existing [`Member`](mls_rs_core::group::Member) of a
+/// [`Group`](crate::group::Group).
+pub struct UpdateProposal {
+ pub(crate) leaf_node: LeafNode,
+}
+
+#[cfg(feature = "by_ref_proposal")]
+impl UpdateProposal {
+ /// The new [`SigningIdentity`] of the [`Member`](mls_rs_core::group::Member)
+ /// that is being updated by this proposal.
+ pub fn signing_identity(&self) -> &SigningIdentity {
+ &self.leaf_node.signing_identity
+ }
+
+ /// New Client [`Capabilities`] of the [`Member`](mls_rs_core::group::Member)
+ /// that will be updated by this proposal.
+ pub fn capabilities(&self) -> Capabilities {
+ self.leaf_node.ungreased_capabilities()
+ }
+
+ /// New Leaf node extensions that will be entered into the group state for the
+ /// [`Member`](mls_rs_core::group::Member) that is being updated by this proposal.
+ pub fn leaf_node_extensions(&self) -> ExtensionList {
+ self.leaf_node.ungreased_extensions()
+ }
+}
+
+#[derive(Clone, Debug, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+/// A proposal to remove an existing [`Member`](mls_rs_core::group::Member) of a
+/// [`Group`](crate::group::Group).
+pub struct RemoveProposal {
+ pub(crate) to_remove: LeafIndex,
+}
+
+impl RemoveProposal {
+ /// The index of the [`Member`](mls_rs_core::group::Member) that will be removed by
+ /// this proposal.
+ pub fn to_remove(&self) -> u32 {
+ *self.to_remove
+ }
+}
+
+impl From<u32> for RemoveProposal {
+ fn from(value: u32) -> Self {
+ RemoveProposal {
+ to_remove: LeafIndex(value),
+ }
+ }
+}
+
+#[cfg(feature = "psk")]
+#[derive(Clone, Debug, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+/// A proposal to add a pre-shared key to a group.
+pub struct PreSharedKeyProposal {
+ pub(crate) psk: PreSharedKeyID,
+}
+
+#[cfg(feature = "psk")]
+impl PreSharedKeyProposal {
+ /// The external pre-shared key id of this proposal.
+ ///
+ /// MLS requires the pre-shared key type for PreSharedKeyProposal to be of
+ /// type `External`.
+ ///
+ /// Returns `None` in the condition that the underlying psk is not external.
+ pub fn external_psk_id(&self) -> Option<&ExternalPskId> {
+ match self.psk.key_id {
+ JustPreSharedKeyID::External(ref ext) => Some(ext),
+ JustPreSharedKeyID::Resumption(_) => None,
+ }
+ }
+}
+
+#[derive(Clone, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+/// A proposal to reinitialize a group using new parameters.
+pub struct ReInitProposal {
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
+ pub(crate) group_id: Vec<u8>,
+ pub(crate) version: ProtocolVersion,
+ pub(crate) cipher_suite: CipherSuite,
+ pub(crate) extensions: ExtensionList,
+}
+
+impl Debug for ReInitProposal {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("ReInitProposal")
+ .field(
+ "group_id",
+ &mls_rs_core::debug::pretty_group_id(&self.group_id),
+ )
+ .field("version", &self.version)
+ .field("cipher_suite", &self.cipher_suite)
+ .field("extensions", &self.extensions)
+ .finish()
+ }
+}
+
+impl ReInitProposal {
+ /// The unique id of the new group post reinitialization.
+ pub fn group_id(&self) -> &[u8] {
+ &self.group_id
+ }
+
+ /// The new protocol version to use post reinitialization.
+ pub fn new_version(&self) -> ProtocolVersion {
+ self.version
+ }
+
+ /// The new ciphersuite to use post reinitialization.
+ pub fn new_cipher_suite(&self) -> CipherSuite {
+ self.cipher_suite
+ }
+
+ /// Group context extensions to set in the new group post reinitialization.
+ pub fn new_group_context_extensions(&self) -> &ExtensionList {
+ &self.extensions
+ }
+}
+
+#[derive(Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+/// A proposal used for external commits.
+pub struct ExternalInit {
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
+ pub(crate) kem_output: Vec<u8>,
+}
+
+impl Debug for ExternalInit {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("ExternalInit")
+ .field(
+ "kem_output",
+ &mls_rs_core::debug::pretty_bytes(&self.kem_output),
+ )
+ .finish()
+ }
+}
+
+#[cfg(feature = "custom_proposal")]
+#[derive(Clone, PartialEq, Eq)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(
+ all(feature = "ffi", not(test)),
+ safer_ffi_gen::ffi_type(clone, opaque)
+)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+/// A user defined custom proposal.
+///
+/// User defined proposals are passed through the protocol as an opaque value.
+pub struct CustomProposal {
+ proposal_type: ProposalType,
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
+ data: Vec<u8>,
+}
+
+#[cfg(feature = "custom_proposal")]
+impl Debug for CustomProposal {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("CustomProposal")
+ .field("proposal_type", &self.proposal_type)
+ .field("data", &mls_rs_core::debug::pretty_bytes(&self.data))
+ .finish()
+ }
+}
+
+#[cfg(feature = "custom_proposal")]
+#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
+impl CustomProposal {
+ /// Create a custom proposal.
+ ///
+ /// # Warning
+ ///
+ /// Avoid using the [`ProposalType`] values that have constants already
+ /// defined by this crate. Using existing constants in a custom proposal
+ /// has unspecified behavior.
+ pub fn new(proposal_type: ProposalType, data: Vec<u8>) -> Self {
+ Self {
+ proposal_type,
+ data,
+ }
+ }
+
+ /// The proposal type used for this custom proposal.
+ pub fn proposal_type(&self) -> ProposalType {
+ self.proposal_type
+ }
+
+ /// The opaque data communicated by this custom proposal.
+ pub fn data(&self) -> &[u8] {
+ &self.data
+ }
+}
+
+/// Trait to simplify creating custom proposals that are serialized with MLS
+/// encoding.
+#[cfg(feature = "custom_proposal")]
+pub trait MlsCustomProposal: MlsSize + MlsEncode + MlsDecode + Sized {
+ fn proposal_type() -> ProposalType;
+
+ fn to_custom_proposal(&self) -> Result<CustomProposal, mls_rs_codec::Error> {
+ Ok(CustomProposal::new(
+ Self::proposal_type(),
+ self.mls_encode_to_vec()?,
+ ))
+ }
+
+ fn from_custom_proposal(proposal: &CustomProposal) -> Result<Self, mls_rs_codec::Error> {
+ if proposal.proposal_type() != Self::proposal_type() {
+ // #[cfg(feature = "std")]
+ // return Err(mls_rs_codec::Error::Custom(
+ // "invalid proposal type".to_string(),
+ // ));
+
+ //#[cfg(not(feature = "std"))]
+ return Err(mls_rs_codec::Error::Custom(4));
+ }
+
+ Self::mls_decode(&mut proposal.data())
+ }
+}
+
+#[allow(clippy::large_enum_variant)]
+#[derive(Clone, Debug, PartialEq)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+#[repr(u16)]
+#[non_exhaustive]
+/// An enum that represents all possible types of proposals.
+pub enum Proposal {
+ Add(alloc::boxed::Box<AddProposal>),
+ #[cfg(feature = "by_ref_proposal")]
+ Update(UpdateProposal),
+ Remove(RemoveProposal),
+ #[cfg(feature = "psk")]
+ Psk(PreSharedKeyProposal),
+ ReInit(ReInitProposal),
+ ExternalInit(ExternalInit),
+ GroupContextExtensions(ExtensionList),
+ #[cfg(feature = "custom_proposal")]
+ Custom(CustomProposal),
+}
+
+impl MlsSize for Proposal {
+ fn mls_encoded_len(&self) -> usize {
+ let inner_len = match self {
+ Proposal::Add(p) => p.mls_encoded_len(),
+ #[cfg(feature = "by_ref_proposal")]
+ Proposal::Update(p) => p.mls_encoded_len(),
+ Proposal::Remove(p) => p.mls_encoded_len(),
+ #[cfg(feature = "psk")]
+ Proposal::Psk(p) => p.mls_encoded_len(),
+ Proposal::ReInit(p) => p.mls_encoded_len(),
+ Proposal::ExternalInit(p) => p.mls_encoded_len(),
+ Proposal::GroupContextExtensions(p) => p.mls_encoded_len(),
+ #[cfg(feature = "custom_proposal")]
+ Proposal::Custom(p) => mls_rs_codec::byte_vec::mls_encoded_len(&p.data),
+ };
+
+ self.proposal_type().mls_encoded_len() + inner_len
+ }
+}
+
+impl MlsEncode for Proposal {
+ fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error> {
+ self.proposal_type().mls_encode(writer)?;
+
+ match self {
+ Proposal::Add(p) => p.mls_encode(writer),
+ #[cfg(feature = "by_ref_proposal")]
+ Proposal::Update(p) => p.mls_encode(writer),
+ Proposal::Remove(p) => p.mls_encode(writer),
+ #[cfg(feature = "psk")]
+ Proposal::Psk(p) => p.mls_encode(writer),
+ Proposal::ReInit(p) => p.mls_encode(writer),
+ Proposal::ExternalInit(p) => p.mls_encode(writer),
+ Proposal::GroupContextExtensions(p) => p.mls_encode(writer),
+ #[cfg(feature = "custom_proposal")]
+ Proposal::Custom(p) => {
+ if p.proposal_type.raw_value() <= 7 {
+ // #[cfg(feature = "std")]
+ // return Err(mls_rs_codec::Error::Custom(
+ // "custom proposal types can not be set to defined values of 0-7".to_string(),
+ // ));
+
+ // #[cfg(not(feature = "std"))]
+ return Err(mls_rs_codec::Error::Custom(2));
+ }
+ mls_rs_codec::byte_vec::mls_encode(&p.data, writer)
+ }
+ }
+ }
+}
+
+impl MlsDecode for Proposal {
+ fn mls_decode(reader: &mut &[u8]) -> Result<Self, mls_rs_codec::Error> {
+ let proposal_type = ProposalType::mls_decode(reader)?;
+
+ Ok(match proposal_type {
+ ProposalType::ADD => {
+ Proposal::Add(alloc::boxed::Box::new(AddProposal::mls_decode(reader)?))
+ }
+ #[cfg(feature = "by_ref_proposal")]
+ ProposalType::UPDATE => Proposal::Update(UpdateProposal::mls_decode(reader)?),
+ ProposalType::REMOVE => Proposal::Remove(RemoveProposal::mls_decode(reader)?),
+ #[cfg(feature = "psk")]
+ ProposalType::PSK => Proposal::Psk(PreSharedKeyProposal::mls_decode(reader)?),
+ ProposalType::RE_INIT => Proposal::ReInit(ReInitProposal::mls_decode(reader)?),
+ ProposalType::EXTERNAL_INIT => {
+ Proposal::ExternalInit(ExternalInit::mls_decode(reader)?)
+ }
+ ProposalType::GROUP_CONTEXT_EXTENSIONS => {
+ Proposal::GroupContextExtensions(ExtensionList::mls_decode(reader)?)
+ }
+ #[cfg(feature = "custom_proposal")]
+ custom => Proposal::Custom(CustomProposal {
+ proposal_type: custom,
+ data: mls_rs_codec::byte_vec::mls_decode(reader)?,
+ }),
+ // TODO fix test dependency on openssl loading codec with default features
+ #[cfg(not(feature = "custom_proposal"))]
+ _ => return Err(mls_rs_codec::Error::Custom(3)),
+ })
+ }
+}
+
+impl Proposal {
+ pub fn proposal_type(&self) -> ProposalType {
+ match self {
+ Proposal::Add(_) => ProposalType::ADD,
+ #[cfg(feature = "by_ref_proposal")]
+ Proposal::Update(_) => ProposalType::UPDATE,
+ Proposal::Remove(_) => ProposalType::REMOVE,
+ #[cfg(feature = "psk")]
+ Proposal::Psk(_) => ProposalType::PSK,
+ Proposal::ReInit(_) => ProposalType::RE_INIT,
+ Proposal::ExternalInit(_) => ProposalType::EXTERNAL_INIT,
+ Proposal::GroupContextExtensions(_) => ProposalType::GROUP_CONTEXT_EXTENSIONS,
+ #[cfg(feature = "custom_proposal")]
+ Proposal::Custom(c) => c.proposal_type,
+ }
+ }
+}
+
+#[derive(Clone, Debug, PartialEq)]
+/// An enum that represents a borrowed version of [`Proposal`].
+pub enum BorrowedProposal<'a> {
+ Add(&'a AddProposal),
+ #[cfg(feature = "by_ref_proposal")]
+ Update(&'a UpdateProposal),
+ Remove(&'a RemoveProposal),
+ #[cfg(feature = "psk")]
+ Psk(&'a PreSharedKeyProposal),
+ ReInit(&'a ReInitProposal),
+ ExternalInit(&'a ExternalInit),
+ GroupContextExtensions(&'a ExtensionList),
+ #[cfg(feature = "custom_proposal")]
+ Custom(&'a CustomProposal),
+}
+
+impl<'a> From<BorrowedProposal<'a>> for Proposal {
+ fn from(value: BorrowedProposal<'a>) -> Self {
+ match value {
+ BorrowedProposal::Add(add) => Proposal::Add(alloc::boxed::Box::new(add.clone())),
+ #[cfg(feature = "by_ref_proposal")]
+ BorrowedProposal::Update(update) => Proposal::Update(update.clone()),
+ BorrowedProposal::Remove(remove) => Proposal::Remove(remove.clone()),
+ #[cfg(feature = "psk")]
+ BorrowedProposal::Psk(psk) => Proposal::Psk(psk.clone()),
+ BorrowedProposal::ReInit(reinit) => Proposal::ReInit(reinit.clone()),
+ BorrowedProposal::ExternalInit(external) => Proposal::ExternalInit(external.clone()),
+ BorrowedProposal::GroupContextExtensions(ext) => {
+ Proposal::GroupContextExtensions(ext.clone())
+ }
+ #[cfg(feature = "custom_proposal")]
+ BorrowedProposal::Custom(custom) => Proposal::Custom(custom.clone()),
+ }
+ }
+}
+
+impl BorrowedProposal<'_> {
+ pub fn proposal_type(&self) -> ProposalType {
+ match self {
+ BorrowedProposal::Add(_) => ProposalType::ADD,
+ #[cfg(feature = "by_ref_proposal")]
+ BorrowedProposal::Update(_) => ProposalType::UPDATE,
+ BorrowedProposal::Remove(_) => ProposalType::REMOVE,
+ #[cfg(feature = "psk")]
+ BorrowedProposal::Psk(_) => ProposalType::PSK,
+ BorrowedProposal::ReInit(_) => ProposalType::RE_INIT,
+ BorrowedProposal::ExternalInit(_) => ProposalType::EXTERNAL_INIT,
+ BorrowedProposal::GroupContextExtensions(_) => ProposalType::GROUP_CONTEXT_EXTENSIONS,
+ #[cfg(feature = "custom_proposal")]
+ BorrowedProposal::Custom(c) => c.proposal_type,
+ }
+ }
+}
+
+impl<'a> From<&'a Proposal> for BorrowedProposal<'a> {
+ fn from(p: &'a Proposal) -> Self {
+ match p {
+ Proposal::Add(p) => BorrowedProposal::Add(p),
+ #[cfg(feature = "by_ref_proposal")]
+ Proposal::Update(p) => BorrowedProposal::Update(p),
+ Proposal::Remove(p) => BorrowedProposal::Remove(p),
+ #[cfg(feature = "psk")]
+ Proposal::Psk(p) => BorrowedProposal::Psk(p),
+ Proposal::ReInit(p) => BorrowedProposal::ReInit(p),
+ Proposal::ExternalInit(p) => BorrowedProposal::ExternalInit(p),
+ Proposal::GroupContextExtensions(p) => BorrowedProposal::GroupContextExtensions(p),
+ #[cfg(feature = "custom_proposal")]
+ Proposal::Custom(p) => BorrowedProposal::Custom(p),
+ }
+ }
+}
+
+impl<'a> From<&'a AddProposal> for BorrowedProposal<'a> {
+ fn from(p: &'a AddProposal) -> Self {
+ Self::Add(p)
+ }
+}
+
+#[cfg(feature = "by_ref_proposal")]
+impl<'a> From<&'a UpdateProposal> for BorrowedProposal<'a> {
+ fn from(p: &'a UpdateProposal) -> Self {
+ Self::Update(p)
+ }
+}
+
+impl<'a> From<&'a RemoveProposal> for BorrowedProposal<'a> {
+ fn from(p: &'a RemoveProposal) -> Self {
+ Self::Remove(p)
+ }
+}
+
+#[cfg(feature = "psk")]
+impl<'a> From<&'a PreSharedKeyProposal> for BorrowedProposal<'a> {
+ fn from(p: &'a PreSharedKeyProposal) -> Self {
+ Self::Psk(p)
+ }
+}
+
+impl<'a> From<&'a ReInitProposal> for BorrowedProposal<'a> {
+ fn from(p: &'a ReInitProposal) -> Self {
+ Self::ReInit(p)
+ }
+}
+
+impl<'a> From<&'a ExternalInit> for BorrowedProposal<'a> {
+ fn from(p: &'a ExternalInit) -> Self {
+ Self::ExternalInit(p)
+ }
+}
+
+impl<'a> From<&'a ExtensionList> for BorrowedProposal<'a> {
+ fn from(p: &'a ExtensionList) -> Self {
+ Self::GroupContextExtensions(p)
+ }
+}
+
+#[cfg(feature = "custom_proposal")]
+impl<'a> From<&'a CustomProposal> for BorrowedProposal<'a> {
+ fn from(p: &'a CustomProposal) -> Self {
+ Self::Custom(p)
+ }
+}
+
+#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+#[repr(u8)]
+pub(crate) enum ProposalOrRef {
+ Proposal(Box<Proposal>) = 1u8,
+ #[cfg(feature = "by_ref_proposal")]
+ Reference(ProposalRef) = 2u8,
+}
+
+impl From<Proposal> for ProposalOrRef {
+ fn from(proposal: Proposal) -> Self {
+ Self::Proposal(Box::new(proposal))
+ }
+}
+
+#[cfg(feature = "by_ref_proposal")]
+impl From<ProposalRef> for ProposalOrRef {
+ fn from(r: ProposalRef) -> Self {
+ Self::Reference(r)
+ }
+}
diff --git a/src/group/proposal_cache.rs b/src/group/proposal_cache.rs
new file mode 100644
index 0000000..17acf79
--- /dev/null
+++ b/src/group/proposal_cache.rs
@@ -0,0 +1,4216 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::vec::Vec;
+
+use super::{
+ message_processor::ProvisionalState,
+ mls_rules::{CommitDirection, CommitSource, MlsRules},
+ GroupState, ProposalOrRef,
+};
+use crate::{
+ client::MlsError,
+ group::{
+ proposal_filter::{ProposalApplier, ProposalBundle, ProposalSource},
+ Proposal, Sender,
+ },
+ time::MlsTime,
+};
+
+#[cfg(feature = "by_ref_proposal")]
+use crate::group::{proposal_filter::FilterStrategy, ProposalRef, ProtocolVersion};
+
+use crate::tree_kem::leaf_node::LeafNode;
+
+#[cfg(all(feature = "std", feature = "by_ref_proposal"))]
+use std::collections::HashMap;
+
+#[cfg(feature = "by_ref_proposal")]
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+
+use mls_rs_core::{
+ crypto::CipherSuiteProvider, error::IntoAnyError, identity::IdentityProvider,
+ psk::PreSharedKeyStorage,
+};
+
+#[cfg(feature = "by_ref_proposal")]
+use core::fmt::{self, Debug};
+
+#[cfg(feature = "by_ref_proposal")]
+#[derive(Debug, Clone, MlsSize, MlsEncode, MlsDecode, PartialEq)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub struct CachedProposal {
+ pub(crate) proposal: Proposal,
+ pub(crate) sender: Sender,
+}
+
+#[cfg(feature = "by_ref_proposal")]
+#[derive(Clone, PartialEq)]
+pub(crate) struct ProposalCache {
+ protocol_version: ProtocolVersion,
+ group_id: Vec<u8>,
+ #[cfg(feature = "std")]
+ pub(crate) proposals: HashMap<ProposalRef, CachedProposal>,
+ #[cfg(not(feature = "std"))]
+ pub(crate) proposals: Vec<(ProposalRef, CachedProposal)>,
+}
+
+#[cfg(feature = "by_ref_proposal")]
+impl Debug for ProposalCache {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("ProposalCache")
+ .field("protocol_version", &self.protocol_version)
+ .field(
+ "group_id",
+ &mls_rs_core::debug::pretty_group_id(&self.group_id),
+ )
+ .field("proposals", &self.proposals)
+ .finish()
+ }
+}
+
+#[cfg(feature = "by_ref_proposal")]
+impl ProposalCache {
+ pub fn new(protocol_version: ProtocolVersion, group_id: Vec<u8>) -> Self {
+ Self {
+ protocol_version,
+ group_id,
+ proposals: Default::default(),
+ }
+ }
+
+ pub fn import(
+ protocol_version: ProtocolVersion,
+ group_id: Vec<u8>,
+ #[cfg(feature = "std")] proposals: HashMap<ProposalRef, CachedProposal>,
+ #[cfg(not(feature = "std"))] proposals: Vec<(ProposalRef, CachedProposal)>,
+ ) -> Self {
+ Self {
+ protocol_version,
+ group_id,
+ proposals,
+ }
+ }
+
+ #[inline]
+ pub fn clear(&mut self) {
+ self.proposals.clear();
+ }
+
+ #[cfg(feature = "private_message")]
+ #[inline]
+ pub fn is_empty(&self) -> bool {
+ self.proposals.is_empty()
+ }
+
+ pub fn insert(&mut self, proposal_ref: ProposalRef, proposal: Proposal, sender: Sender) {
+ let cached_proposal = CachedProposal { proposal, sender };
+
+ #[cfg(feature = "std")]
+ self.proposals.insert(proposal_ref, cached_proposal);
+
+ #[cfg(not(feature = "std"))]
+ // This may result in dups but it does not matter
+ self.proposals.push((proposal_ref, cached_proposal));
+ }
+
+ pub fn prepare_commit(
+ &self,
+ sender: Sender,
+ additional_proposals: Vec<Proposal>,
+ ) -> ProposalBundle {
+ self.proposals
+ .iter()
+ .map(|(r, p)| {
+ (
+ p.proposal.clone(),
+ p.sender,
+ ProposalSource::ByReference(r.clone()),
+ )
+ })
+ .chain(
+ additional_proposals
+ .into_iter()
+ .map(|p| (p, sender, ProposalSource::ByValue)),
+ )
+ .collect()
+ }
+
+ pub fn resolve_for_commit(
+ &self,
+ sender: Sender,
+ proposal_list: Vec<ProposalOrRef>,
+ ) -> Result<ProposalBundle, MlsError> {
+ let mut proposals = ProposalBundle::default();
+
+ for p in proposal_list {
+ match p {
+ ProposalOrRef::Proposal(p) => proposals.add(*p, sender, ProposalSource::ByValue),
+ ProposalOrRef::Reference(r) => {
+ #[cfg(feature = "std")]
+ let p = self
+ .proposals
+ .get(&r)
+ .ok_or(MlsError::ProposalNotFound)?
+ .clone();
+ #[cfg(not(feature = "std"))]
+ let p = self
+ .proposals
+ .iter()
+ .find_map(|(rr, p)| (rr == &r).then_some(p))
+ .ok_or(MlsError::ProposalNotFound)?
+ .clone();
+
+ proposals.add(p.proposal, p.sender, ProposalSource::ByReference(r));
+ }
+ };
+ }
+
+ Ok(proposals)
+ }
+}
+
+#[cfg(not(feature = "by_ref_proposal"))]
+pub(crate) fn prepare_commit(
+ sender: Sender,
+ additional_proposals: Vec<Proposal>,
+) -> ProposalBundle {
+ let mut proposals = ProposalBundle::default();
+
+ for p in additional_proposals.into_iter() {
+ proposals.add(p, sender, ProposalSource::ByValue);
+ }
+
+ proposals
+}
+
+#[cfg(not(feature = "by_ref_proposal"))]
+pub(crate) fn resolve_for_commit(
+ sender: Sender,
+ proposal_list: Vec<ProposalOrRef>,
+) -> Result<ProposalBundle, MlsError> {
+ let mut proposals = ProposalBundle::default();
+
+ for p in proposal_list {
+ let ProposalOrRef::Proposal(p) = p;
+ proposals.add(*p, sender, ProposalSource::ByValue);
+ }
+
+ Ok(proposals)
+}
+
+impl GroupState {
+ #[inline(never)]
+ #[allow(clippy::too_many_arguments)]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn apply_resolved<C, F, P, CSP>(
+ &self,
+ sender: Sender,
+ mut proposals: ProposalBundle,
+ external_leaf: Option<&LeafNode>,
+ identity_provider: &C,
+ cipher_suite_provider: &CSP,
+ psk_storage: &P,
+ user_rules: &F,
+ commit_time: Option<MlsTime>,
+ direction: CommitDirection,
+ ) -> Result<ProvisionalState, MlsError>
+ where
+ C: IdentityProvider,
+ F: MlsRules,
+ P: PreSharedKeyStorage,
+ CSP: CipherSuiteProvider,
+ {
+ let roster = self.public_tree.roster();
+ let group_extensions = &self.context.extensions;
+
+ #[cfg(feature = "by_ref_proposal")]
+ let all_proposals = proposals.clone();
+
+ let origin = match sender {
+ Sender::Member(index) => Ok::<_, MlsError>(CommitSource::ExistingMember(
+ roster.member_with_index(index)?,
+ )),
+ #[cfg(feature = "by_ref_proposal")]
+ Sender::NewMemberProposal => Err(MlsError::InvalidSender),
+ #[cfg(feature = "by_ref_proposal")]
+ Sender::External(_) => Err(MlsError::InvalidSender),
+ Sender::NewMemberCommit => Ok(CommitSource::NewMember(
+ external_leaf
+ .map(|l| l.signing_identity.clone())
+ .ok_or(MlsError::ExternalCommitMustHaveNewLeaf)?,
+ )),
+ }?;
+
+ proposals = user_rules
+ .filter_proposals(direction, origin, &roster, group_extensions, proposals)
+ .await
+ .map_err(|e| MlsError::MlsRulesError(e.into_any_error()))?;
+
+ let applier = ProposalApplier::new(
+ &self.public_tree,
+ self.context.protocol_version,
+ cipher_suite_provider,
+ group_extensions,
+ external_leaf,
+ identity_provider,
+ psk_storage,
+ #[cfg(feature = "by_ref_proposal")]
+ &self.context.group_id,
+ );
+
+ #[cfg(feature = "by_ref_proposal")]
+ let applier_output = match direction {
+ CommitDirection::Send => {
+ applier
+ .apply_proposals(FilterStrategy::IgnoreByRef, &sender, proposals, commit_time)
+ .await?
+ }
+ CommitDirection::Receive => {
+ applier
+ .apply_proposals(FilterStrategy::IgnoreNone, &sender, proposals, commit_time)
+ .await?
+ }
+ };
+
+ #[cfg(not(feature = "by_ref_proposal"))]
+ let applier_output = applier
+ .apply_proposals(&sender, &proposals, commit_time)
+ .await?;
+
+ #[cfg(feature = "by_ref_proposal")]
+ let unused_proposals = unused_proposals(
+ match direction {
+ CommitDirection::Send => all_proposals,
+ CommitDirection::Receive => self.proposals.proposals.iter().collect(),
+ },
+ &applier_output.applied_proposals,
+ );
+
+ let mut group_context = self.context.clone();
+ group_context.epoch += 1;
+
+ if let Some(ext) = applier_output.new_context_extensions {
+ group_context.extensions = ext;
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ let proposals = applier_output.applied_proposals;
+
+ Ok(ProvisionalState {
+ public_tree: applier_output.new_tree,
+ group_context,
+ applied_proposals: proposals,
+ external_init_index: applier_output.external_init_index,
+ indexes_of_added_kpkgs: applier_output.indexes_of_added_kpkgs,
+ #[cfg(feature = "by_ref_proposal")]
+ unused_proposals,
+ })
+ }
+}
+
+#[cfg(feature = "by_ref_proposal")]
+impl Extend<(ProposalRef, CachedProposal)> for ProposalCache {
+ fn extend<T>(&mut self, iter: T)
+ where
+ T: IntoIterator<Item = (ProposalRef, CachedProposal)>,
+ {
+ self.proposals.extend(iter);
+ }
+}
+
+#[cfg(feature = "by_ref_proposal")]
+fn has_ref(proposals: &ProposalBundle, reference: &ProposalRef) -> bool {
+ proposals
+ .iter_proposals()
+ .any(|p| matches!(&p.source, ProposalSource::ByReference(r) if r == reference))
+}
+
+#[cfg(feature = "by_ref_proposal")]
+fn unused_proposals(
+ all_proposals: ProposalBundle,
+ accepted_proposals: &ProposalBundle,
+) -> Vec<crate::mls_rules::ProposalInfo<Proposal>> {
+ all_proposals
+ .into_proposals()
+ .filter(|p| {
+ matches!(p.source, ProposalSource::ByReference(ref r) if !has_ref(accepted_proposals, r)
+ )
+ })
+ .collect()
+}
+
+// TODO add tests for lite version of filtering
+#[cfg(all(feature = "by_ref_proposal", test))]
+pub(crate) mod test_utils {
+ use mls_rs_core::{
+ crypto::CipherSuiteProvider, extension::ExtensionList, identity::IdentityProvider,
+ psk::PreSharedKeyStorage,
+ };
+
+ use crate::{
+ client::test_utils::TEST_PROTOCOL_VERSION,
+ group::{
+ confirmation_tag::ConfirmationTag,
+ mls_rules::{CommitDirection, DefaultMlsRules, MlsRules},
+ proposal::{Proposal, ProposalOrRef},
+ proposal_ref::ProposalRef,
+ state::GroupState,
+ test_utils::{get_test_group_context, TEST_GROUP},
+ GroupContext, LeafIndex, LeafNode, ProvisionalState, Sender, TreeKemPublic,
+ },
+ identity::{basic::BasicIdentityProvider, test_utils::BasicWithCustomProvider},
+ psk::AlwaysFoundPskStorage,
+ };
+
+ use super::{CachedProposal, MlsError, ProposalCache};
+
+ use alloc::vec::Vec;
+
+ impl CachedProposal {
+ pub fn new(proposal: Proposal, sender: Sender) -> Self {
+ Self { proposal, sender }
+ }
+ }
+
+ #[derive(Debug)]
+ pub(crate) struct CommitReceiver<'a, C, F, P, CSP> {
+ tree: &'a TreeKemPublic,
+ sender: Sender,
+ receiver: LeafIndex,
+ cache: ProposalCache,
+ identity_provider: C,
+ cipher_suite_provider: CSP,
+ group_context_extensions: ExtensionList,
+ user_rules: F,
+ with_psk_storage: P,
+ }
+
+ impl<'a, CSP>
+ CommitReceiver<'a, BasicWithCustomProvider, DefaultMlsRules, AlwaysFoundPskStorage, CSP>
+ {
+ pub fn new<S>(
+ tree: &'a TreeKemPublic,
+ sender: S,
+ receiver: LeafIndex,
+ cipher_suite_provider: CSP,
+ ) -> Self
+ where
+ S: Into<Sender>,
+ {
+ Self {
+ tree,
+ sender: sender.into(),
+ receiver,
+ cache: make_proposal_cache(),
+ identity_provider: BasicWithCustomProvider::new(BasicIdentityProvider),
+ group_context_extensions: Default::default(),
+ user_rules: pass_through_rules(),
+ with_psk_storage: AlwaysFoundPskStorage,
+ cipher_suite_provider,
+ }
+ }
+ }
+
+ impl<'a, C, F, P, CSP> CommitReceiver<'a, C, F, P, CSP>
+ where
+ C: IdentityProvider,
+ F: MlsRules,
+ P: PreSharedKeyStorage,
+ CSP: CipherSuiteProvider,
+ {
+ #[cfg(feature = "by_ref_proposal")]
+ pub fn with_identity_provider<V>(self, validator: V) -> CommitReceiver<'a, V, F, P, CSP>
+ where
+ V: IdentityProvider,
+ {
+ CommitReceiver {
+ tree: self.tree,
+ sender: self.sender,
+ receiver: self.receiver,
+ cache: self.cache,
+ identity_provider: validator,
+ group_context_extensions: self.group_context_extensions,
+ user_rules: self.user_rules,
+ with_psk_storage: self.with_psk_storage,
+ cipher_suite_provider: self.cipher_suite_provider,
+ }
+ }
+
+ pub fn with_user_rules<G>(self, f: G) -> CommitReceiver<'a, C, G, P, CSP>
+ where
+ G: MlsRules,
+ {
+ CommitReceiver {
+ tree: self.tree,
+ sender: self.sender,
+ receiver: self.receiver,
+ cache: self.cache,
+ identity_provider: self.identity_provider,
+ group_context_extensions: self.group_context_extensions,
+ user_rules: f,
+ with_psk_storage: self.with_psk_storage,
+ cipher_suite_provider: self.cipher_suite_provider,
+ }
+ }
+
+ pub fn with_psk_storage<V>(self, v: V) -> CommitReceiver<'a, C, F, V, CSP>
+ where
+ V: PreSharedKeyStorage,
+ {
+ CommitReceiver {
+ tree: self.tree,
+ sender: self.sender,
+ receiver: self.receiver,
+ cache: self.cache,
+ identity_provider: self.identity_provider,
+ group_context_extensions: self.group_context_extensions,
+ user_rules: self.user_rules,
+ with_psk_storage: v,
+ cipher_suite_provider: self.cipher_suite_provider,
+ }
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ pub fn with_extensions(self, extensions: ExtensionList) -> Self {
+ Self {
+ group_context_extensions: extensions,
+ ..self
+ }
+ }
+
+ pub fn cache<S>(mut self, r: ProposalRef, p: Proposal, proposer: S) -> Self
+ where
+ S: Into<Sender>,
+ {
+ self.cache.insert(r, p, proposer.into());
+ self
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn receive<I>(&self, proposals: I) -> Result<ProvisionalState, MlsError>
+ where
+ I: IntoIterator,
+ I::Item: Into<ProposalOrRef>,
+ {
+ self.cache
+ .resolve_for_commit_default(
+ self.sender,
+ proposals.into_iter().map(Into::into).collect(),
+ None,
+ &self.group_context_extensions,
+ &self.identity_provider,
+ &self.cipher_suite_provider,
+ self.tree,
+ &self.with_psk_storage,
+ &self.user_rules,
+ )
+ .await
+ }
+ }
+
+ pub(crate) fn make_proposal_cache() -> ProposalCache {
+ ProposalCache::new(TEST_PROTOCOL_VERSION, TEST_GROUP.to_vec())
+ }
+
+ pub fn pass_through_rules() -> DefaultMlsRules {
+ DefaultMlsRules::new()
+ }
+
+ impl ProposalCache {
+ #[allow(clippy::too_many_arguments)]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn resolve_for_commit_default<C, F, P, CSP>(
+ &self,
+ sender: Sender,
+ proposal_list: Vec<ProposalOrRef>,
+ external_leaf: Option<&LeafNode>,
+ group_extensions: &ExtensionList,
+ identity_provider: &C,
+ cipher_suite_provider: &CSP,
+ public_tree: &TreeKemPublic,
+ psk_storage: &P,
+ user_rules: F,
+ ) -> Result<ProvisionalState, MlsError>
+ where
+ C: IdentityProvider,
+ F: MlsRules,
+ P: PreSharedKeyStorage,
+ CSP: CipherSuiteProvider,
+ {
+ let mut context =
+ get_test_group_context(123, cipher_suite_provider.cipher_suite()).await;
+
+ context.extensions = group_extensions.clone();
+
+ let mut state = GroupState::new(
+ context,
+ public_tree.clone(),
+ Vec::new().into(),
+ ConfirmationTag::empty(cipher_suite_provider).await,
+ );
+
+ state.proposals.proposals = self.proposals.clone();
+ let proposals = self.resolve_for_commit(sender, proposal_list)?;
+
+ state
+ .apply_resolved(
+ sender,
+ proposals,
+ external_leaf,
+ identity_provider,
+ cipher_suite_provider,
+ psk_storage,
+ &user_rules,
+ None,
+ CommitDirection::Receive,
+ )
+ .await
+ }
+
+ #[allow(clippy::too_many_arguments)]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn prepare_commit_default<C, F, P, CSP>(
+ &self,
+ sender: Sender,
+ additional_proposals: Vec<Proposal>,
+ context: &GroupContext,
+ identity_provider: &C,
+ cipher_suite_provider: &CSP,
+ public_tree: &TreeKemPublic,
+ external_leaf: Option<&LeafNode>,
+ psk_storage: &P,
+ user_rules: F,
+ ) -> Result<ProvisionalState, MlsError>
+ where
+ C: IdentityProvider,
+ F: MlsRules,
+ P: PreSharedKeyStorage,
+ CSP: CipherSuiteProvider,
+ {
+ let state = GroupState::new(
+ context.clone(),
+ public_tree.clone(),
+ Vec::new().into(),
+ ConfirmationTag::empty(cipher_suite_provider).await,
+ );
+
+ let proposals = self.prepare_commit(sender, additional_proposals);
+
+ state
+ .apply_resolved(
+ sender,
+ proposals,
+ external_leaf,
+ identity_provider,
+ cipher_suite_provider,
+ psk_storage,
+ &user_rules,
+ None,
+ CommitDirection::Send,
+ )
+ .await
+ }
+ }
+}
+
+// TODO add tests for lite version of filtering
+#[cfg(all(feature = "by_ref_proposal", test))]
+mod tests {
+ use alloc::{boxed::Box, vec, vec::Vec};
+
+ use super::test_utils::{make_proposal_cache, pass_through_rules, CommitReceiver};
+ use super::{CachedProposal, ProposalCache};
+ use crate::client::MlsError;
+ use crate::group::message_processor::ProvisionalState;
+ use crate::group::mls_rules::{CommitDirection, CommitSource, EncryptionOptions};
+ use crate::group::proposal_filter::{ProposalBundle, ProposalInfo, ProposalSource};
+ use crate::group::proposal_ref::test_utils::auth_content_from_proposal;
+ use crate::group::proposal_ref::ProposalRef;
+ use crate::group::{
+ AddProposal, AuthenticatedContent, Content, ExternalInit, Proposal, ProposalOrRef,
+ ReInitProposal, RemoveProposal, Roster, Sender, UpdateProposal,
+ };
+ use crate::key_package::test_utils::test_key_package_with_signer;
+ use crate::signer::Signable;
+ use crate::tree_kem::leaf_node::LeafNode;
+ use crate::tree_kem::node::LeafIndex;
+ use crate::tree_kem::TreeKemPublic;
+ use crate::{
+ client::test_utils::{TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
+ crypto::{self, test_utils::test_cipher_suite_provider},
+ extension::test_utils::TestExtension,
+ group::{
+ message_processor::path_update_required,
+ proposal_filter::proposer_can_propose,
+ test_utils::{get_test_group_context, random_bytes, test_group, TEST_GROUP},
+ },
+ identity::basic::BasicIdentityProvider,
+ identity::test_utils::{get_test_signing_identity, BasicWithCustomProvider},
+ key_package::{test_utils::test_key_package, KeyPackageGenerator},
+ mls_rules::{CommitOptions, DefaultMlsRules},
+ psk::AlwaysFoundPskStorage,
+ tree_kem::{
+ leaf_node::{
+ test_utils::{
+ default_properties, get_basic_test_node, get_basic_test_node_capabilities,
+ get_basic_test_node_sig_key, get_test_capabilities,
+ },
+ ConfigProperties, LeafNodeSigningContext, LeafNodeSource,
+ },
+ Lifetime,
+ },
+ };
+ use crate::{KeyPackage, MlsRules};
+
+ use crate::extension::RequiredCapabilitiesExt;
+
+ #[cfg(feature = "by_ref_proposal")]
+ use crate::{
+ extension::ExternalSendersExt,
+ tree_kem::leaf_node_validator::test_utils::FailureIdentityProvider,
+ };
+
+ #[cfg(feature = "psk")]
+ use crate::{
+ group::proposal::PreSharedKeyProposal,
+ psk::{
+ ExternalPskId, JustPreSharedKeyID, PreSharedKeyID, PskGroupId, PskNonce,
+ ResumptionPSKUsage, ResumptionPsk,
+ },
+ };
+
+ #[cfg(feature = "custom_proposal")]
+ use crate::group::proposal::CustomProposal;
+
+ use assert_matches::assert_matches;
+ use core::convert::Infallible;
+ use itertools::Itertools;
+ use mls_rs_core::crypto::{CipherSuite, CipherSuiteProvider};
+ use mls_rs_core::extension::ExtensionList;
+ use mls_rs_core::group::{Capabilities, ProposalType};
+ use mls_rs_core::identity::IdentityProvider;
+ use mls_rs_core::protocol_version::ProtocolVersion;
+ use mls_rs_core::psk::{PreSharedKey, PreSharedKeyStorage};
+ use mls_rs_core::{
+ extension::MlsExtension,
+ identity::{Credential, CredentialType, CustomCredential},
+ };
+
+ fn test_sender() -> u32 {
+ 1
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn new_tree_custom_proposals(
+ name: &str,
+ proposal_types: Vec<ProposalType>,
+ ) -> (LeafIndex, TreeKemPublic) {
+ let (leaf, secret, _) = get_basic_test_node_capabilities(
+ TEST_CIPHER_SUITE,
+ name,
+ Capabilities {
+ proposals: proposal_types,
+ ..get_test_capabilities()
+ },
+ )
+ .await;
+
+ let (pub_tree, priv_tree) =
+ TreeKemPublic::derive(leaf, secret, &BasicIdentityProvider, &Default::default())
+ .await
+ .unwrap();
+
+ (priv_tree.self_index, pub_tree)
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn new_tree(name: &str) -> (LeafIndex, TreeKemPublic) {
+ new_tree_custom_proposals(name, vec![]).await
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn add_member(tree: &mut TreeKemPublic, name: &str) -> LeafIndex {
+ let test_node = get_basic_test_node(TEST_CIPHER_SUITE, name).await;
+
+ tree.add_leaves(
+ vec![test_node],
+ &BasicIdentityProvider,
+ &test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .await
+ .unwrap()[0]
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn update_leaf_node(name: &str, leaf_index: u32) -> LeafNode {
+ let (mut leaf, _, signer) = get_basic_test_node_sig_key(TEST_CIPHER_SUITE, name).await;
+
+ leaf.update(
+ &test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ TEST_GROUP,
+ leaf_index,
+ default_properties(),
+ None,
+ &signer,
+ )
+ .await
+ .unwrap();
+
+ leaf
+ }
+
+ struct TestProposals {
+ test_sender: u32,
+ test_proposals: Vec<AuthenticatedContent>,
+ expected_effects: ProvisionalState,
+ tree: TreeKemPublic,
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn test_proposals(
+ protocol_version: ProtocolVersion,
+ cipher_suite: CipherSuite,
+ ) -> TestProposals {
+ let cipher_suite_provider = test_cipher_suite_provider(cipher_suite);
+
+ let (sender_leaf, sender_leaf_secret, _) =
+ get_basic_test_node_sig_key(cipher_suite, "alice").await;
+
+ let sender = LeafIndex(0);
+
+ let (mut tree, _) = TreeKemPublic::derive(
+ sender_leaf,
+ sender_leaf_secret,
+ &BasicIdentityProvider,
+ &Default::default(),
+ )
+ .await
+ .unwrap();
+
+ let add_package = test_key_package(protocol_version, cipher_suite, "dave").await;
+
+ let remove_leaf_index = add_member(&mut tree, "carol").await;
+
+ let add = Proposal::Add(Box::new(AddProposal {
+ key_package: add_package.clone(),
+ }));
+
+ let remove = Proposal::Remove(RemoveProposal {
+ to_remove: remove_leaf_index,
+ });
+
+ let extensions = Proposal::GroupContextExtensions(ExtensionList::new());
+
+ let proposals = vec![add, remove, extensions];
+
+ let test_node = get_basic_test_node(cipher_suite, "charlie").await;
+
+ let test_sender = *tree
+ .add_leaves(
+ vec![test_node],
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap()[0];
+
+ let mut expected_tree = tree.clone();
+
+ let mut bundle = ProposalBundle::default();
+
+ let plaintext = proposals
+ .iter()
+ .cloned()
+ .map(|p| auth_content_from_proposal(p, sender))
+ .collect_vec();
+
+ for i in 0..proposals.len() {
+ let pref = ProposalRef::from_content(&cipher_suite_provider, &plaintext[i])
+ .await
+ .unwrap();
+
+ bundle.add(
+ proposals[i].clone(),
+ Sender::Member(test_sender),
+ ProposalSource::ByReference(pref),
+ )
+ }
+
+ expected_tree
+ .batch_edit(
+ &mut bundle,
+ &Default::default(),
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ true,
+ )
+ .await
+ .unwrap();
+
+ let expected_effects = ProvisionalState {
+ public_tree: expected_tree,
+ group_context: get_test_group_context(1, cipher_suite).await,
+ external_init_index: None,
+ indexes_of_added_kpkgs: vec![LeafIndex(1)],
+ #[cfg(feature = "state_update")]
+ unused_proposals: vec![],
+ applied_proposals: bundle,
+ };
+
+ TestProposals {
+ test_sender,
+ test_proposals: plaintext,
+ expected_effects,
+ tree,
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn filter_proposals(
+ cipher_suite: CipherSuite,
+ proposals: Vec<AuthenticatedContent>,
+ ) -> Vec<(ProposalRef, CachedProposal)> {
+ let mut contents = Vec::new();
+
+ for p in proposals {
+ if let Content::Proposal(proposal) = &p.content.content {
+ let proposal_ref =
+ ProposalRef::from_content(&test_cipher_suite_provider(cipher_suite), &p)
+ .await
+ .unwrap();
+ contents.push((
+ proposal_ref,
+ CachedProposal::new(proposal.as_ref().clone(), p.content.sender),
+ ));
+ }
+ }
+
+ contents
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn make_proposal_ref<S>(p: &Proposal, sender: S) -> ProposalRef
+ where
+ S: Into<Sender>,
+ {
+ ProposalRef::from_content(
+ &test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ &auth_content_from_proposal(p.clone(), sender),
+ )
+ .await
+ .unwrap()
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn make_proposal_info<S>(p: &Proposal, sender: S) -> ProposalInfo<Proposal>
+ where
+ S: Into<Sender> + Clone,
+ {
+ ProposalInfo {
+ proposal: p.clone(),
+ sender: sender.clone().into(),
+ source: ProposalSource::ByReference(make_proposal_ref(p, sender).await),
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn test_proposal_cache_setup(proposals: Vec<AuthenticatedContent>) -> ProposalCache {
+ let mut cache = make_proposal_cache();
+ cache.extend(filter_proposals(TEST_CIPHER_SUITE, proposals).await);
+ cache
+ }
+
+ fn assert_matches(mut expected_state: ProvisionalState, state: ProvisionalState) {
+ let expected_proposals = expected_state.applied_proposals.into_proposals_or_refs();
+ let proposals = state.applied_proposals.into_proposals_or_refs();
+
+ assert_eq!(proposals.len(), expected_proposals.len());
+
+ // Determine there are no duplicates in the proposals returned
+ assert!(!proposals.iter().enumerate().any(|(i, p1)| proposals
+ .iter()
+ .enumerate()
+ .any(|(j, p2)| p1 == p2 && i != j)),);
+
+ // Proposal order may change so we just compare the length and contents are the same
+ expected_proposals
+ .iter()
+ .for_each(|p| assert!(proposals.contains(p)));
+
+ assert_eq!(
+ expected_state.external_init_index,
+ state.external_init_index
+ );
+
+ // We don't compare the epoch in this test.
+ expected_state.group_context.epoch = state.group_context.epoch;
+ assert_eq!(expected_state.group_context, state.group_context);
+
+ assert_eq!(
+ expected_state.indexes_of_added_kpkgs,
+ state.indexes_of_added_kpkgs
+ );
+
+ assert_eq!(expected_state.public_tree, state.public_tree);
+
+ #[cfg(feature = "state_update")]
+ assert_eq!(expected_state.unused_proposals, state.unused_proposals);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_proposal_cache_commit_all_cached() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let TestProposals {
+ test_sender,
+ test_proposals,
+ expected_effects,
+ tree,
+ ..
+ } = test_proposals(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+
+ let cache = test_proposal_cache_setup(test_proposals.clone()).await;
+
+ let provisional_state = cache
+ .prepare_commit_default(
+ Sender::Member(test_sender),
+ vec![],
+ &get_test_group_context(0, TEST_CIPHER_SUITE).await,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ &tree,
+ None,
+ &AlwaysFoundPskStorage,
+ pass_through_rules(),
+ )
+ .await
+ .unwrap();
+
+ assert_matches(expected_effects, provisional_state)
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_proposal_cache_commit_additional() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let TestProposals {
+ test_sender,
+ test_proposals,
+ mut expected_effects,
+ tree,
+ ..
+ } = test_proposals(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+
+ let additional_key_package =
+ test_key_package(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "frank").await;
+
+ let additional = AddProposal {
+ key_package: additional_key_package.clone(),
+ };
+
+ let cache = test_proposal_cache_setup(test_proposals.clone()).await;
+
+ let provisional_state = cache
+ .prepare_commit_default(
+ Sender::Member(test_sender),
+ vec![Proposal::Add(Box::new(additional.clone()))],
+ &get_test_group_context(0, TEST_CIPHER_SUITE).await,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ &tree,
+ None,
+ &AlwaysFoundPskStorage,
+ pass_through_rules(),
+ )
+ .await
+ .unwrap();
+
+ expected_effects.applied_proposals.add(
+ Proposal::Add(Box::new(additional.clone())),
+ Sender::Member(test_sender),
+ ProposalSource::ByValue,
+ );
+
+ let leaf = vec![additional_key_package.leaf_node.clone()];
+
+ expected_effects
+ .public_tree
+ .add_leaves(leaf, &BasicIdentityProvider, &cipher_suite_provider)
+ .await
+ .unwrap();
+
+ expected_effects.indexes_of_added_kpkgs.push(LeafIndex(3));
+
+ assert_matches(expected_effects, provisional_state);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_proposal_cache_update_filter() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let TestProposals {
+ test_proposals,
+ tree,
+ ..
+ } = test_proposals(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+
+ let update_proposal = make_update_proposal("foo").await;
+
+ let additional = vec![Proposal::Update(update_proposal)];
+
+ let cache = test_proposal_cache_setup(test_proposals).await;
+
+ let res = cache
+ .prepare_commit_default(
+ Sender::Member(test_sender()),
+ additional,
+ &get_test_group_context(0, TEST_CIPHER_SUITE).await,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ &tree,
+ None,
+ &AlwaysFoundPskStorage,
+ pass_through_rules(),
+ )
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidProposalTypeForSender));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_proposal_cache_removal_override_update() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let TestProposals {
+ test_sender,
+ test_proposals,
+ tree,
+ ..
+ } = test_proposals(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+
+ let update = Proposal::Update(make_update_proposal("foo").await);
+ let update_proposal_ref = make_proposal_ref(&update, LeafIndex(1)).await;
+ let mut cache = test_proposal_cache_setup(test_proposals).await;
+
+ cache.insert(update_proposal_ref.clone(), update, Sender::Member(1));
+
+ let provisional_state = cache
+ .prepare_commit_default(
+ Sender::Member(test_sender),
+ vec![],
+ &get_test_group_context(0, TEST_CIPHER_SUITE).await,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ &tree,
+ None,
+ &AlwaysFoundPskStorage,
+ pass_through_rules(),
+ )
+ .await
+ .unwrap();
+
+ assert!(provisional_state
+ .applied_proposals
+ .removals
+ .iter()
+ .any(|p| *p.proposal.to_remove == 1));
+
+ assert!(!provisional_state
+ .applied_proposals
+ .into_proposals_or_refs()
+ .contains(&ProposalOrRef::Reference(update_proposal_ref)))
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_proposal_cache_filter_duplicates_insert() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let TestProposals {
+ test_sender,
+ test_proposals,
+ expected_effects,
+ tree,
+ ..
+ } = test_proposals(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+
+ let mut cache = test_proposal_cache_setup(test_proposals.clone()).await;
+ cache.extend(filter_proposals(TEST_CIPHER_SUITE, test_proposals.clone()).await);
+
+ let provisional_state = cache
+ .prepare_commit_default(
+ Sender::Member(test_sender),
+ vec![],
+ &get_test_group_context(0, TEST_CIPHER_SUITE).await,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ &tree,
+ None,
+ &AlwaysFoundPskStorage,
+ pass_through_rules(),
+ )
+ .await
+ .unwrap();
+
+ assert_matches(expected_effects, provisional_state)
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_proposal_cache_filter_duplicates_additional() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let TestProposals {
+ test_proposals,
+ expected_effects,
+ tree,
+ ..
+ } = test_proposals(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+
+ let mut cache = test_proposal_cache_setup(test_proposals.clone()).await;
+
+ // Updates from different senders will be allowed so we test duplicates for add / remove
+ let additional = test_proposals
+ .clone()
+ .into_iter()
+ .filter_map(|plaintext| match plaintext.content.content {
+ Content::Proposal(p) if p.proposal_type() == ProposalType::UPDATE => None,
+ Content::Proposal(_) => Some(plaintext),
+ _ => None,
+ })
+ .collect::<Vec<_>>();
+
+ cache.extend(filter_proposals(TEST_CIPHER_SUITE, additional).await);
+
+ let provisional_state = cache
+ .prepare_commit_default(
+ Sender::Member(2),
+ Vec::new(),
+ &get_test_group_context(0, TEST_CIPHER_SUITE).await,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ &tree,
+ None,
+ &AlwaysFoundPskStorage,
+ pass_through_rules(),
+ )
+ .await
+ .unwrap();
+
+ assert_matches(expected_effects, provisional_state)
+ }
+
+ #[cfg(feature = "private_message")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_proposal_cache_is_empty() {
+ let mut cache = make_proposal_cache();
+ assert!(cache.is_empty());
+
+ let test_proposal = Proposal::Remove(RemoveProposal {
+ to_remove: LeafIndex(test_sender()),
+ });
+
+ let proposer = test_sender();
+ let test_proposal_ref = make_proposal_ref(&test_proposal, LeafIndex(proposer)).await;
+ cache.insert(test_proposal_ref, test_proposal, Sender::Member(proposer));
+
+ assert!(!cache.is_empty())
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_proposal_cache_resolve() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let TestProposals {
+ test_sender,
+ test_proposals,
+ tree,
+ ..
+ } = test_proposals(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+
+ let cache = test_proposal_cache_setup(test_proposals).await;
+
+ let proposal = Proposal::Add(Box::new(AddProposal {
+ key_package: test_key_package(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "frank").await,
+ }));
+
+ let additional = vec![proposal];
+
+ let expected_effects = cache
+ .prepare_commit_default(
+ Sender::Member(test_sender),
+ additional,
+ &get_test_group_context(0, TEST_CIPHER_SUITE).await,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ &tree,
+ None,
+ &AlwaysFoundPskStorage,
+ pass_through_rules(),
+ )
+ .await
+ .unwrap();
+
+ let proposals = expected_effects
+ .applied_proposals
+ .clone()
+ .into_proposals_or_refs();
+
+ let resolution = cache
+ .resolve_for_commit_default(
+ Sender::Member(test_sender),
+ proposals,
+ None,
+ &ExtensionList::new(),
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ &tree,
+ &AlwaysFoundPskStorage,
+ pass_through_rules(),
+ )
+ .await
+ .unwrap();
+
+ assert_matches(expected_effects, resolution);
+ }
+
+ #[cfg(feature = "psk")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn proposal_cache_filters_duplicate_psk_ids() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let (alice, tree) = new_tree("alice").await;
+ let cache = make_proposal_cache();
+
+ let proposal = Proposal::Psk(make_external_psk(
+ b"ted",
+ crate::psk::PskNonce::random(&test_cipher_suite_provider(TEST_CIPHER_SUITE)).unwrap(),
+ ));
+
+ let res = cache
+ .prepare_commit_default(
+ Sender::Member(*alice),
+ vec![proposal.clone(), proposal],
+ &get_test_group_context(0, TEST_CIPHER_SUITE).await,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ &tree,
+ None,
+ &AlwaysFoundPskStorage,
+ pass_through_rules(),
+ )
+ .await;
+
+ assert_matches!(res, Err(MlsError::DuplicatePskIds));
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn test_node() -> LeafNode {
+ let (mut leaf_node, _, signer) =
+ get_basic_test_node_sig_key(TEST_CIPHER_SUITE, "foo").await;
+
+ leaf_node
+ .commit(
+ &test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ TEST_GROUP,
+ 0,
+ default_properties(),
+ None,
+ &signer,
+ )
+ .await
+ .unwrap();
+
+ leaf_node
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn external_commit_must_have_new_leaf() {
+ let cache = make_proposal_cache();
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let kem_output = vec![0; cipher_suite_provider.kdf_extract_size()];
+ let group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let public_tree = &group.group.state.public_tree;
+
+ let res = cache
+ .resolve_for_commit_default(
+ Sender::NewMemberCommit,
+ vec![ProposalOrRef::Proposal(Box::new(Proposal::ExternalInit(
+ ExternalInit { kem_output },
+ )))],
+ None,
+ &group.group.context().extensions,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ public_tree,
+ &AlwaysFoundPskStorage,
+ pass_through_rules(),
+ )
+ .await;
+
+ assert_matches!(res, Err(MlsError::ExternalCommitMustHaveNewLeaf));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn proposal_cache_rejects_proposals_by_ref_for_new_member() {
+ let mut cache = make_proposal_cache();
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let proposal = {
+ let kem_output = vec![0; cipher_suite_provider.kdf_extract_size()];
+ Proposal::ExternalInit(ExternalInit { kem_output })
+ };
+
+ let proposal_ref = make_proposal_ref(&proposal, test_sender()).await;
+
+ cache.insert(
+ proposal_ref.clone(),
+ proposal,
+ Sender::Member(test_sender()),
+ );
+
+ let group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let public_tree = &group.group.state.public_tree;
+
+ let res = cache
+ .resolve_for_commit_default(
+ Sender::NewMemberCommit,
+ vec![ProposalOrRef::Reference(proposal_ref)],
+ Some(&test_node().await),
+ &group.group.context().extensions,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ public_tree,
+ &AlwaysFoundPskStorage,
+ pass_through_rules(),
+ )
+ .await;
+
+ assert_matches!(res, Err(MlsError::OnlyMembersCanCommitProposalsByRef));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn proposal_cache_rejects_multiple_external_init_proposals_in_commit() {
+ let cache = make_proposal_cache();
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+ let kem_output = vec![0; cipher_suite_provider.kdf_extract_size()];
+ let group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let public_tree = &group.group.state.public_tree;
+
+ let res = cache
+ .resolve_for_commit_default(
+ Sender::NewMemberCommit,
+ [
+ Proposal::ExternalInit(ExternalInit {
+ kem_output: kem_output.clone(),
+ }),
+ Proposal::ExternalInit(ExternalInit { kem_output }),
+ ]
+ .into_iter()
+ .map(|p| ProposalOrRef::Proposal(Box::new(p)))
+ .collect(),
+ Some(&test_node().await),
+ &group.group.context().extensions,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ public_tree,
+ &AlwaysFoundPskStorage,
+ pass_through_rules(),
+ )
+ .await;
+
+ assert_matches!(
+ res,
+ Err(MlsError::ExternalCommitMustHaveExactlyOneExternalInit)
+ );
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn new_member_commits_proposal(proposal: Proposal) -> Result<ProvisionalState, MlsError> {
+ let cache = make_proposal_cache();
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+ let kem_output = vec![0; cipher_suite_provider.kdf_extract_size()];
+ let group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let public_tree = &group.group.state.public_tree;
+
+ cache
+ .resolve_for_commit_default(
+ Sender::NewMemberCommit,
+ [
+ Proposal::ExternalInit(ExternalInit { kem_output }),
+ proposal,
+ ]
+ .into_iter()
+ .map(|p| ProposalOrRef::Proposal(Box::new(p)))
+ .collect(),
+ Some(&test_node().await),
+ &group.group.context().extensions,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ public_tree,
+ &AlwaysFoundPskStorage,
+ pass_through_rules(),
+ )
+ .await
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn new_member_cannot_commit_add_proposal() {
+ let res = new_member_commits_proposal(Proposal::Add(Box::new(AddProposal {
+ key_package: test_key_package(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "frank").await,
+ })))
+ .await;
+
+ assert_matches!(
+ res,
+ Err(MlsError::InvalidProposalTypeInExternalCommit(
+ ProposalType::ADD
+ ))
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn new_member_cannot_commit_more_than_one_remove_proposal() {
+ let cache = make_proposal_cache();
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+ let kem_output = vec![0; cipher_suite_provider.kdf_extract_size()];
+ let group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let group_extensions = group.group.context().extensions.clone();
+ let mut public_tree = group.group.state.public_tree;
+
+ let foo = get_basic_test_node(TEST_CIPHER_SUITE, "foo").await;
+
+ let bar = get_basic_test_node(TEST_CIPHER_SUITE, "bar").await;
+
+ let test_leaf_nodes = vec![foo, bar];
+
+ let test_leaf_node_indexes = public_tree
+ .add_leaves(
+ test_leaf_nodes,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap();
+
+ let proposals = vec![
+ Proposal::ExternalInit(ExternalInit { kem_output }),
+ Proposal::Remove(RemoveProposal {
+ to_remove: test_leaf_node_indexes[0],
+ }),
+ Proposal::Remove(RemoveProposal {
+ to_remove: test_leaf_node_indexes[1],
+ }),
+ ];
+
+ let res = cache
+ .resolve_for_commit_default(
+ Sender::NewMemberCommit,
+ proposals
+ .into_iter()
+ .map(|p| ProposalOrRef::Proposal(Box::new(p)))
+ .collect(),
+ Some(&test_node().await),
+ &group_extensions,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ &public_tree,
+ &AlwaysFoundPskStorage,
+ pass_through_rules(),
+ )
+ .await;
+
+ assert_matches!(res, Err(MlsError::ExternalCommitWithMoreThanOneRemove));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn new_member_remove_proposal_invalid_credential() {
+ let cache = make_proposal_cache();
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+ let kem_output = vec![0; cipher_suite_provider.kdf_extract_size()];
+ let group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let group_extensions = group.group.context().extensions.clone();
+ let mut public_tree = group.group.state.public_tree;
+
+ let node = get_basic_test_node(TEST_CIPHER_SUITE, "bar").await;
+
+ let test_leaf_nodes = vec![node];
+
+ let test_leaf_node_indexes = public_tree
+ .add_leaves(
+ test_leaf_nodes,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap();
+
+ let proposals = vec![
+ Proposal::ExternalInit(ExternalInit { kem_output }),
+ Proposal::Remove(RemoveProposal {
+ to_remove: test_leaf_node_indexes[0],
+ }),
+ ];
+
+ let res = cache
+ .resolve_for_commit_default(
+ Sender::NewMemberCommit,
+ proposals
+ .into_iter()
+ .map(|p| ProposalOrRef::Proposal(Box::new(p)))
+ .collect(),
+ Some(&test_node().await),
+ &group_extensions,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ &public_tree,
+ &AlwaysFoundPskStorage,
+ pass_through_rules(),
+ )
+ .await;
+
+ assert_matches!(res, Err(MlsError::ExternalCommitRemovesOtherIdentity));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn new_member_remove_proposal_valid_credential() {
+ let cache = make_proposal_cache();
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+ let kem_output = vec![0; cipher_suite_provider.kdf_extract_size()];
+ let group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let group_extensions = group.group.context().extensions.clone();
+ let mut public_tree = group.group.state.public_tree;
+
+ let node = get_basic_test_node(TEST_CIPHER_SUITE, "foo").await;
+
+ let test_leaf_nodes = vec![node];
+
+ let test_leaf_node_indexes = public_tree
+ .add_leaves(
+ test_leaf_nodes,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap();
+
+ let proposals = vec![
+ Proposal::ExternalInit(ExternalInit { kem_output }),
+ Proposal::Remove(RemoveProposal {
+ to_remove: test_leaf_node_indexes[0],
+ }),
+ ];
+
+ let res = cache
+ .resolve_for_commit_default(
+ Sender::NewMemberCommit,
+ proposals
+ .into_iter()
+ .map(|p| ProposalOrRef::Proposal(Box::new(p)))
+ .collect(),
+ Some(&test_node().await),
+ &group_extensions,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ &public_tree,
+ &AlwaysFoundPskStorage,
+ pass_through_rules(),
+ )
+ .await;
+
+ assert_matches!(res, Ok(_));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn new_member_cannot_commit_update_proposal() {
+ let res = new_member_commits_proposal(Proposal::Update(UpdateProposal {
+ leaf_node: get_basic_test_node(TEST_CIPHER_SUITE, "foo").await,
+ }))
+ .await;
+
+ assert_matches!(
+ res,
+ Err(MlsError::InvalidProposalTypeInExternalCommit(
+ ProposalType::UPDATE
+ ))
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn new_member_cannot_commit_group_extensions_proposal() {
+ let res =
+ new_member_commits_proposal(Proposal::GroupContextExtensions(ExtensionList::new()))
+ .await;
+
+ assert_matches!(
+ res,
+ Err(MlsError::InvalidProposalTypeInExternalCommit(
+ ProposalType::GROUP_CONTEXT_EXTENSIONS,
+ ))
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn new_member_cannot_commit_reinit_proposal() {
+ let res = new_member_commits_proposal(Proposal::ReInit(ReInitProposal {
+ group_id: b"foo".to_vec(),
+ version: TEST_PROTOCOL_VERSION,
+ cipher_suite: TEST_CIPHER_SUITE,
+ extensions: ExtensionList::new(),
+ }))
+ .await;
+
+ assert_matches!(
+ res,
+ Err(MlsError::InvalidProposalTypeInExternalCommit(
+ ProposalType::RE_INIT
+ ))
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn new_member_commit_must_contain_an_external_init_proposal() {
+ let cache = make_proposal_cache();
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+ let group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let public_tree = &group.group.state.public_tree;
+
+ let res = cache
+ .resolve_for_commit_default(
+ Sender::NewMemberCommit,
+ Vec::new(),
+ Some(&test_node().await),
+ &group.group.context().extensions,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ public_tree,
+ &AlwaysFoundPskStorage,
+ pass_through_rules(),
+ )
+ .await;
+
+ assert_matches!(
+ res,
+ Err(MlsError::ExternalCommitMustHaveExactlyOneExternalInit)
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_path_update_required_empty() {
+ let cache = make_proposal_cache();
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let mut tree = TreeKemPublic::new();
+ add_member(&mut tree, "alice").await;
+ add_member(&mut tree, "bob").await;
+
+ let effects = cache
+ .prepare_commit_default(
+ Sender::Member(test_sender()),
+ vec![],
+ &get_test_group_context(1, TEST_CIPHER_SUITE).await,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ &tree,
+ None,
+ &AlwaysFoundPskStorage,
+ pass_through_rules(),
+ )
+ .await
+ .unwrap();
+
+ assert!(path_update_required(&effects.applied_proposals))
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_path_update_required_updates() {
+ let mut cache = make_proposal_cache();
+ let update = Proposal::Update(make_update_proposal("bar").await);
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ cache.insert(
+ make_proposal_ref(&update, LeafIndex(2)).await,
+ update,
+ Sender::Member(2),
+ );
+
+ let mut tree = TreeKemPublic::new();
+ add_member(&mut tree, "alice").await;
+ add_member(&mut tree, "bob").await;
+ add_member(&mut tree, "carol").await;
+
+ let effects = cache
+ .prepare_commit_default(
+ Sender::Member(test_sender()),
+ Vec::new(),
+ &get_test_group_context(1, TEST_CIPHER_SUITE).await,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ &tree,
+ None,
+ &AlwaysFoundPskStorage,
+ pass_through_rules(),
+ )
+ .await
+ .unwrap();
+
+ assert!(path_update_required(&effects.applied_proposals))
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_path_update_required_removes() {
+ let cache = make_proposal_cache();
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let (alice_leaf, alice_secret, _) =
+ get_basic_test_node_sig_key(TEST_CIPHER_SUITE, "alice").await;
+ let alice = 0;
+
+ let (mut tree, _) = TreeKemPublic::derive(
+ alice_leaf,
+ alice_secret,
+ &BasicIdentityProvider,
+ &Default::default(),
+ )
+ .await
+ .unwrap();
+
+ let bob_node = get_basic_test_node(TEST_CIPHER_SUITE, "bob").await;
+
+ let bob = tree
+ .add_leaves(
+ vec![bob_node],
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap()[0];
+
+ let remove = Proposal::Remove(RemoveProposal { to_remove: bob });
+
+ let effects = cache
+ .prepare_commit_default(
+ Sender::Member(alice),
+ vec![remove],
+ &get_test_group_context(1, TEST_CIPHER_SUITE).await,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ &tree,
+ None,
+ &AlwaysFoundPskStorage,
+ pass_through_rules(),
+ )
+ .await
+ .unwrap();
+
+ assert!(path_update_required(&effects.applied_proposals))
+ }
+
+ #[cfg(feature = "psk")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_path_update_not_required() {
+ let (alice, tree) = new_tree("alice").await;
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+ let cache = make_proposal_cache();
+
+ let psk = Proposal::Psk(PreSharedKeyProposal {
+ psk: PreSharedKeyID::new(
+ JustPreSharedKeyID::External(ExternalPskId::new(vec![])),
+ &test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .unwrap(),
+ });
+
+ let add = Proposal::Add(Box::new(AddProposal {
+ key_package: test_key_package(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "bob").await,
+ }));
+
+ let effects = cache
+ .prepare_commit_default(
+ Sender::Member(*alice),
+ vec![psk, add],
+ &get_test_group_context(1, TEST_CIPHER_SUITE).await,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ &tree,
+ None,
+ &AlwaysFoundPskStorage,
+ pass_through_rules(),
+ )
+ .await
+ .unwrap();
+
+ assert!(!path_update_required(&effects.applied_proposals))
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn path_update_is_not_required_for_re_init() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+ let (alice, tree) = new_tree("alice").await;
+ let cache = make_proposal_cache();
+
+ let reinit = Proposal::ReInit(ReInitProposal {
+ group_id: vec![],
+ version: TEST_PROTOCOL_VERSION,
+ cipher_suite: TEST_CIPHER_SUITE,
+ extensions: Default::default(),
+ });
+
+ let effects = cache
+ .prepare_commit_default(
+ Sender::Member(*alice),
+ vec![reinit],
+ &get_test_group_context(1, TEST_CIPHER_SUITE).await,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ &tree,
+ None,
+ &AlwaysFoundPskStorage,
+ pass_through_rules(),
+ )
+ .await
+ .unwrap();
+
+ assert!(!path_update_required(&effects.applied_proposals))
+ }
+
+ #[derive(Debug)]
+ struct CommitSender<'a, C, F, P, CSP> {
+ cipher_suite_provider: CSP,
+ tree: &'a TreeKemPublic,
+ sender: LeafIndex,
+ cache: ProposalCache,
+ additional_proposals: Vec<Proposal>,
+ identity_provider: C,
+ user_rules: F,
+ psk_storage: P,
+ }
+
+ impl<'a, CSP>
+ CommitSender<'a, BasicWithCustomProvider, DefaultMlsRules, AlwaysFoundPskStorage, CSP>
+ {
+ fn new(tree: &'a TreeKemPublic, sender: LeafIndex, cipher_suite_provider: CSP) -> Self {
+ Self {
+ tree,
+ sender,
+ cache: make_proposal_cache(),
+ additional_proposals: Vec::new(),
+ identity_provider: BasicWithCustomProvider::new(BasicIdentityProvider::new()),
+ user_rules: pass_through_rules(),
+ psk_storage: AlwaysFoundPskStorage,
+ cipher_suite_provider,
+ }
+ }
+ }
+
+ impl<'a, C, F, P, CSP> CommitSender<'a, C, F, P, CSP>
+ where
+ C: IdentityProvider,
+ F: MlsRules,
+ P: PreSharedKeyStorage,
+ CSP: CipherSuiteProvider,
+ {
+ #[cfg(feature = "by_ref_proposal")]
+ fn with_identity_provider<V>(self, identity_provider: V) -> CommitSender<'a, V, F, P, CSP>
+ where
+ V: IdentityProvider,
+ {
+ CommitSender {
+ identity_provider,
+ cipher_suite_provider: self.cipher_suite_provider,
+ tree: self.tree,
+ sender: self.sender,
+ cache: self.cache,
+ additional_proposals: self.additional_proposals,
+ user_rules: self.user_rules,
+ psk_storage: self.psk_storage,
+ }
+ }
+
+ fn cache<S>(mut self, r: ProposalRef, p: Proposal, proposer: S) -> Self
+ where
+ S: Into<Sender>,
+ {
+ self.cache.insert(r, p, proposer.into());
+ self
+ }
+
+ fn with_additional<I>(mut self, proposals: I) -> Self
+ where
+ I: IntoIterator<Item = Proposal>,
+ {
+ self.additional_proposals.extend(proposals);
+ self
+ }
+
+ fn with_user_rules<G>(self, f: G) -> CommitSender<'a, C, G, P, CSP>
+ where
+ G: MlsRules,
+ {
+ CommitSender {
+ tree: self.tree,
+ sender: self.sender,
+ cache: self.cache,
+ additional_proposals: self.additional_proposals,
+ identity_provider: self.identity_provider,
+ user_rules: f,
+ psk_storage: self.psk_storage,
+ cipher_suite_provider: self.cipher_suite_provider,
+ }
+ }
+
+ fn with_psk_storage<V>(self, v: V) -> CommitSender<'a, C, F, V, CSP>
+ where
+ V: PreSharedKeyStorage,
+ {
+ CommitSender {
+ tree: self.tree,
+ sender: self.sender,
+ cache: self.cache,
+ additional_proposals: self.additional_proposals,
+ identity_provider: self.identity_provider,
+ user_rules: self.user_rules,
+ psk_storage: v,
+ cipher_suite_provider: self.cipher_suite_provider,
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn send(&self) -> Result<(Vec<ProposalOrRef>, ProvisionalState), MlsError> {
+ let state = self
+ .cache
+ .prepare_commit_default(
+ Sender::Member(*self.sender),
+ self.additional_proposals.clone(),
+ &get_test_group_context(1, TEST_CIPHER_SUITE).await,
+ &self.identity_provider,
+ &self.cipher_suite_provider,
+ self.tree,
+ None,
+ &self.psk_storage,
+ &self.user_rules,
+ )
+ .await?;
+
+ let proposals = state.applied_proposals.clone().into_proposals_or_refs();
+
+ Ok((proposals, state))
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn key_package_with_invalid_signature() -> KeyPackage {
+ let mut kp = test_key_package(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "mallory").await;
+ kp.signature.clear();
+ kp
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn key_package_with_public_key(key: crypto::HpkePublicKey) -> KeyPackage {
+ let cs = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let (mut key_package, signer) =
+ test_key_package_with_signer(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "test").await;
+
+ key_package.leaf_node.public_key = key;
+
+ key_package
+ .leaf_node
+ .sign(
+ &cs,
+ &signer,
+ &LeafNodeSigningContext {
+ group_id: None,
+ leaf_index: None,
+ },
+ )
+ .await
+ .unwrap();
+
+ key_package.sign(&cs, &signer, &()).await.unwrap();
+
+ key_package
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_add_with_invalid_key_package_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .receive([Proposal::Add(Box::new(AddProposal {
+ key_package: key_package_with_invalid_signature().await,
+ }))])
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidSignature));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_additional_add_with_invalid_key_package_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_additional([Proposal::Add(Box::new(AddProposal {
+ key_package: key_package_with_invalid_signature().await,
+ }))])
+ .send()
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidSignature));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_add_with_invalid_key_package_filters_it_out() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let proposal = Proposal::Add(Box::new(AddProposal {
+ key_package: key_package_with_invalid_signature().await,
+ }));
+
+ let proposal_info = make_proposal_info(&proposal, alice).await;
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .cache(
+ proposal_info.proposal_ref().unwrap().clone(),
+ proposal.clone(),
+ alice,
+ )
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(processed_proposals.0, Vec::new());
+
+ #[cfg(feature = "state_update")]
+ assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_add_with_hpke_key_of_another_member_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_additional([Proposal::Add(Box::new(AddProposal {
+ key_package: key_package_with_public_key(
+ tree.get_leaf_node(alice).unwrap().public_key.clone(),
+ )
+ .await,
+ }))])
+ .send()
+ .await;
+
+ assert_matches!(res, Err(MlsError::DuplicateLeafData(_)));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_add_with_hpke_key_of_another_member_filters_it_out() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let proposal = Proposal::Add(Box::new(AddProposal {
+ key_package: key_package_with_public_key(
+ tree.get_leaf_node(alice).unwrap().public_key.clone(),
+ )
+ .await,
+ }));
+
+ let proposal_info = make_proposal_info(&proposal, alice).await;
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .cache(
+ proposal_info.proposal_ref().unwrap().clone(),
+ proposal.clone(),
+ alice,
+ )
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(processed_proposals.0, Vec::new());
+
+ #[cfg(feature = "state_update")]
+ assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_update_with_invalid_leaf_node_fails() {
+ let (alice, mut tree) = new_tree("alice").await;
+ let bob = add_member(&mut tree, "bob").await;
+
+ let proposal = Proposal::Update(UpdateProposal {
+ leaf_node: get_basic_test_node(TEST_CIPHER_SUITE, "alice").await,
+ });
+
+ let proposal_ref = make_proposal_ref(&proposal, bob).await;
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ bob,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .cache(proposal_ref.clone(), proposal, bob)
+ .receive([proposal_ref])
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidLeafNodeSource));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_update_with_invalid_leaf_node_filters_it_out() {
+ let (alice, mut tree) = new_tree("alice").await;
+ let bob = add_member(&mut tree, "bob").await;
+
+ let proposal = Proposal::Update(UpdateProposal {
+ leaf_node: get_basic_test_node(TEST_CIPHER_SUITE, "alice").await,
+ });
+
+ let proposal_info = make_proposal_info(&proposal, bob).await;
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .cache(proposal_info.proposal_ref().unwrap().clone(), proposal, bob)
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(processed_proposals.0, Vec::new());
+
+ #[cfg(feature = "state_update")]
+ assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_remove_with_invalid_index_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .receive([Proposal::Remove(RemoveProposal {
+ to_remove: LeafIndex(10),
+ })])
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidNodeIndex(20)));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_additional_remove_with_invalid_index_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_additional([Proposal::Remove(RemoveProposal {
+ to_remove: LeafIndex(10),
+ })])
+ .send()
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidNodeIndex(20)));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_remove_with_invalid_index_filters_it_out() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let proposal = Proposal::Remove(RemoveProposal {
+ to_remove: LeafIndex(10),
+ });
+
+ let proposal_info = make_proposal_info(&proposal, alice).await;
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .cache(
+ proposal_info.proposal_ref().unwrap().clone(),
+ proposal.clone(),
+ alice,
+ )
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(processed_proposals.0, Vec::new());
+
+ #[cfg(feature = "state_update")]
+ assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]);
+ }
+
+ #[cfg(feature = "psk")]
+ fn make_external_psk(id: &[u8], nonce: PskNonce) -> PreSharedKeyProposal {
+ PreSharedKeyProposal {
+ psk: PreSharedKeyID {
+ key_id: JustPreSharedKeyID::External(ExternalPskId::new(id.to_vec())),
+ psk_nonce: nonce,
+ },
+ }
+ }
+
+ #[cfg(feature = "psk")]
+ fn new_external_psk(id: &[u8]) -> PreSharedKeyProposal {
+ make_external_psk(
+ id,
+ PskNonce::random(&test_cipher_suite_provider(TEST_CIPHER_SUITE)).unwrap(),
+ )
+ }
+
+ #[cfg(feature = "psk")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_psk_with_invalid_nonce_fails() {
+ let invalid_nonce = PskNonce(vec![0, 1, 2]);
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .receive([Proposal::Psk(make_external_psk(
+ b"foo",
+ invalid_nonce.clone(),
+ ))])
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidPskNonceLength,));
+ }
+
+ #[cfg(feature = "psk")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_additional_psk_with_invalid_nonce_fails() {
+ let invalid_nonce = PskNonce(vec![0, 1, 2]);
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_additional([Proposal::Psk(make_external_psk(
+ b"foo",
+ invalid_nonce.clone(),
+ ))])
+ .send()
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidPskNonceLength));
+ }
+
+ #[cfg(feature = "psk")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_psk_with_invalid_nonce_filters_it_out() {
+ let invalid_nonce = PskNonce(vec![0, 1, 2]);
+ let (alice, tree) = new_tree("alice").await;
+ let proposal = Proposal::Psk(make_external_psk(b"foo", invalid_nonce));
+
+ let proposal_info = make_proposal_info(&proposal, alice).await;
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .cache(
+ proposal_info.proposal_ref().unwrap().clone(),
+ proposal.clone(),
+ alice,
+ )
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(processed_proposals.0, Vec::new());
+
+ #[cfg(feature = "state_update")]
+ assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]);
+ }
+
+ #[cfg(feature = "psk")]
+ fn make_resumption_psk(usage: ResumptionPSKUsage) -> PreSharedKeyProposal {
+ PreSharedKeyProposal {
+ psk: PreSharedKeyID {
+ key_id: JustPreSharedKeyID::Resumption(ResumptionPsk {
+ usage,
+ psk_group_id: PskGroupId(TEST_GROUP.to_vec()),
+ psk_epoch: 1,
+ }),
+ psk_nonce: PskNonce::random(&test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .unwrap(),
+ },
+ }
+ }
+
+ #[cfg(feature = "psk")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn receiving_resumption_psk_with_bad_usage_fails(usage: ResumptionPSKUsage) {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .receive([Proposal::Psk(make_resumption_psk(usage))])
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidTypeOrUsageInPreSharedKeyProposal));
+ }
+
+ #[cfg(feature = "psk")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn sending_additional_resumption_psk_with_bad_usage_fails(usage: ResumptionPSKUsage) {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_additional([Proposal::Psk(make_resumption_psk(usage))])
+ .send()
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidTypeOrUsageInPreSharedKeyProposal));
+ }
+
+ #[cfg(feature = "psk")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn sending_resumption_psk_with_bad_usage_filters_it_out(usage: ResumptionPSKUsage) {
+ let (alice, tree) = new_tree("alice").await;
+ let proposal = Proposal::Psk(make_resumption_psk(usage));
+ let proposal_info = make_proposal_info(&proposal, alice).await;
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .cache(
+ proposal_info.proposal_ref().unwrap().clone(),
+ proposal.clone(),
+ alice,
+ )
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(processed_proposals.0, Vec::new());
+
+ #[cfg(feature = "state_update")]
+ assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]);
+ }
+
+ #[cfg(feature = "psk")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_resumption_psk_with_reinit_usage_fails() {
+ receiving_resumption_psk_with_bad_usage_fails(ResumptionPSKUsage::Reinit).await;
+ }
+
+ #[cfg(feature = "psk")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_additional_resumption_psk_with_reinit_usage_fails() {
+ sending_additional_resumption_psk_with_bad_usage_fails(ResumptionPSKUsage::Reinit).await;
+ }
+
+ #[cfg(feature = "psk")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_resumption_psk_with_reinit_usage_filters_it_out() {
+ sending_resumption_psk_with_bad_usage_filters_it_out(ResumptionPSKUsage::Reinit).await;
+ }
+
+ #[cfg(feature = "psk")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_resumption_psk_with_branch_usage_fails() {
+ receiving_resumption_psk_with_bad_usage_fails(ResumptionPSKUsage::Branch).await;
+ }
+
+ #[cfg(feature = "psk")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_additional_resumption_psk_with_branch_usage_fails() {
+ sending_additional_resumption_psk_with_bad_usage_fails(ResumptionPSKUsage::Branch).await;
+ }
+
+ #[cfg(feature = "psk")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_resumption_psk_with_branch_usage_filters_it_out() {
+ sending_resumption_psk_with_bad_usage_filters_it_out(ResumptionPSKUsage::Branch).await;
+ }
+
+ fn make_reinit(version: ProtocolVersion) -> ReInitProposal {
+ ReInitProposal {
+ group_id: TEST_GROUP.to_vec(),
+ version,
+ cipher_suite: TEST_CIPHER_SUITE,
+ extensions: ExtensionList::new(),
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_reinit_downgrading_version_fails() {
+ let smaller_protocol_version = ProtocolVersion::from(0);
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .receive([Proposal::ReInit(make_reinit(smaller_protocol_version))])
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidProtocolVersionInReInit));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_additional_reinit_downgrading_version_fails() {
+ let smaller_protocol_version = ProtocolVersion::from(0);
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_additional([Proposal::ReInit(make_reinit(smaller_protocol_version))])
+ .send()
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidProtocolVersionInReInit));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_reinit_downgrading_version_filters_it_out() {
+ let smaller_protocol_version = ProtocolVersion::from(0);
+ let (alice, tree) = new_tree("alice").await;
+ let proposal = Proposal::ReInit(make_reinit(smaller_protocol_version));
+ let proposal_info = make_proposal_info(&proposal, alice).await;
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .cache(
+ proposal_info.proposal_ref().unwrap().clone(),
+ proposal.clone(),
+ alice,
+ )
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(processed_proposals.0, Vec::new());
+
+ #[cfg(feature = "state_update")]
+ assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_update_for_committer_fails() {
+ let (alice, tree) = new_tree("alice").await;
+ let update = Proposal::Update(make_update_proposal("alice").await);
+ let update_ref = make_proposal_ref(&update, alice).await;
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .cache(update_ref.clone(), update, alice)
+ .receive([update_ref])
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidCommitSelfUpdate));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_additional_update_for_committer_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_additional([Proposal::Update(make_update_proposal("alice").await)])
+ .send()
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidProposalTypeForSender));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_update_for_committer_filters_it_out() {
+ let (alice, tree) = new_tree("alice").await;
+ let proposal = Proposal::Update(make_update_proposal("alice").await);
+ let proposal_info = make_proposal_info(&proposal, alice).await;
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .cache(
+ proposal_info.proposal_ref().unwrap().clone(),
+ proposal.clone(),
+ alice,
+ )
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(processed_proposals.0, Vec::new());
+
+ #[cfg(feature = "state_update")]
+ assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_remove_for_committer_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .receive([Proposal::Remove(RemoveProposal { to_remove: alice })])
+ .await;
+
+ assert_matches!(res, Err(MlsError::CommitterSelfRemoval));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_additional_remove_for_committer_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_additional([Proposal::Remove(RemoveProposal { to_remove: alice })])
+ .send()
+ .await;
+
+ assert_matches!(res, Err(MlsError::CommitterSelfRemoval));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_remove_for_committer_filters_it_out() {
+ let (alice, tree) = new_tree("alice").await;
+ let proposal = Proposal::Remove(RemoveProposal { to_remove: alice });
+ let proposal_info = make_proposal_info(&proposal, alice).await;
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .cache(
+ proposal_info.proposal_ref().unwrap().clone(),
+ proposal.clone(),
+ alice,
+ )
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(processed_proposals.0, Vec::new());
+
+ #[cfg(feature = "state_update")]
+ assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_update_and_remove_for_same_leaf_fails() {
+ let (alice, mut tree) = new_tree("alice").await;
+ let bob = add_member(&mut tree, "bob").await;
+
+ let update = Proposal::Update(make_update_proposal("bob").await);
+ let update_ref = make_proposal_ref(&update, bob).await;
+
+ let remove = Proposal::Remove(RemoveProposal { to_remove: bob });
+ let remove_ref = make_proposal_ref(&remove, bob).await;
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .cache(update_ref.clone(), update, bob)
+ .cache(remove_ref.clone(), remove, bob)
+ .receive([update_ref, remove_ref])
+ .await;
+
+ assert_matches!(res, Err(MlsError::UpdatingNonExistingMember));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_update_and_remove_for_same_leaf_filters_update_out() {
+ let (alice, mut tree) = new_tree("alice").await;
+ let bob = add_member(&mut tree, "bob").await;
+
+ let update = Proposal::Update(make_update_proposal("bob").await);
+ let update_info = make_proposal_info(&update, alice).await;
+
+ let remove = Proposal::Remove(RemoveProposal { to_remove: bob });
+ let remove_ref = make_proposal_ref(&remove, alice).await;
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .cache(
+ update_info.proposal_ref().unwrap().clone(),
+ update.clone(),
+ alice,
+ )
+ .cache(remove_ref.clone(), remove, alice)
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(processed_proposals.0, vec![remove_ref.into()]);
+
+ #[cfg(feature = "state_update")]
+ assert_eq!(processed_proposals.1.unused_proposals, vec![update_info]);
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn make_add_proposal() -> Box<AddProposal> {
+ Box::new(AddProposal {
+ key_package: test_key_package(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "frank").await,
+ })
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_add_proposals_for_same_client_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .receive([
+ Proposal::Add(make_add_proposal().await),
+ Proposal::Add(make_add_proposal().await),
+ ])
+ .await;
+
+ assert_matches!(res, Err(MlsError::DuplicateLeafData(1)));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_additional_add_proposals_for_same_client_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_additional([
+ Proposal::Add(make_add_proposal().await),
+ Proposal::Add(make_add_proposal().await),
+ ])
+ .send()
+ .await;
+
+ assert_matches!(res, Err(MlsError::DuplicateLeafData(1)));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_add_proposals_for_same_client_keeps_only_one() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let add_one = Proposal::Add(make_add_proposal().await);
+ let add_two = Proposal::Add(make_add_proposal().await);
+ let add_ref_one = make_proposal_ref(&add_one, alice).await;
+ let add_ref_two = make_proposal_ref(&add_two, alice).await;
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .cache(add_ref_one.clone(), add_one.clone(), alice)
+ .cache(add_ref_two.clone(), add_two.clone(), alice)
+ .send()
+ .await
+ .unwrap();
+
+ let committed_add_ref = match &*processed_proposals.0 {
+ [ProposalOrRef::Reference(add_ref)] => add_ref,
+ _ => panic!("committed proposals list does not contain exactly one reference"),
+ };
+
+ let add_refs = [add_ref_one, add_ref_two];
+ assert!(add_refs.contains(committed_add_ref));
+
+ #[cfg(feature = "state_update")]
+ assert_matches!(
+ &*processed_proposals.1.unused_proposals,
+ [rejected_add_info] if committed_add_ref != rejected_add_info.proposal_ref().unwrap() && add_refs.contains(rejected_add_info.proposal_ref().unwrap())
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_update_for_different_identity_fails() {
+ let (alice, mut tree) = new_tree("alice").await;
+ let bob = add_member(&mut tree, "bob").await;
+
+ let update = Proposal::Update(make_update_proposal_custom("carol", 1).await);
+ let update_ref = make_proposal_ref(&update, bob).await;
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .cache(update_ref.clone(), update, bob)
+ .receive([update_ref])
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidSuccessor));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_update_for_different_identity_filters_it_out() {
+ let (alice, mut tree) = new_tree("alice").await;
+ let bob = add_member(&mut tree, "bob").await;
+
+ let update = Proposal::Update(make_update_proposal("carol").await);
+ let update_info = make_proposal_info(&update, bob).await;
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .cache(update_info.proposal_ref().unwrap().clone(), update, bob)
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(processed_proposals.0, Vec::new());
+
+ // Bob proposed the update, so it is not listed as rejected when Alice commits it because
+ // she didn't propose it.
+ #[cfg(feature = "state_update")]
+ assert_eq!(processed_proposals.1.unused_proposals, vec![update_info]);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_add_for_same_client_as_existing_member_fails() {
+ let (alice, public_tree) = new_tree("alice").await;
+ let add = Proposal::Add(make_add_proposal().await);
+
+ let ProvisionalState { public_tree, .. } = CommitReceiver::new(
+ &public_tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .receive([add.clone()])
+ .await
+ .unwrap();
+
+ let res = CommitReceiver::new(
+ &public_tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .receive([add])
+ .await;
+
+ assert_matches!(res, Err(MlsError::DuplicateLeafData(1)));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_additional_add_for_same_client_as_existing_member_fails() {
+ let (alice, public_tree) = new_tree("alice").await;
+ let add = Proposal::Add(make_add_proposal().await);
+
+ let ProvisionalState { public_tree, .. } = CommitReceiver::new(
+ &public_tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .receive([add.clone()])
+ .await
+ .unwrap();
+
+ let res = CommitSender::new(
+ &public_tree,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .with_additional([add])
+ .send()
+ .await;
+
+ assert_matches!(res, Err(MlsError::DuplicateLeafData(1)));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_add_for_same_client_as_existing_member_filters_it_out() {
+ let (alice, public_tree) = new_tree("alice").await;
+ let add = Proposal::Add(make_add_proposal().await);
+
+ let ProvisionalState { public_tree, .. } = CommitReceiver::new(
+ &public_tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .receive([add.clone()])
+ .await
+ .unwrap();
+
+ let proposal_info = make_proposal_info(&add, alice).await;
+
+ let processed_proposals = CommitSender::new(
+ &public_tree,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .cache(
+ proposal_info.proposal_ref().unwrap().clone(),
+ add.clone(),
+ alice,
+ )
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(processed_proposals.0, Vec::new());
+
+ #[cfg(feature = "state_update")]
+ assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]);
+ }
+
+ #[cfg(feature = "psk")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_psk_proposals_with_same_psk_id_fails() {
+ let (alice, tree) = new_tree("alice").await;
+ let psk_proposal = Proposal::Psk(new_external_psk(b"foo"));
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .receive([psk_proposal.clone(), psk_proposal])
+ .await;
+
+ assert_matches!(res, Err(MlsError::DuplicatePskIds));
+ }
+
+ #[cfg(feature = "psk")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_additional_psk_proposals_with_same_psk_id_fails() {
+ let (alice, tree) = new_tree("alice").await;
+ let psk_proposal = Proposal::Psk(new_external_psk(b"foo"));
+
+ let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_additional([psk_proposal.clone(), psk_proposal])
+ .send()
+ .await;
+
+ assert_matches!(res, Err(MlsError::DuplicatePskIds));
+ }
+
+ #[cfg(feature = "psk")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_psk_proposals_with_same_psk_id_keeps_only_one() {
+ let (alice, mut tree) = new_tree("alice").await;
+ let bob = add_member(&mut tree, "bob").await;
+
+ let proposal = Proposal::Psk(new_external_psk(b"foo"));
+
+ let proposal_info = [
+ make_proposal_info(&proposal, alice).await,
+ make_proposal_info(&proposal, bob).await,
+ ];
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .cache(
+ proposal_info[0].proposal_ref().unwrap().clone(),
+ proposal.clone(),
+ alice,
+ )
+ .cache(
+ proposal_info[1].proposal_ref().unwrap().clone(),
+ proposal,
+ bob,
+ )
+ .send()
+ .await
+ .unwrap();
+
+ let committed_info = match processed_proposals
+ .1
+ .applied_proposals
+ .clone()
+ .into_proposals()
+ .collect_vec()
+ .as_slice()
+ {
+ [r] => r.clone(),
+ _ => panic!("Expected single proposal reference in {processed_proposals:?}"),
+ };
+
+ assert!(proposal_info.contains(&committed_info));
+
+ #[cfg(feature = "state_update")]
+ match &*processed_proposals.1.unused_proposals {
+ [r] => {
+ assert_ne!(*r, committed_info);
+ assert!(proposal_info.contains(r));
+ }
+ _ => panic!(
+ "Expected one proposal reference in {:?}",
+ processed_proposals.1.unused_proposals
+ ),
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_multiple_group_context_extensions_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .receive([
+ Proposal::GroupContextExtensions(ExtensionList::new()),
+ Proposal::GroupContextExtensions(ExtensionList::new()),
+ ])
+ .await;
+
+ assert_matches!(
+ res,
+ Err(MlsError::MoreThanOneGroupContextExtensionsProposal)
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_multiple_additional_group_context_extensions_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_additional([
+ Proposal::GroupContextExtensions(ExtensionList::new()),
+ Proposal::GroupContextExtensions(ExtensionList::new()),
+ ])
+ .send()
+ .await;
+
+ assert_matches!(
+ res,
+ Err(MlsError::MoreThanOneGroupContextExtensionsProposal)
+ );
+ }
+
+ fn make_extension_list(foo: u8) -> ExtensionList {
+ vec![TestExtension { foo }.into_extension().unwrap()].into()
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_multiple_group_context_extensions_keeps_only_one() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let (alice, tree) = {
+ let (signing_identity, signature_key) =
+ get_test_signing_identity(TEST_CIPHER_SUITE, b"alice").await;
+
+ let properties = ConfigProperties {
+ capabilities: Capabilities {
+ extensions: vec![42.into()],
+ ..Capabilities::default()
+ },
+ extensions: Default::default(),
+ };
+
+ let (leaf, secret) = LeafNode::generate(
+ &cipher_suite_provider,
+ properties,
+ signing_identity,
+ &signature_key,
+ Lifetime::years(1).unwrap(),
+ )
+ .await
+ .unwrap();
+
+ let (pub_tree, priv_tree) =
+ TreeKemPublic::derive(leaf, secret, &BasicIdentityProvider, &Default::default())
+ .await
+ .unwrap();
+
+ (priv_tree.self_index, pub_tree)
+ };
+
+ let proposals = [
+ Proposal::GroupContextExtensions(make_extension_list(0)),
+ Proposal::GroupContextExtensions(make_extension_list(1)),
+ ];
+
+ let gce_info = [
+ make_proposal_info(&proposals[0], alice).await,
+ make_proposal_info(&proposals[1], alice).await,
+ ];
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .cache(
+ gce_info[0].proposal_ref().unwrap().clone(),
+ proposals[0].clone(),
+ alice,
+ )
+ .cache(
+ gce_info[1].proposal_ref().unwrap().clone(),
+ proposals[1].clone(),
+ alice,
+ )
+ .send()
+ .await
+ .unwrap();
+
+ let committed_gce_info = match processed_proposals
+ .1
+ .applied_proposals
+ .clone()
+ .into_proposals()
+ .collect_vec()
+ .as_slice()
+ {
+ [gce_info] => gce_info.clone(),
+ _ => panic!("committed proposals list does not contain exactly one reference"),
+ };
+
+ assert!(gce_info.contains(&committed_gce_info));
+
+ #[cfg(feature = "state_update")]
+ assert_matches!(
+ &*processed_proposals.1.unused_proposals,
+ [rejected_gce_info] if committed_gce_info != *rejected_gce_info && gce_info.contains(rejected_gce_info)
+ );
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn make_external_senders_extension() -> ExtensionList {
+ let identity = get_test_signing_identity(TEST_CIPHER_SUITE, b"alice")
+ .await
+ .0;
+
+ vec![ExternalSendersExt::new(vec![identity])
+ .into_extension()
+ .unwrap()]
+ .into()
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_invalid_external_senders_extension_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .with_identity_provider(FailureIdentityProvider::new())
+ .receive([Proposal::GroupContextExtensions(
+ make_external_senders_extension().await,
+ )])
+ .await;
+
+ assert_matches!(res, Err(MlsError::IdentityProviderError(_)));
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_additional_invalid_external_senders_extension_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_identity_provider(FailureIdentityProvider::new())
+ .with_additional([Proposal::GroupContextExtensions(
+ make_external_senders_extension().await,
+ )])
+ .send()
+ .await;
+
+ assert_matches!(res, Err(MlsError::IdentityProviderError(_)));
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_invalid_external_senders_extension_filters_it_out() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let proposal = Proposal::GroupContextExtensions(make_external_senders_extension().await);
+
+ let proposal_info = make_proposal_info(&proposal, alice).await;
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_identity_provider(FailureIdentityProvider::new())
+ .cache(
+ proposal_info.proposal_ref().unwrap().clone(),
+ proposal.clone(),
+ alice,
+ )
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(processed_proposals.0, Vec::new());
+
+ #[cfg(feature = "state_update")]
+ assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_reinit_with_other_proposals_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .receive([
+ Proposal::ReInit(make_reinit(TEST_PROTOCOL_VERSION)),
+ Proposal::Add(make_add_proposal().await),
+ ])
+ .await;
+
+ assert_matches!(res, Err(MlsError::OtherProposalWithReInit));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_additional_reinit_with_other_proposals_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_additional([
+ Proposal::ReInit(make_reinit(TEST_PROTOCOL_VERSION)),
+ Proposal::Add(make_add_proposal().await),
+ ])
+ .send()
+ .await;
+
+ assert_matches!(res, Err(MlsError::OtherProposalWithReInit));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_reinit_with_other_proposals_filters_it_out() {
+ let (alice, tree) = new_tree("alice").await;
+ let reinit = Proposal::ReInit(make_reinit(TEST_PROTOCOL_VERSION));
+ let reinit_info = make_proposal_info(&reinit, alice).await;
+ let add = Proposal::Add(make_add_proposal().await);
+ let add_ref = make_proposal_ref(&add, alice).await;
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .cache(
+ reinit_info.proposal_ref().unwrap().clone(),
+ reinit.clone(),
+ alice,
+ )
+ .cache(add_ref.clone(), add, alice)
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(processed_proposals.0, vec![add_ref.into()]);
+
+ #[cfg(feature = "state_update")]
+ assert_eq!(processed_proposals.1.unused_proposals, vec![reinit_info]);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_multiple_reinits_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .receive([
+ Proposal::ReInit(make_reinit(TEST_PROTOCOL_VERSION)),
+ Proposal::ReInit(make_reinit(TEST_PROTOCOL_VERSION)),
+ ])
+ .await;
+
+ assert_matches!(res, Err(MlsError::OtherProposalWithReInit));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_additional_multiple_reinits_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_additional([
+ Proposal::ReInit(make_reinit(TEST_PROTOCOL_VERSION)),
+ Proposal::ReInit(make_reinit(TEST_PROTOCOL_VERSION)),
+ ])
+ .send()
+ .await;
+
+ assert_matches!(res, Err(MlsError::OtherProposalWithReInit));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_multiple_reinits_keeps_only_one() {
+ let (alice, tree) = new_tree("alice").await;
+ let reinit = Proposal::ReInit(make_reinit(TEST_PROTOCOL_VERSION));
+ let reinit_ref = make_proposal_ref(&reinit, alice).await;
+ let other_reinit = Proposal::ReInit(ReInitProposal {
+ group_id: b"other_group".to_vec(),
+ ..make_reinit(TEST_PROTOCOL_VERSION)
+ });
+ let other_reinit_ref = make_proposal_ref(&other_reinit, alice).await;
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .cache(reinit_ref.clone(), reinit.clone(), alice)
+ .cache(other_reinit_ref.clone(), other_reinit.clone(), alice)
+ .send()
+ .await
+ .unwrap();
+
+ let processed_ref = match &*processed_proposals.0 {
+ [ProposalOrRef::Reference(r)] => r,
+ p => panic!("Expected single proposal reference but found {p:?}"),
+ };
+
+ assert!(*processed_ref == reinit_ref || *processed_ref == other_reinit_ref);
+
+ #[cfg(feature = "state_update")]
+ {
+ let (rejected_ref, unused_proposal) = match &*processed_proposals.1.unused_proposals {
+ [r] => (r.proposal_ref().unwrap().clone(), r.proposal.clone()),
+ p => panic!("Expected single proposal but found {p:?}"),
+ };
+
+ assert_ne!(rejected_ref, *processed_ref);
+ assert!(rejected_ref == reinit_ref || rejected_ref == other_reinit_ref);
+ assert!(unused_proposal == reinit || unused_proposal == other_reinit);
+ }
+ }
+
+ fn make_external_init() -> ExternalInit {
+ ExternalInit {
+ kem_output: vec![33; test_cipher_suite_provider(TEST_CIPHER_SUITE).kdf_extract_size()],
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_external_init_from_member_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .receive([Proposal::ExternalInit(make_external_init())])
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidProposalTypeForSender));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_additional_external_init_from_member_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_additional([Proposal::ExternalInit(make_external_init())])
+ .send()
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidProposalTypeForSender));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_external_init_from_member_filters_it_out() {
+ let (alice, tree) = new_tree("alice").await;
+ let external_init = Proposal::ExternalInit(make_external_init());
+ let external_init_info = make_proposal_info(&external_init, alice).await;
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .cache(
+ external_init_info.proposal_ref().unwrap().clone(),
+ external_init.clone(),
+ alice,
+ )
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(processed_proposals.0, Vec::new());
+
+ #[cfg(feature = "state_update")]
+ assert_eq!(
+ processed_proposals.1.unused_proposals,
+ vec![external_init_info]
+ );
+ }
+
+ fn required_capabilities_proposal(extension: u16) -> Proposal {
+ let required_capabilities = RequiredCapabilitiesExt {
+ extensions: vec![extension.into()],
+ ..Default::default()
+ };
+
+ let ext = vec![required_capabilities.into_extension().unwrap()];
+
+ Proposal::GroupContextExtensions(ext.into())
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_required_capabilities_not_supported_by_member_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .receive([required_capabilities_proposal(33)])
+ .await;
+
+ assert_matches!(
+ res,
+ Err(MlsError::RequiredExtensionNotFound(v)) if v == 33.into()
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_required_capabilities_not_supported_by_member_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_additional([required_capabilities_proposal(33)])
+ .send()
+ .await;
+
+ assert_matches!(
+ res,
+ Err(MlsError::RequiredExtensionNotFound(v)) if v == 33.into()
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_additional_required_capabilities_not_supported_by_member_filters_it_out() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let proposal = required_capabilities_proposal(33);
+ let proposal_info = make_proposal_info(&proposal, alice).await;
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .cache(
+ proposal_info.proposal_ref().unwrap().clone(),
+ proposal.clone(),
+ alice,
+ )
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(processed_proposals.0, Vec::new());
+
+ #[cfg(feature = "state_update")]
+ assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn committing_update_from_pk1_to_pk2_and_update_from_pk2_to_pk3_works() {
+ let (alice_leaf, alice_secret, alice_signer) =
+ get_basic_test_node_sig_key(TEST_CIPHER_SUITE, "alice").await;
+
+ let (mut tree, priv_tree) = TreeKemPublic::derive(
+ alice_leaf.clone(),
+ alice_secret,
+ &BasicIdentityProvider,
+ &Default::default(),
+ )
+ .await
+ .unwrap();
+
+ let alice = priv_tree.self_index;
+
+ let bob = add_member(&mut tree, "bob").await;
+ let carol = add_member(&mut tree, "carol").await;
+
+ let bob_current_leaf = tree.get_leaf_node(bob).unwrap();
+
+ let mut alice_new_leaf = LeafNode {
+ public_key: bob_current_leaf.public_key.clone(),
+ leaf_node_source: LeafNodeSource::Update,
+ ..alice_leaf
+ };
+
+ alice_new_leaf
+ .sign(
+ &test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ &alice_signer,
+ &(TEST_GROUP, 0).into(),
+ )
+ .await
+ .unwrap();
+
+ let bob_new_leaf = update_leaf_node("bob", 1).await;
+
+ let pk1_to_pk2 = Proposal::Update(UpdateProposal {
+ leaf_node: alice_new_leaf.clone(),
+ });
+
+ let pk1_to_pk2_ref = make_proposal_ref(&pk1_to_pk2, alice).await;
+
+ let pk2_to_pk3 = Proposal::Update(UpdateProposal {
+ leaf_node: bob_new_leaf.clone(),
+ });
+
+ let pk2_to_pk3_ref = make_proposal_ref(&pk2_to_pk3, bob).await;
+
+ let effects = CommitReceiver::new(
+ &tree,
+ carol,
+ carol,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .cache(pk1_to_pk2_ref.clone(), pk1_to_pk2, alice)
+ .cache(pk2_to_pk3_ref.clone(), pk2_to_pk3, bob)
+ .receive([pk1_to_pk2_ref, pk2_to_pk3_ref])
+ .await
+ .unwrap();
+
+ assert_eq!(effects.applied_proposals.update_senders, vec![alice, bob]);
+
+ assert_eq!(
+ effects
+ .applied_proposals
+ .updates
+ .into_iter()
+ .map(|p| p.proposal.leaf_node)
+ .collect_vec(),
+ vec![alice_new_leaf, bob_new_leaf]
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn committing_update_from_pk1_to_pk2_and_removal_of_pk2_works() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let (alice_leaf, alice_secret, alice_signer) =
+ get_basic_test_node_sig_key(TEST_CIPHER_SUITE, "alice").await;
+
+ let (mut tree, priv_tree) = TreeKemPublic::derive(
+ alice_leaf.clone(),
+ alice_secret,
+ &BasicIdentityProvider,
+ &Default::default(),
+ )
+ .await
+ .unwrap();
+
+ let alice = priv_tree.self_index;
+
+ let bob = add_member(&mut tree, "bob").await;
+ let carol = add_member(&mut tree, "carol").await;
+
+ let bob_current_leaf = tree.get_leaf_node(bob).unwrap();
+
+ let mut alice_new_leaf = LeafNode {
+ public_key: bob_current_leaf.public_key.clone(),
+ leaf_node_source: LeafNodeSource::Update,
+ ..alice_leaf
+ };
+
+ alice_new_leaf
+ .sign(
+ &cipher_suite_provider,
+ &alice_signer,
+ &(TEST_GROUP, 0).into(),
+ )
+ .await
+ .unwrap();
+
+ let pk1_to_pk2 = Proposal::Update(UpdateProposal {
+ leaf_node: alice_new_leaf.clone(),
+ });
+
+ let pk1_to_pk2_ref = make_proposal_ref(&pk1_to_pk2, alice).await;
+
+ let remove_pk2 = Proposal::Remove(RemoveProposal { to_remove: bob });
+
+ let remove_pk2_ref = make_proposal_ref(&remove_pk2, bob).await;
+
+ let effects = CommitReceiver::new(
+ &tree,
+ carol,
+ carol,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .cache(pk1_to_pk2_ref.clone(), pk1_to_pk2, alice)
+ .cache(remove_pk2_ref.clone(), remove_pk2, bob)
+ .receive([pk1_to_pk2_ref, remove_pk2_ref])
+ .await
+ .unwrap();
+
+ assert_eq!(effects.applied_proposals.update_senders, vec![alice]);
+
+ assert_eq!(
+ effects
+ .applied_proposals
+ .updates
+ .into_iter()
+ .map(|p| p.proposal.leaf_node)
+ .collect_vec(),
+ vec![alice_new_leaf]
+ );
+
+ assert_eq!(
+ effects
+ .applied_proposals
+ .removals
+ .into_iter()
+ .map(|p| p.proposal.to_remove)
+ .collect_vec(),
+ vec![bob]
+ );
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn unsupported_credential_key_package(name: &str) -> KeyPackage {
+ let (mut signing_identity, secret_key) =
+ get_test_signing_identity(TEST_CIPHER_SUITE, name.as_bytes()).await;
+
+ signing_identity.credential = Credential::Custom(CustomCredential::new(
+ CredentialType::new(BasicWithCustomProvider::CUSTOM_CREDENTIAL_TYPE),
+ random_bytes(32),
+ ));
+
+ let generator = KeyPackageGenerator {
+ protocol_version: TEST_PROTOCOL_VERSION,
+ cipher_suite_provider: &test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ signing_identity: &signing_identity,
+ signing_key: &secret_key,
+ identity_provider: &BasicWithCustomProvider::new(BasicIdentityProvider::new()),
+ };
+
+ generator
+ .generate(
+ Lifetime::years(1).unwrap(),
+ Capabilities {
+ credentials: vec![42.into()],
+ ..Default::default()
+ },
+ Default::default(),
+ Default::default(),
+ )
+ .await
+ .unwrap()
+ .key_package
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_add_with_leaf_not_supporting_credential_type_of_other_leaf_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .receive([Proposal::Add(Box::new(AddProposal {
+ key_package: unsupported_credential_key_package("bob").await,
+ }))])
+ .await;
+
+ assert_matches!(res, Err(MlsError::InUseCredentialTypeUnsupportedByNewLeaf));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_additional_add_with_leaf_not_supporting_credential_type_of_other_leaf_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_additional([Proposal::Add(Box::new(AddProposal {
+ key_package: unsupported_credential_key_package("bob").await,
+ }))])
+ .send()
+ .await;
+
+ assert_matches!(res, Err(MlsError::InUseCredentialTypeUnsupportedByNewLeaf));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_add_with_leaf_not_supporting_credential_type_of_other_leaf_filters_it_out() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let add = Proposal::Add(Box::new(AddProposal {
+ key_package: unsupported_credential_key_package("bob").await,
+ }));
+
+ let add_info = make_proposal_info(&add, alice).await;
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .cache(add_info.proposal_ref().unwrap().clone(), add.clone(), alice)
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(processed_proposals.0, Vec::new());
+
+ #[cfg(feature = "state_update")]
+ assert_eq!(processed_proposals.1.unused_proposals, vec![add_info]);
+ }
+
+ #[cfg(feature = "custom_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_custom_proposal_with_member_not_supporting_proposal_type_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let custom_proposal = Proposal::Custom(CustomProposal::new(ProposalType::new(42), vec![]));
+
+ let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_additional([custom_proposal.clone()])
+ .send()
+ .await;
+
+ assert_matches!(
+ res,
+ Err(
+ MlsError::UnsupportedCustomProposal(c)
+ ) if c == custom_proposal.proposal_type()
+ );
+ }
+
+ #[cfg(feature = "custom_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_custom_proposal_with_member_not_supporting_filters_it_out() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let custom_proposal = Proposal::Custom(CustomProposal::new(ProposalType::new(42), vec![]));
+
+ let custom_info = make_proposal_info(&custom_proposal, alice).await;
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .cache(
+ custom_info.proposal_ref().unwrap().clone(),
+ custom_proposal.clone(),
+ alice,
+ )
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(processed_proposals.0, Vec::new());
+
+ #[cfg(feature = "state_update")]
+ assert_eq!(processed_proposals.1.unused_proposals, vec![custom_info]);
+ }
+
+ #[cfg(feature = "custom_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_custom_proposal_with_member_not_supporting_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let custom_proposal = Proposal::Custom(CustomProposal::new(ProposalType::new(42), vec![]));
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .receive([custom_proposal.clone()])
+ .await;
+
+ assert_matches!(
+ res,
+ Err(MlsError::UnsupportedCustomProposal(c)) if c == custom_proposal.proposal_type()
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_group_extension_unsupported_by_leaf_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .receive([Proposal::GroupContextExtensions(make_extension_list(0))])
+ .await;
+
+ assert_matches!(
+ res,
+ Err(
+ MlsError::UnsupportedGroupExtension(v)
+ ) if v == 42.into()
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_additional_group_extension_unsupported_by_leaf_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_additional([Proposal::GroupContextExtensions(make_extension_list(0))])
+ .send()
+ .await;
+
+ assert_matches!(
+ res,
+ Err(
+ MlsError::UnsupportedGroupExtension(v)
+ ) if v == 42.into()
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_group_extension_unsupported_by_leaf_filters_it_out() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let proposal = Proposal::GroupContextExtensions(make_extension_list(0));
+ let proposal_info = make_proposal_info(&proposal, alice).await;
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .cache(
+ proposal_info.proposal_ref().unwrap().clone(),
+ proposal.clone(),
+ alice,
+ )
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(processed_proposals.0, Vec::new());
+
+ #[cfg(feature = "state_update")]
+ assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]);
+ }
+
+ #[cfg(feature = "psk")]
+ #[derive(Debug)]
+ struct AlwaysNotFoundPskStorage;
+
+ #[cfg(feature = "psk")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ #[cfg_attr(mls_build_async, maybe_async::must_be_async)]
+ impl PreSharedKeyStorage for AlwaysNotFoundPskStorage {
+ type Error = Infallible;
+
+ async fn get(&self, _: &ExternalPskId) -> Result<Option<PreSharedKey>, Self::Error> {
+ Ok(None)
+ }
+ }
+
+ #[cfg(feature = "psk")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_external_psk_with_unknown_id_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .with_psk_storage(AlwaysNotFoundPskStorage)
+ .receive([Proposal::Psk(new_external_psk(b"abc"))])
+ .await;
+
+ assert_matches!(res, Err(MlsError::MissingRequiredPsk));
+ }
+
+ #[cfg(feature = "psk")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_additional_external_psk_with_unknown_id_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_psk_storage(AlwaysNotFoundPskStorage)
+ .with_additional([Proposal::Psk(new_external_psk(b"abc"))])
+ .send()
+ .await;
+
+ assert_matches!(res, Err(MlsError::MissingRequiredPsk));
+ }
+
+ #[cfg(feature = "psk")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_external_psk_with_unknown_id_filters_it_out() {
+ let (alice, tree) = new_tree("alice").await;
+ let proposal = Proposal::Psk(new_external_psk(b"abc"));
+ let proposal_info = make_proposal_info(&proposal, alice).await;
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_psk_storage(AlwaysNotFoundPskStorage)
+ .cache(
+ proposal_info.proposal_ref().unwrap().clone(),
+ proposal.clone(),
+ alice,
+ )
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(processed_proposals.0, Vec::new());
+
+ #[cfg(feature = "state_update")]
+ assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn user_defined_filter_can_remove_proposals() {
+ struct RemoveGroupContextExtensions;
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ #[cfg_attr(mls_build_async, maybe_async::must_be_async)]
+ impl MlsRules for RemoveGroupContextExtensions {
+ type Error = Infallible;
+
+ async fn filter_proposals(
+ &self,
+ _: CommitDirection,
+ _: CommitSource,
+ _: &Roster,
+ _: &ExtensionList,
+ mut proposals: ProposalBundle,
+ ) -> Result<ProposalBundle, Self::Error> {
+ proposals.group_context_extensions.clear();
+ Ok(proposals)
+ }
+
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn commit_options(
+ &self,
+ _: &Roster,
+ _: &ExtensionList,
+ _: &ProposalBundle,
+ ) -> Result<CommitOptions, Self::Error> {
+ Ok(Default::default())
+ }
+
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn encryption_options(
+ &self,
+ _: &Roster,
+ _: &ExtensionList,
+ ) -> Result<EncryptionOptions, Self::Error> {
+ Ok(Default::default())
+ }
+ }
+
+ let (alice, tree) = new_tree("alice").await;
+
+ let (committed, _) =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_additional([Proposal::GroupContextExtensions(Default::default())])
+ .with_user_rules(RemoveGroupContextExtensions)
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(committed, Vec::new());
+ }
+
+ struct FailureMlsRules;
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ #[cfg_attr(mls_build_async, maybe_async::must_be_async)]
+ impl MlsRules for FailureMlsRules {
+ type Error = MlsError;
+
+ async fn filter_proposals(
+ &self,
+ _: CommitDirection,
+ _: CommitSource,
+ _: &Roster,
+ _: &ExtensionList,
+ _: ProposalBundle,
+ ) -> Result<ProposalBundle, Self::Error> {
+ Err(MlsError::InvalidSignature)
+ }
+
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn commit_options(
+ &self,
+ _: &Roster,
+ _: &ExtensionList,
+ _: &ProposalBundle,
+ ) -> Result<CommitOptions, Self::Error> {
+ Ok(Default::default())
+ }
+
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn encryption_options(
+ &self,
+ _: &Roster,
+ _: &ExtensionList,
+ ) -> Result<EncryptionOptions, Self::Error> {
+ Ok(Default::default())
+ }
+ }
+
+ struct InjectMlsRules {
+ to_inject: Proposal,
+ source: ProposalSource,
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ #[cfg_attr(mls_build_async, maybe_async::must_be_async)]
+ impl MlsRules for InjectMlsRules {
+ type Error = MlsError;
+
+ async fn filter_proposals(
+ &self,
+ _: CommitDirection,
+ _: CommitSource,
+ _: &Roster,
+ _: &ExtensionList,
+ mut proposals: ProposalBundle,
+ ) -> Result<ProposalBundle, Self::Error> {
+ proposals.add(
+ self.to_inject.clone(),
+ Sender::Member(0),
+ self.source.clone(),
+ );
+ Ok(proposals)
+ }
+
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn commit_options(
+ &self,
+ _: &Roster,
+ _: &ExtensionList,
+ _: &ProposalBundle,
+ ) -> Result<CommitOptions, Self::Error> {
+ Ok(Default::default())
+ }
+
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn encryption_options(
+ &self,
+ _: &Roster,
+ _: &ExtensionList,
+ ) -> Result<EncryptionOptions, Self::Error> {
+ Ok(Default::default())
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn user_defined_filter_can_inject_proposals() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let test_proposal = Proposal::GroupContextExtensions(Default::default());
+
+ let (committed, _) =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_user_rules(InjectMlsRules {
+ to_inject: test_proposal.clone(),
+ source: ProposalSource::ByValue,
+ })
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(
+ committed,
+ vec![ProposalOrRef::Proposal(test_proposal.into())]
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn user_defined_filter_can_inject_local_only_proposals() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let test_proposal = Proposal::GroupContextExtensions(Default::default());
+
+ let (committed, _) =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_user_rules(InjectMlsRules {
+ to_inject: test_proposal.clone(),
+ source: ProposalSource::Local,
+ })
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(committed, vec![]);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn user_defined_filter_cant_break_base_rules() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let test_proposal = Proposal::Update(UpdateProposal {
+ leaf_node: get_basic_test_node(TEST_CIPHER_SUITE, "leaf").await,
+ });
+
+ let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_user_rules(InjectMlsRules {
+ to_inject: test_proposal.clone(),
+ source: ProposalSource::ByValue,
+ })
+ .send()
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidProposalTypeForSender { .. }))
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn user_defined_filter_can_refuse_to_send_commit() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_additional([Proposal::GroupContextExtensions(Default::default())])
+ .with_user_rules(FailureMlsRules)
+ .send()
+ .await;
+
+ assert_matches!(res, Err(MlsError::MlsRulesError(_)));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn user_defined_filter_can_reject_incoming_commit() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .with_user_rules(FailureMlsRules)
+ .receive([Proposal::GroupContextExtensions(Default::default())])
+ .await;
+
+ assert_matches!(res, Err(MlsError::MlsRulesError(_)));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn proposers_are_verified() {
+ let (alice, mut tree) = new_tree("alice").await;
+ let bob = add_member(&mut tree, "bob").await;
+
+ #[cfg(feature = "by_ref_proposal")]
+ let identity = get_test_signing_identity(TEST_CIPHER_SUITE, b"carol")
+ .await
+ .0;
+
+ #[cfg(feature = "by_ref_proposal")]
+ let external_senders = ExternalSendersExt::new(vec![identity]);
+
+ let proposals: &[Proposal] = &[
+ Proposal::Add(make_add_proposal().await),
+ Proposal::Update(make_update_proposal("alice").await),
+ Proposal::Remove(RemoveProposal { to_remove: bob }),
+ #[cfg(feature = "psk")]
+ Proposal::Psk(make_external_psk(
+ b"ted",
+ PskNonce::random(&test_cipher_suite_provider(TEST_CIPHER_SUITE)).unwrap(),
+ )),
+ Proposal::ReInit(make_reinit(TEST_PROTOCOL_VERSION)),
+ Proposal::ExternalInit(make_external_init()),
+ Proposal::GroupContextExtensions(Default::default()),
+ ];
+
+ let proposers = [
+ Sender::Member(*alice),
+ #[cfg(feature = "by_ref_proposal")]
+ Sender::External(0),
+ Sender::NewMemberCommit,
+ Sender::NewMemberProposal,
+ ];
+
+ for ((proposer, proposal), by_ref) in proposers
+ .into_iter()
+ .cartesian_product(proposals)
+ .cartesian_product([true])
+ {
+ let committer = Sender::Member(*alice);
+
+ let receiver = CommitReceiver::new(
+ &tree,
+ committer,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ );
+
+ #[cfg(feature = "by_ref_proposal")]
+ let extensions: ExtensionList =
+ vec![external_senders.clone().into_extension().unwrap()].into();
+
+ #[cfg(feature = "by_ref_proposal")]
+ let receiver = receiver.with_extensions(extensions);
+
+ let (receiver, proposals, proposer) = if by_ref {
+ let proposal_ref = make_proposal_ref(proposal, proposer).await;
+ let receiver = receiver.cache(proposal_ref.clone(), proposal.clone(), proposer);
+ (receiver, vec![ProposalOrRef::from(proposal_ref)], proposer)
+ } else {
+ (receiver, vec![proposal.clone().into()], committer)
+ };
+
+ let res = receiver.receive(proposals).await;
+
+ if proposer_can_propose(proposer, proposal.proposal_type(), by_ref).is_err() {
+ assert_matches!(res, Err(MlsError::InvalidProposalTypeForSender));
+ } else {
+ let is_self_update = proposal.proposal_type() == ProposalType::UPDATE
+ && by_ref
+ && matches!(proposer, Sender::Member(_));
+
+ if !is_self_update {
+ res.unwrap();
+ }
+ }
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn make_update_proposal(name: &str) -> UpdateProposal {
+ UpdateProposal {
+ leaf_node: update_leaf_node(name, 1).await,
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn make_update_proposal_custom(name: &str, leaf_index: u32) -> UpdateProposal {
+ UpdateProposal {
+ leaf_node: update_leaf_node(name, leaf_index).await,
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn when_receiving_commit_unused_proposals_are_proposals_in_cache_but_not_in_commit() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let proposal = Proposal::GroupContextExtensions(Default::default());
+ let proposal_ref = make_proposal_ref(&proposal, alice).await;
+
+ let state = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .cache(proposal_ref.clone(), proposal, alice)
+ .receive([Proposal::Add(Box::new(AddProposal {
+ key_package: test_key_package(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "bob").await,
+ }))])
+ .await
+ .unwrap();
+
+ let [p] = &state.unused_proposals[..] else {
+ panic!(
+ "Expected single unused proposal but got {:?}",
+ state.unused_proposals
+ );
+ };
+
+ assert_eq!(p.proposal_ref(), Some(&proposal_ref));
+ }
+}
diff --git a/src/group/proposal_filter.rs b/src/group/proposal_filter.rs
new file mode 100644
index 0000000..5ef6b20
--- /dev/null
+++ b/src/group/proposal_filter.rs
@@ -0,0 +1,23 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+mod bundle;
+mod filtering_common;
+
+#[cfg(feature = "by_ref_proposal")]
+mod filtering;
+#[cfg(not(feature = "by_ref_proposal"))]
+pub mod filtering_lite;
+#[cfg(all(feature = "custom_proposal", not(feature = "by_ref_proposal")))]
+use filtering_lite as filtering;
+
+pub use bundle::{ProposalBundle, ProposalInfo, ProposalSource};
+
+#[cfg(feature = "by_ref_proposal")]
+pub(crate) use filtering::FilterStrategy;
+
+pub(crate) use filtering_common::ProposalApplier;
+
+#[cfg(all(feature = "by_ref_proposal", test))]
+pub(crate) use filtering::proposer_can_propose;
diff --git a/src/group/proposal_filter/bundle.rs b/src/group/proposal_filter/bundle.rs
new file mode 100644
index 0000000..f18a75b
--- /dev/null
+++ b/src/group/proposal_filter/bundle.rs
@@ -0,0 +1,633 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::boxed::Box;
+use alloc::vec::Vec;
+
+#[cfg(feature = "custom_proposal")]
+use itertools::Itertools;
+
+use crate::{
+ group::{
+ AddProposal, BorrowedProposal, Proposal, ProposalOrRef, ProposalType, ReInitProposal,
+ RemoveProposal, Sender,
+ },
+ ExtensionList,
+};
+
+#[cfg(feature = "by_ref_proposal")]
+use crate::group::{proposal_cache::CachedProposal, LeafIndex, ProposalRef, UpdateProposal};
+
+#[cfg(feature = "psk")]
+use crate::group::PreSharedKeyProposal;
+
+#[cfg(feature = "custom_proposal")]
+use crate::group::proposal::CustomProposal;
+
+use crate::group::ExternalInit;
+
+use core::iter::empty;
+
+#[derive(Clone, Debug, Default)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+/// A collection of proposals.
+pub struct ProposalBundle {
+ pub(crate) additions: Vec<ProposalInfo<AddProposal>>,
+ #[cfg(feature = "by_ref_proposal")]
+ pub(crate) updates: Vec<ProposalInfo<UpdateProposal>>,
+ #[cfg(feature = "by_ref_proposal")]
+ pub(crate) update_senders: Vec<LeafIndex>,
+ pub(crate) removals: Vec<ProposalInfo<RemoveProposal>>,
+ #[cfg(feature = "psk")]
+ pub(crate) psks: Vec<ProposalInfo<PreSharedKeyProposal>>,
+ pub(crate) reinitializations: Vec<ProposalInfo<ReInitProposal>>,
+ pub(crate) external_initializations: Vec<ProposalInfo<ExternalInit>>,
+ pub(crate) group_context_extensions: Vec<ProposalInfo<ExtensionList>>,
+ #[cfg(feature = "custom_proposal")]
+ pub(crate) custom_proposals: Vec<ProposalInfo<CustomProposal>>,
+}
+
+impl ProposalBundle {
+ pub fn add(&mut self, proposal: Proposal, sender: Sender, source: ProposalSource) {
+ match proposal {
+ Proposal::Add(proposal) => self.additions.push(ProposalInfo {
+ proposal: *proposal,
+ sender,
+ source,
+ }),
+ #[cfg(feature = "by_ref_proposal")]
+ Proposal::Update(proposal) => self.updates.push(ProposalInfo {
+ proposal,
+ sender,
+ source,
+ }),
+ Proposal::Remove(proposal) => self.removals.push(ProposalInfo {
+ proposal,
+ sender,
+ source,
+ }),
+ #[cfg(feature = "psk")]
+ Proposal::Psk(proposal) => self.psks.push(ProposalInfo {
+ proposal,
+ sender,
+ source,
+ }),
+ Proposal::ReInit(proposal) => self.reinitializations.push(ProposalInfo {
+ proposal,
+ sender,
+ source,
+ }),
+ Proposal::ExternalInit(proposal) => self.external_initializations.push(ProposalInfo {
+ proposal,
+ sender,
+ source,
+ }),
+ Proposal::GroupContextExtensions(proposal) => {
+ self.group_context_extensions.push(ProposalInfo {
+ proposal,
+ sender,
+ source,
+ })
+ }
+ #[cfg(feature = "custom_proposal")]
+ Proposal::Custom(proposal) => self.custom_proposals.push(ProposalInfo {
+ proposal,
+ sender,
+ source,
+ }),
+ }
+ }
+
+ /// Remove the proposal of type `T` at `index`
+ ///
+ /// Type `T` can be any of the standard MLS proposal types defined in the
+ /// [`proposal`](crate::group::proposal) module.
+ ///
+ /// `index` is consistent with the index returned by any of the proposal
+ /// type specific functions in this module.
+ pub fn remove<T: Proposable>(&mut self, index: usize) {
+ T::remove(self, index);
+ }
+
+ /// Iterate over proposals, filtered by type.
+ ///
+ /// Type `T` can be any of the standard MLS proposal types defined in the
+ /// [`proposal`](crate::group::proposal) module.
+ pub fn by_type<'a, T: Proposable + 'a>(&'a self) -> impl Iterator<Item = &'a ProposalInfo<T>> {
+ T::filter(self).iter()
+ }
+
+ /// Retain proposals, filtered by type.
+ ///
+ /// Type `T` can be any of the standard MLS proposal types defined in the
+ /// [`proposal`](crate::group::proposal) module.
+ pub fn retain_by_type<T, F, E>(&mut self, mut f: F) -> Result<(), E>
+ where
+ T: Proposable,
+ F: FnMut(&ProposalInfo<T>) -> Result<bool, E>,
+ {
+ let mut res = Ok(());
+
+ T::retain(self, |p| match f(p) {
+ Ok(keep) => keep,
+ Err(e) => {
+ if res.is_ok() {
+ res = Err(e);
+ }
+ false
+ }
+ });
+
+ res
+ }
+
+ /// Retain custom proposals in the bundle.
+ #[cfg(feature = "custom_proposal")]
+ pub fn retain_custom<F, E>(&mut self, mut f: F) -> Result<(), E>
+ where
+ F: FnMut(&ProposalInfo<CustomProposal>) -> Result<bool, E>,
+ {
+ let mut res = Ok(());
+
+ self.custom_proposals.retain(|p| match f(p) {
+ Ok(keep) => keep,
+ Err(e) => {
+ if res.is_ok() {
+ res = Err(e);
+ }
+ false
+ }
+ });
+
+ res
+ }
+
+ /// Retain MLS standard proposals in the bundle.
+ pub fn retain<F, E>(&mut self, mut f: F) -> Result<(), E>
+ where
+ F: FnMut(&ProposalInfo<BorrowedProposal<'_>>) -> Result<bool, E>,
+ {
+ self.retain_by_type::<AddProposal, _, _>(|proposal| {
+ f(&proposal.as_ref().map(BorrowedProposal::from))
+ })?;
+
+ #[cfg(feature = "by_ref_proposal")]
+ self.retain_by_type::<UpdateProposal, _, _>(|proposal| {
+ f(&proposal.as_ref().map(BorrowedProposal::from))
+ })?;
+
+ self.retain_by_type::<RemoveProposal, _, _>(|proposal| {
+ f(&proposal.as_ref().map(BorrowedProposal::from))
+ })?;
+
+ #[cfg(feature = "psk")]
+ self.retain_by_type::<PreSharedKeyProposal, _, _>(|proposal| {
+ f(&proposal.as_ref().map(BorrowedProposal::from))
+ })?;
+
+ self.retain_by_type::<ReInitProposal, _, _>(|proposal| {
+ f(&proposal.as_ref().map(BorrowedProposal::from))
+ })?;
+
+ self.retain_by_type::<ExternalInit, _, _>(|proposal| {
+ f(&proposal.as_ref().map(BorrowedProposal::from))
+ })?;
+
+ self.retain_by_type::<ExtensionList, _, _>(|proposal| {
+ f(&proposal.as_ref().map(BorrowedProposal::from))
+ })?;
+
+ Ok(())
+ }
+
+ /// The number of proposals in the bundle
+ pub fn length(&self) -> usize {
+ let len = 0;
+
+ #[cfg(feature = "psk")]
+ let len = len + self.psks.len();
+
+ let len = len + self.external_initializations.len();
+
+ #[cfg(feature = "custom_proposal")]
+ let len = len + self.custom_proposals.len();
+
+ #[cfg(feature = "by_ref_proposal")]
+ let len = len + self.updates.len();
+
+ len + self.additions.len()
+ + self.removals.len()
+ + self.reinitializations.len()
+ + self.group_context_extensions.len()
+ }
+
+ /// Iterate over all proposals inside the bundle.
+ pub fn iter_proposals(&self) -> impl Iterator<Item = ProposalInfo<BorrowedProposal<'_>>> {
+ let res = self
+ .additions
+ .iter()
+ .map(|p| p.as_ref().map(BorrowedProposal::Add))
+ .chain(
+ self.removals
+ .iter()
+ .map(|p| p.as_ref().map(BorrowedProposal::Remove)),
+ )
+ .chain(
+ self.reinitializations
+ .iter()
+ .map(|p| p.as_ref().map(BorrowedProposal::ReInit)),
+ );
+
+ #[cfg(feature = "by_ref_proposal")]
+ let res = res.chain(
+ self.updates
+ .iter()
+ .map(|p| p.as_ref().map(BorrowedProposal::Update)),
+ );
+
+ #[cfg(feature = "psk")]
+ let res = res.chain(
+ self.psks
+ .iter()
+ .map(|p| p.as_ref().map(BorrowedProposal::Psk)),
+ );
+
+ let res = res.chain(
+ self.external_initializations
+ .iter()
+ .map(|p| p.as_ref().map(BorrowedProposal::ExternalInit)),
+ );
+
+ let res = res.chain(
+ self.group_context_extensions
+ .iter()
+ .map(|p| p.as_ref().map(BorrowedProposal::GroupContextExtensions)),
+ );
+
+ #[cfg(feature = "custom_proposal")]
+ let res = res.chain(
+ self.custom_proposals
+ .iter()
+ .map(|p| p.as_ref().map(BorrowedProposal::Custom)),
+ );
+
+ res
+ }
+
+ /// Iterate over proposal in the bundle, consuming the bundle.
+ pub fn into_proposals(self) -> impl Iterator<Item = ProposalInfo<Proposal>> {
+ let res = empty();
+
+ #[cfg(feature = "custom_proposal")]
+ let res = res.chain(
+ self.custom_proposals
+ .into_iter()
+ .map(|p| p.map(Proposal::Custom)),
+ );
+
+ let res = res.chain(
+ self.external_initializations
+ .into_iter()
+ .map(|p| p.map(Proposal::ExternalInit)),
+ );
+
+ #[cfg(feature = "psk")]
+ let res = res.chain(self.psks.into_iter().map(|p| p.map(Proposal::Psk)));
+
+ #[cfg(feature = "by_ref_proposal")]
+ let res = res.chain(self.updates.into_iter().map(|p| p.map(Proposal::Update)));
+
+ res.chain(
+ self.additions
+ .into_iter()
+ .map(|p| p.map(|p| Proposal::Add(alloc::boxed::Box::new(p)))),
+ )
+ .chain(self.removals.into_iter().map(|p| p.map(Proposal::Remove)))
+ .chain(
+ self.reinitializations
+ .into_iter()
+ .map(|p| p.map(Proposal::ReInit)),
+ )
+ .chain(
+ self.group_context_extensions
+ .into_iter()
+ .map(|p| p.map(Proposal::GroupContextExtensions)),
+ )
+ }
+
+ pub(crate) fn into_proposals_or_refs(self) -> Vec<ProposalOrRef> {
+ self.into_proposals()
+ .filter_map(|p| match p.source {
+ ProposalSource::ByValue => Some(ProposalOrRef::Proposal(Box::new(p.proposal))),
+ #[cfg(feature = "by_ref_proposal")]
+ ProposalSource::ByReference(reference) => Some(ProposalOrRef::Reference(reference)),
+ _ => None,
+ })
+ .collect()
+ }
+
+ /// Add proposals in the bundle.
+ pub fn add_proposals(&self) -> &[ProposalInfo<AddProposal>] {
+ &self.additions
+ }
+
+ /// Update proposals in the bundle.
+ #[cfg(feature = "by_ref_proposal")]
+ pub fn update_proposals(&self) -> &[ProposalInfo<UpdateProposal>] {
+ &self.updates
+ }
+
+ /// Senders of update proposals in the bundle.
+ #[cfg(feature = "by_ref_proposal")]
+ pub fn update_proposal_senders(&self) -> &[LeafIndex] {
+ &self.update_senders
+ }
+
+ /// Remove proposals in the bundle.
+ pub fn remove_proposals(&self) -> &[ProposalInfo<RemoveProposal>] {
+ &self.removals
+ }
+
+ /// Pre-shared key proposals in the bundle.
+ #[cfg(feature = "psk")]
+ pub fn psk_proposals(&self) -> &[ProposalInfo<PreSharedKeyProposal>] {
+ &self.psks
+ }
+
+ /// Reinit proposals in the bundle.
+ pub fn reinit_proposals(&self) -> &[ProposalInfo<ReInitProposal>] {
+ &self.reinitializations
+ }
+
+ /// External init proposals in the bundle.
+ pub fn external_init_proposals(&self) -> &[ProposalInfo<ExternalInit>] {
+ &self.external_initializations
+ }
+
+ /// Group context extension proposals in the bundle.
+ pub fn group_context_ext_proposals(&self) -> &[ProposalInfo<ExtensionList>] {
+ &self.group_context_extensions
+ }
+
+ /// Custom proposals in the bundle.
+ #[cfg(feature = "custom_proposal")]
+ pub fn custom_proposals(&self) -> &[ProposalInfo<CustomProposal>] {
+ &self.custom_proposals
+ }
+
+ pub(crate) fn group_context_extensions_proposal(&self) -> Option<&ProposalInfo<ExtensionList>> {
+ self.group_context_extensions.first()
+ }
+
+ /// Custom proposal types that are in use within this bundle.
+ #[cfg(feature = "custom_proposal")]
+ pub fn custom_proposal_types(&self) -> impl Iterator<Item = ProposalType> + '_ {
+ #[cfg(feature = "std")]
+ let res = self
+ .custom_proposals
+ .iter()
+ .map(|v| v.proposal.proposal_type())
+ .unique();
+
+ #[cfg(not(feature = "std"))]
+ let res = self
+ .custom_proposals
+ .iter()
+ .map(|v| v.proposal.proposal_type())
+ .collect::<alloc::collections::BTreeSet<_>>()
+ .into_iter();
+
+ res
+ }
+
+ /// Standard proposal types that are in use within this bundle.
+ pub fn proposal_types(&self) -> impl Iterator<Item = ProposalType> + '_ {
+ let res = (!self.additions.is_empty())
+ .then_some(ProposalType::ADD)
+ .into_iter()
+ .chain((!self.removals.is_empty()).then_some(ProposalType::REMOVE))
+ .chain((!self.reinitializations.is_empty()).then_some(ProposalType::RE_INIT));
+
+ #[cfg(feature = "by_ref_proposal")]
+ let res = res.chain((!self.updates.is_empty()).then_some(ProposalType::UPDATE));
+
+ #[cfg(feature = "psk")]
+ let res = res.chain((!self.psks.is_empty()).then_some(ProposalType::PSK));
+
+ let res = res.chain(
+ (!self.external_initializations.is_empty()).then_some(ProposalType::EXTERNAL_INIT),
+ );
+
+ #[cfg(not(feature = "custom_proposal"))]
+ return res.chain(
+ (!self.group_context_extensions.is_empty())
+ .then_some(ProposalType::GROUP_CONTEXT_EXTENSIONS),
+ );
+
+ #[cfg(feature = "custom_proposal")]
+ return res
+ .chain(
+ (!self.group_context_extensions.is_empty())
+ .then_some(ProposalType::GROUP_CONTEXT_EXTENSIONS),
+ )
+ .chain(self.custom_proposal_types());
+ }
+}
+
+impl FromIterator<(Proposal, Sender, ProposalSource)> for ProposalBundle {
+ fn from_iter<I>(iter: I) -> Self
+ where
+ I: IntoIterator<Item = (Proposal, Sender, ProposalSource)>,
+ {
+ let mut bundle = ProposalBundle::default();
+ for (proposal, sender, source) in iter {
+ bundle.add(proposal, sender, source);
+ }
+ bundle
+ }
+}
+
+#[cfg(feature = "by_ref_proposal")]
+impl<'a> FromIterator<(&'a ProposalRef, &'a CachedProposal)> for ProposalBundle {
+ fn from_iter<I>(iter: I) -> Self
+ where
+ I: IntoIterator<Item = (&'a ProposalRef, &'a CachedProposal)>,
+ {
+ iter.into_iter()
+ .map(|(r, p)| {
+ (
+ p.proposal.clone(),
+ p.sender,
+ ProposalSource::ByReference(r.clone()),
+ )
+ })
+ .collect()
+ }
+}
+
+#[cfg(feature = "by_ref_proposal")]
+impl<'a> FromIterator<&'a (ProposalRef, CachedProposal)> for ProposalBundle {
+ fn from_iter<I>(iter: I) -> Self
+ where
+ I: IntoIterator<Item = &'a (ProposalRef, CachedProposal)>,
+ {
+ iter.into_iter().map(|pair| (&pair.0, &pair.1)).collect()
+ }
+}
+
+#[cfg_attr(
+ all(feature = "ffi", not(test)),
+ safer_ffi_gen::ffi_type(clone, opaque)
+)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+#[derive(Clone, Debug, PartialEq)]
+pub enum ProposalSource {
+ ByValue,
+ #[cfg(feature = "by_ref_proposal")]
+ ByReference(ProposalRef),
+ Local,
+}
+
+#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::ffi_type(opaque))]
+#[derive(Clone, Debug, PartialEq)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+#[non_exhaustive]
+/// Proposal description used as input to a
+/// [`MlsRules`](crate::MlsRules).
+pub struct ProposalInfo<T> {
+ /// The underlying proposal value.
+ pub proposal: T,
+ /// The sender of this proposal.
+ pub sender: Sender,
+ /// The source of the proposal.
+ pub source: ProposalSource,
+}
+
+#[cfg_attr(all(feature = "ffi", not(test)), ::safer_ffi_gen::safer_ffi_gen)]
+impl<T> ProposalInfo<T> {
+ /// Create a new ProposalInfo.
+ ///
+ /// The resulting value will be either transmitted with a commit or
+ /// locally injected into a commit resolution depending on the
+ /// `can_transmit` flag.
+ ///
+ /// This function is useful when implementing custom
+ /// [`MlsRules`](crate::MlsRules).
+ #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen_ignore)]
+ pub fn new(proposal: T, sender: Sender, can_transmit: bool) -> Self {
+ let source = if can_transmit {
+ ProposalSource::ByValue
+ } else {
+ ProposalSource::Local
+ };
+
+ ProposalInfo {
+ proposal,
+ sender,
+ source,
+ }
+ }
+
+ #[cfg(all(feature = "ffi", not(test)))]
+ pub fn sender(&self) -> &Sender {
+ &self.sender
+ }
+
+ #[cfg(all(feature = "ffi", not(test)))]
+ pub fn source(&self) -> &ProposalSource {
+ &self.source
+ }
+
+ #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen_ignore)]
+ pub fn map<U, F>(self, f: F) -> ProposalInfo<U>
+ where
+ F: FnOnce(T) -> U,
+ {
+ ProposalInfo {
+ proposal: f(self.proposal),
+ sender: self.sender,
+ source: self.source,
+ }
+ }
+
+ #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen_ignore)]
+ pub fn as_ref(&self) -> ProposalInfo<&T> {
+ ProposalInfo {
+ proposal: &self.proposal,
+ sender: self.sender,
+ source: self.source.clone(),
+ }
+ }
+
+ #[inline(always)]
+ pub fn is_by_value(&self) -> bool {
+ self.source == ProposalSource::ByValue
+ }
+
+ #[inline(always)]
+ pub fn is_by_reference(&self) -> bool {
+ !self.is_by_value()
+ }
+
+ /// The [`ProposalRef`] of this proposal if its source is [`ProposalSource::ByReference`]
+ #[cfg(feature = "by_ref_proposal")]
+ pub fn proposal_ref(&self) -> Option<&ProposalRef> {
+ match self.source {
+ ProposalSource::ByReference(ref reference) => Some(reference),
+ _ => None,
+ }
+ }
+}
+
+#[cfg(all(feature = "ffi", not(test)))]
+safer_ffi_gen::specialize!(ProposalInfoFfi = ProposalInfo<Proposal>);
+
+pub trait Proposable: Sized {
+ const TYPE: ProposalType;
+
+ fn filter(bundle: &ProposalBundle) -> &[ProposalInfo<Self>];
+ fn remove(bundle: &mut ProposalBundle, index: usize);
+ fn retain<F>(bundle: &mut ProposalBundle, keep: F)
+ where
+ F: FnMut(&ProposalInfo<Self>) -> bool;
+}
+
+macro_rules! impl_proposable {
+ ($ty:ty, $proposal_type:ident, $field:ident) => {
+ impl Proposable for $ty {
+ const TYPE: ProposalType = ProposalType::$proposal_type;
+
+ fn filter(bundle: &ProposalBundle) -> &[ProposalInfo<Self>] {
+ &bundle.$field
+ }
+
+ fn remove(bundle: &mut ProposalBundle, index: usize) {
+ if index < bundle.$field.len() {
+ bundle.$field.remove(index);
+ }
+ }
+
+ fn retain<F>(bundle: &mut ProposalBundle, keep: F)
+ where
+ F: FnMut(&ProposalInfo<Self>) -> bool,
+ {
+ bundle.$field.retain(keep);
+ }
+ }
+ };
+}
+
+impl_proposable!(AddProposal, ADD, additions);
+#[cfg(feature = "by_ref_proposal")]
+impl_proposable!(UpdateProposal, UPDATE, updates);
+impl_proposable!(RemoveProposal, REMOVE, removals);
+#[cfg(feature = "psk")]
+impl_proposable!(PreSharedKeyProposal, PSK, psks);
+impl_proposable!(ReInitProposal, RE_INIT, reinitializations);
+impl_proposable!(ExternalInit, EXTERNAL_INIT, external_initializations);
+impl_proposable!(
+ ExtensionList,
+ GROUP_CONTEXT_EXTENSIONS,
+ group_context_extensions
+);
diff --git a/src/group/proposal_filter/filtering.rs b/src/group/proposal_filter/filtering.rs
new file mode 100644
index 0000000..8e67ff5
--- /dev/null
+++ b/src/group/proposal_filter/filtering.rs
@@ -0,0 +1,580 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use crate::{
+ client::MlsError,
+ group::{
+ proposal::ReInitProposal,
+ proposal_filter::{ProposalBundle, ProposalInfo},
+ AddProposal, ProposalType, RemoveProposal, Sender, UpdateProposal,
+ },
+ iter::wrap_iter,
+ protocol_version::ProtocolVersion,
+ time::MlsTime,
+ tree_kem::{
+ leaf_node_validator::{LeafNodeValidator, ValidationContext},
+ node::LeafIndex,
+ TreeKemPublic,
+ },
+ CipherSuiteProvider, ExtensionList,
+};
+
+use super::filtering_common::{filter_out_invalid_psks, ApplyProposalsOutput, ProposalApplier};
+
+#[cfg(feature = "by_ref_proposal")]
+use crate::extension::ExternalSendersExt;
+
+use alloc::vec::Vec;
+use mls_rs_core::{error::IntoAnyError, identity::IdentityProvider, psk::PreSharedKeyStorage};
+
+#[cfg(any(
+ feature = "custom_proposal",
+ not(any(mls_build_async, feature = "rayon"))
+))]
+use itertools::Itertools;
+
+use crate::group::ExternalInit;
+
+#[cfg(feature = "psk")]
+use crate::group::proposal::PreSharedKeyProposal;
+
+#[cfg(all(not(mls_build_async), feature = "rayon"))]
+use {crate::iter::ParallelIteratorExt, rayon::prelude::*};
+
+#[cfg(mls_build_async)]
+use futures::{StreamExt, TryStreamExt};
+
+impl<'a, C, P, CSP> ProposalApplier<'a, C, P, CSP>
+where
+ C: IdentityProvider,
+ P: PreSharedKeyStorage,
+ CSP: CipherSuiteProvider,
+{
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(super) async fn apply_proposals_from_member(
+ &self,
+ strategy: FilterStrategy,
+ commit_sender: LeafIndex,
+ proposals: ProposalBundle,
+ commit_time: Option<MlsTime>,
+ ) -> Result<ApplyProposalsOutput, MlsError> {
+ let proposals = filter_out_invalid_proposers(strategy, proposals)?;
+
+ let mut proposals: ProposalBundle =
+ filter_out_update_for_committer(strategy, commit_sender, proposals)?;
+
+ // We ignore the strategy here because the check above ensures all updates are from members
+ proposals.update_senders = proposals
+ .updates
+ .iter()
+ .map(leaf_index_of_update_sender)
+ .collect::<Result<_, _>>()?;
+
+ let mut proposals = filter_out_removal_of_committer(strategy, commit_sender, proposals)?;
+
+ filter_out_invalid_psks(
+ strategy,
+ self.cipher_suite_provider,
+ &mut proposals,
+ self.psk_storage,
+ )
+ .await?;
+
+ #[cfg(feature = "by_ref_proposal")]
+ let proposals = filter_out_invalid_group_extensions(
+ strategy,
+ proposals,
+ self.identity_provider,
+ commit_time,
+ )
+ .await?;
+
+ let proposals = filter_out_extra_group_context_extensions(strategy, proposals)?;
+ let proposals = filter_out_invalid_reinit(strategy, proposals, self.protocol_version)?;
+ let proposals = filter_out_reinit_if_other_proposals(strategy.is_ignore(), proposals)?;
+
+ let proposals = filter_out_external_init(strategy, proposals)?;
+
+ self.apply_proposal_changes(strategy, proposals, commit_time)
+ .await
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(super) async fn apply_proposal_changes(
+ &self,
+ strategy: FilterStrategy,
+ proposals: ProposalBundle,
+ commit_time: Option<MlsTime>,
+ ) -> Result<ApplyProposalsOutput, MlsError> {
+ match proposals.group_context_extensions_proposal().cloned() {
+ Some(p) => {
+ self.apply_proposals_with_new_capabilities(strategy, proposals, p, commit_time)
+ .await
+ }
+ None => {
+ self.apply_tree_changes(
+ strategy,
+ proposals,
+ self.original_group_extensions,
+ commit_time,
+ )
+ .await
+ }
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(super) async fn apply_tree_changes(
+ &self,
+ strategy: FilterStrategy,
+ proposals: ProposalBundle,
+ group_extensions_in_use: &ExtensionList,
+ commit_time: Option<MlsTime>,
+ ) -> Result<ApplyProposalsOutput, MlsError> {
+ let mut applied_proposals = self
+ .validate_new_nodes(strategy, proposals, group_extensions_in_use, commit_time)
+ .await?;
+
+ let mut new_tree = self.original_tree.clone();
+
+ let added = new_tree
+ .batch_edit(
+ &mut applied_proposals,
+ group_extensions_in_use,
+ self.identity_provider,
+ self.cipher_suite_provider,
+ strategy.is_ignore(),
+ )
+ .await?;
+
+ let new_context_extensions = applied_proposals
+ .group_context_extensions_proposal()
+ .map(|gce| gce.proposal.clone());
+
+ Ok(ApplyProposalsOutput {
+ applied_proposals,
+ new_tree,
+ indexes_of_added_kpkgs: added,
+ external_init_index: None,
+ new_context_extensions,
+ })
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn validate_new_nodes(
+ &self,
+ strategy: FilterStrategy,
+ mut proposals: ProposalBundle,
+ group_extensions_in_use: &ExtensionList,
+ commit_time: Option<MlsTime>,
+ ) -> Result<ProposalBundle, MlsError> {
+ let leaf_node_validator = &LeafNodeValidator::new(
+ self.cipher_suite_provider,
+ self.identity_provider,
+ Some(group_extensions_in_use),
+ );
+
+ let bad_indices: Vec<_> = wrap_iter(proposals.update_proposals())
+ .zip(wrap_iter(proposals.update_proposal_senders()))
+ .enumerate()
+ .filter_map(|(i, (p, &sender_index))| async move {
+ let res = {
+ let leaf = &p.proposal.leaf_node;
+
+ let res = leaf_node_validator
+ .check_if_valid(
+ leaf,
+ ValidationContext::Update((self.group_id, *sender_index, commit_time)),
+ )
+ .await;
+
+ let old_leaf = match self.original_tree.get_leaf_node(sender_index) {
+ Ok(leaf) => leaf,
+ Err(e) => return Some(Err(e)),
+ };
+
+ let valid_successor = self
+ .identity_provider
+ .valid_successor(
+ &old_leaf.signing_identity,
+ &leaf.signing_identity,
+ group_extensions_in_use,
+ )
+ .await
+ .map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))
+ .and_then(|valid| valid.then_some(()).ok_or(MlsError::InvalidSuccessor));
+
+ res.and(valid_successor)
+ };
+
+ apply_strategy(strategy, p.is_by_reference(), res)
+ .map(|b| (!b).then_some(i))
+ .transpose()
+ })
+ .try_collect()
+ .await?;
+
+ bad_indices.into_iter().rev().for_each(|i| {
+ proposals.remove::<UpdateProposal>(i);
+ proposals.update_senders.remove(i);
+ });
+
+ let bad_indices: Vec<_> = wrap_iter(proposals.add_proposals())
+ .enumerate()
+ .filter_map(|(i, p)| async move {
+ let res = self
+ .validate_new_node(leaf_node_validator, &p.proposal.key_package, commit_time)
+ .await;
+
+ apply_strategy(strategy, p.is_by_reference(), res)
+ .map(|b| (!b).then_some(i))
+ .transpose()
+ })
+ .try_collect()
+ .await?;
+
+ bad_indices
+ .into_iter()
+ .rev()
+ .for_each(|i| proposals.remove::<AddProposal>(i));
+
+ Ok(proposals)
+ }
+}
+
+#[derive(Clone, Copy, Debug)]
+pub enum FilterStrategy {
+ IgnoreByRef,
+ IgnoreNone,
+}
+
+impl FilterStrategy {
+ pub(super) fn ignore(self, by_ref: bool) -> bool {
+ match self {
+ FilterStrategy::IgnoreByRef => by_ref,
+ FilterStrategy::IgnoreNone => false,
+ }
+ }
+
+ fn is_ignore(self) -> bool {
+ match self {
+ FilterStrategy::IgnoreByRef => true,
+ FilterStrategy::IgnoreNone => false,
+ }
+ }
+}
+
+pub(crate) fn apply_strategy(
+ strategy: FilterStrategy,
+ by_ref: bool,
+ r: Result<(), MlsError>,
+) -> Result<bool, MlsError> {
+ r.map(|_| true)
+ .or_else(|error| strategy.ignore(by_ref).then_some(false).ok_or(error))
+}
+
+fn filter_out_update_for_committer(
+ strategy: FilterStrategy,
+ commit_sender: LeafIndex,
+ mut proposals: ProposalBundle,
+) -> Result<ProposalBundle, MlsError> {
+ proposals.retain_by_type::<UpdateProposal, _, _>(|p| {
+ apply_strategy(
+ strategy,
+ p.is_by_reference(),
+ (p.sender != Sender::Member(*commit_sender))
+ .then_some(())
+ .ok_or(MlsError::InvalidCommitSelfUpdate),
+ )
+ })?;
+ Ok(proposals)
+}
+
+fn filter_out_removal_of_committer(
+ strategy: FilterStrategy,
+ commit_sender: LeafIndex,
+ mut proposals: ProposalBundle,
+) -> Result<ProposalBundle, MlsError> {
+ proposals.retain_by_type::<RemoveProposal, _, _>(|p| {
+ apply_strategy(
+ strategy,
+ p.is_by_reference(),
+ (p.proposal.to_remove != commit_sender)
+ .then_some(())
+ .ok_or(MlsError::CommitterSelfRemoval),
+ )
+ })?;
+ Ok(proposals)
+}
+
+#[cfg(feature = "by_ref_proposal")]
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+async fn filter_out_invalid_group_extensions<C>(
+ strategy: FilterStrategy,
+ mut proposals: ProposalBundle,
+ identity_provider: &C,
+ commit_time: Option<MlsTime>,
+) -> Result<ProposalBundle, MlsError>
+where
+ C: IdentityProvider,
+{
+ let mut bad_indices = Vec::new();
+
+ for (i, p) in proposals.by_type::<ExtensionList>().enumerate() {
+ let ext = p.proposal.get_as::<ExternalSendersExt>();
+
+ let res = match ext {
+ Ok(None) => Ok(()),
+ Ok(Some(extension)) => extension
+ .verify_all(identity_provider, commit_time, &p.proposal)
+ .await
+ .map_err(|e| MlsError::IdentityProviderError(e.into_any_error())),
+ Err(e) => Err(MlsError::from(e)),
+ };
+
+ if !apply_strategy(strategy, p.is_by_reference(), res)? {
+ bad_indices.push(i);
+ }
+ }
+
+ bad_indices
+ .into_iter()
+ .rev()
+ .for_each(|i| proposals.remove::<ExtensionList>(i));
+
+ Ok(proposals)
+}
+
+fn filter_out_extra_group_context_extensions(
+ strategy: FilterStrategy,
+ mut proposals: ProposalBundle,
+) -> Result<ProposalBundle, MlsError> {
+ let mut found = false;
+
+ proposals.retain_by_type::<ExtensionList, _, _>(|p| {
+ apply_strategy(
+ strategy,
+ p.is_by_reference(),
+ (!core::mem::replace(&mut found, true))
+ .then_some(())
+ .ok_or(MlsError::MoreThanOneGroupContextExtensionsProposal),
+ )
+ })?;
+
+ Ok(proposals)
+}
+
+fn filter_out_invalid_reinit(
+ strategy: FilterStrategy,
+ mut proposals: ProposalBundle,
+ protocol_version: ProtocolVersion,
+) -> Result<ProposalBundle, MlsError> {
+ proposals.retain_by_type::<ReInitProposal, _, _>(|p| {
+ apply_strategy(
+ strategy,
+ p.is_by_reference(),
+ (p.proposal.version >= protocol_version)
+ .then_some(())
+ .ok_or(MlsError::InvalidProtocolVersionInReInit),
+ )
+ })?;
+
+ Ok(proposals)
+}
+
+fn filter_out_reinit_if_other_proposals(
+ filter: bool,
+ mut proposals: ProposalBundle,
+) -> Result<ProposalBundle, MlsError> {
+ let proposal_count = proposals.length();
+
+ let has_reinit_and_other_proposal =
+ !proposals.reinit_proposals().is_empty() && proposal_count != 1;
+
+ if has_reinit_and_other_proposal {
+ let any_by_val = proposals.reinit_proposals().iter().any(|p| p.is_by_value());
+
+ if any_by_val || !filter {
+ return Err(MlsError::OtherProposalWithReInit);
+ }
+
+ let has_other_proposal_type = proposal_count > proposals.reinit_proposals().len();
+
+ if has_other_proposal_type {
+ proposals.reinitializations = Vec::new();
+ } else {
+ proposals.reinitializations.truncate(1);
+ }
+ }
+
+ Ok(proposals)
+}
+
+fn filter_out_external_init(
+ strategy: FilterStrategy,
+ mut proposals: ProposalBundle,
+) -> Result<ProposalBundle, MlsError> {
+ proposals.retain_by_type::<ExternalInit, _, _>(|p| {
+ apply_strategy(
+ strategy,
+ p.is_by_reference(),
+ Err(MlsError::InvalidProposalTypeForSender),
+ )
+ })?;
+
+ Ok(proposals)
+}
+
+pub(crate) fn proposer_can_propose(
+ proposer: Sender,
+ proposal_type: ProposalType,
+ by_ref: bool,
+) -> Result<(), MlsError> {
+ let can_propose = match (proposer, by_ref) {
+ (Sender::Member(_), false) => matches!(
+ proposal_type,
+ ProposalType::ADD
+ | ProposalType::REMOVE
+ | ProposalType::PSK
+ | ProposalType::RE_INIT
+ | ProposalType::GROUP_CONTEXT_EXTENSIONS
+ ),
+ (Sender::Member(_), true) => matches!(
+ proposal_type,
+ ProposalType::ADD
+ | ProposalType::UPDATE
+ | ProposalType::REMOVE
+ | ProposalType::PSK
+ | ProposalType::RE_INIT
+ | ProposalType::GROUP_CONTEXT_EXTENSIONS
+ ),
+ #[cfg(feature = "by_ref_proposal")]
+ (Sender::External(_), false) => false,
+ #[cfg(feature = "by_ref_proposal")]
+ (Sender::External(_), true) => matches!(
+ proposal_type,
+ ProposalType::ADD
+ | ProposalType::REMOVE
+ | ProposalType::RE_INIT
+ | ProposalType::PSK
+ | ProposalType::GROUP_CONTEXT_EXTENSIONS
+ ),
+ (Sender::NewMemberCommit, false) => matches!(
+ proposal_type,
+ ProposalType::REMOVE | ProposalType::PSK | ProposalType::EXTERNAL_INIT
+ ),
+ (Sender::NewMemberCommit, true) => false,
+ (Sender::NewMemberProposal, false) => false,
+ (Sender::NewMemberProposal, true) => matches!(proposal_type, ProposalType::ADD),
+ };
+
+ can_propose
+ .then_some(())
+ .ok_or(MlsError::InvalidProposalTypeForSender)
+}
+
+pub(crate) fn filter_out_invalid_proposers(
+ strategy: FilterStrategy,
+ mut proposals: ProposalBundle,
+) -> Result<ProposalBundle, MlsError> {
+ for i in (0..proposals.add_proposals().len()).rev() {
+ let p = &proposals.add_proposals()[i];
+ let res = proposer_can_propose(p.sender, ProposalType::ADD, p.is_by_reference());
+
+ if !apply_strategy(strategy, p.is_by_reference(), res)? {
+ proposals.remove::<AddProposal>(i);
+ }
+ }
+
+ for i in (0..proposals.update_proposals().len()).rev() {
+ let p = &proposals.update_proposals()[i];
+ let res = proposer_can_propose(p.sender, ProposalType::UPDATE, p.is_by_reference());
+
+ if !apply_strategy(strategy, p.is_by_reference(), res)? {
+ proposals.remove::<UpdateProposal>(i);
+ proposals.update_senders.remove(i);
+ }
+ }
+
+ for i in (0..proposals.remove_proposals().len()).rev() {
+ let p = &proposals.remove_proposals()[i];
+ let res = proposer_can_propose(p.sender, ProposalType::REMOVE, p.is_by_reference());
+
+ if !apply_strategy(strategy, p.is_by_reference(), res)? {
+ proposals.remove::<RemoveProposal>(i);
+ }
+ }
+
+ #[cfg(feature = "psk")]
+ for i in (0..proposals.psk_proposals().len()).rev() {
+ let p = &proposals.psk_proposals()[i];
+ let res = proposer_can_propose(p.sender, ProposalType::PSK, p.is_by_reference());
+
+ if !apply_strategy(strategy, p.is_by_reference(), res)? {
+ proposals.remove::<PreSharedKeyProposal>(i);
+ }
+ }
+
+ for i in (0..proposals.reinit_proposals().len()).rev() {
+ let p = &proposals.reinit_proposals()[i];
+ let res = proposer_can_propose(p.sender, ProposalType::RE_INIT, p.is_by_reference());
+
+ if !apply_strategy(strategy, p.is_by_reference(), res)? {
+ proposals.remove::<ReInitProposal>(i);
+ }
+ }
+
+ for i in (0..proposals.external_init_proposals().len()).rev() {
+ let p = &proposals.external_init_proposals()[i];
+ let res = proposer_can_propose(p.sender, ProposalType::EXTERNAL_INIT, p.is_by_reference());
+
+ if !apply_strategy(strategy, p.is_by_reference(), res)? {
+ proposals.remove::<ExternalInit>(i);
+ }
+ }
+
+ for i in (0..proposals.group_context_ext_proposals().len()).rev() {
+ let p = &proposals.group_context_ext_proposals()[i];
+ let gce_type = ProposalType::GROUP_CONTEXT_EXTENSIONS;
+ let res = proposer_can_propose(p.sender, gce_type, p.is_by_reference());
+
+ if !apply_strategy(strategy, p.is_by_reference(), res)? {
+ proposals.remove::<ExtensionList>(i);
+ }
+ }
+
+ Ok(proposals)
+}
+
+fn leaf_index_of_update_sender(p: &ProposalInfo<UpdateProposal>) -> Result<LeafIndex, MlsError> {
+ match p.sender {
+ Sender::Member(i) => Ok(LeafIndex(i)),
+ _ => Err(MlsError::InvalidProposalTypeForSender),
+ }
+}
+
+#[cfg(feature = "custom_proposal")]
+pub(super) fn filter_out_unsupported_custom_proposals(
+ proposals: &mut ProposalBundle,
+ tree: &TreeKemPublic,
+ strategy: FilterStrategy,
+) -> Result<(), MlsError> {
+ let supported_types = proposals
+ .custom_proposal_types()
+ .filter(|t| tree.can_support_proposal(*t))
+ .collect_vec();
+
+ proposals.retain_custom(|p| {
+ let proposal_type = p.proposal.proposal_type();
+
+ apply_strategy(
+ strategy,
+ p.is_by_reference(),
+ supported_types
+ .contains(&proposal_type)
+ .then_some(())
+ .ok_or(MlsError::UnsupportedCustomProposal(proposal_type)),
+ )
+ })
+}
diff --git a/src/group/proposal_filter/filtering_common.rs b/src/group/proposal_filter/filtering_common.rs
new file mode 100644
index 0000000..278c0de
--- /dev/null
+++ b/src/group/proposal_filter/filtering_common.rs
@@ -0,0 +1,579 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use crate::{
+ client::MlsError,
+ group::{proposal_filter::ProposalBundle, Sender},
+ key_package::{validate_key_package_properties, KeyPackage},
+ protocol_version::ProtocolVersion,
+ time::MlsTime,
+ tree_kem::{
+ leaf_node_validator::{LeafNodeValidator, ValidationContext},
+ node::LeafIndex,
+ TreeKemPublic,
+ },
+ CipherSuiteProvider, ExtensionList,
+};
+
+use crate::tree_kem::leaf_node::LeafNode;
+
+use super::ProposalInfo;
+
+use crate::extension::{MlsExtension, RequiredCapabilitiesExt};
+
+#[cfg(feature = "by_ref_proposal")]
+use crate::extension::ExternalSendersExt;
+
+use mls_rs_core::error::IntoAnyError;
+
+use alloc::vec::Vec;
+use mls_rs_core::{identity::IdentityProvider, psk::PreSharedKeyStorage};
+
+use crate::group::{ExternalInit, ProposalType, RemoveProposal};
+
+#[cfg(all(feature = "by_ref_proposal", feature = "psk"))]
+use crate::group::proposal::PreSharedKeyProposal;
+
+#[cfg(feature = "psk")]
+use crate::group::{JustPreSharedKeyID, ResumptionPSKUsage, ResumptionPsk};
+
+#[cfg(all(feature = "std", feature = "psk"))]
+use std::collections::HashSet;
+
+#[cfg(feature = "by_ref_proposal")]
+use super::filtering::{apply_strategy, filter_out_invalid_proposers, FilterStrategy};
+
+#[cfg(feature = "custom_proposal")]
+use super::filtering::filter_out_unsupported_custom_proposals;
+
+#[derive(Debug)]
+pub(crate) struct ProposalApplier<'a, C, P, CSP> {
+ pub original_tree: &'a TreeKemPublic,
+ pub protocol_version: ProtocolVersion,
+ pub cipher_suite_provider: &'a CSP,
+ pub original_group_extensions: &'a ExtensionList,
+ pub external_leaf: Option<&'a LeafNode>,
+ pub identity_provider: &'a C,
+ pub psk_storage: &'a P,
+ #[cfg(feature = "by_ref_proposal")]
+ pub group_id: &'a [u8],
+}
+
+#[derive(Debug)]
+pub(crate) struct ApplyProposalsOutput {
+ pub(crate) new_tree: TreeKemPublic,
+ pub(crate) indexes_of_added_kpkgs: Vec<LeafIndex>,
+ pub(crate) external_init_index: Option<LeafIndex>,
+ #[cfg(feature = "by_ref_proposal")]
+ pub(crate) applied_proposals: ProposalBundle,
+ pub(crate) new_context_extensions: Option<ExtensionList>,
+}
+
+impl<'a, C, P, CSP> ProposalApplier<'a, C, P, CSP>
+where
+ C: IdentityProvider,
+ P: PreSharedKeyStorage,
+ CSP: CipherSuiteProvider,
+{
+ #[allow(clippy::too_many_arguments)]
+ pub(crate) fn new(
+ original_tree: &'a TreeKemPublic,
+ protocol_version: ProtocolVersion,
+ cipher_suite_provider: &'a CSP,
+ original_group_extensions: &'a ExtensionList,
+ external_leaf: Option<&'a LeafNode>,
+ identity_provider: &'a C,
+ psk_storage: &'a P,
+ #[cfg(feature = "by_ref_proposal")] group_id: &'a [u8],
+ ) -> Self {
+ Self {
+ original_tree,
+ protocol_version,
+ cipher_suite_provider,
+ original_group_extensions,
+ external_leaf,
+ identity_provider,
+ psk_storage,
+ #[cfg(feature = "by_ref_proposal")]
+ group_id,
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn apply_proposals(
+ &self,
+ #[cfg(feature = "by_ref_proposal")] strategy: FilterStrategy,
+ commit_sender: &Sender,
+ #[cfg(not(feature = "by_ref_proposal"))] proposals: &ProposalBundle,
+ #[cfg(feature = "by_ref_proposal")] proposals: ProposalBundle,
+ commit_time: Option<MlsTime>,
+ ) -> Result<ApplyProposalsOutput, MlsError> {
+ let output = match commit_sender {
+ Sender::Member(sender) => {
+ self.apply_proposals_from_member(
+ #[cfg(feature = "by_ref_proposal")]
+ strategy,
+ LeafIndex(*sender),
+ proposals,
+ commit_time,
+ )
+ .await
+ }
+ Sender::NewMemberCommit => {
+ self.apply_proposals_from_new_member(proposals, commit_time)
+ .await
+ }
+ #[cfg(feature = "by_ref_proposal")]
+ Sender::External(_) => Err(MlsError::ExternalSenderCannotCommit),
+ #[cfg(feature = "by_ref_proposal")]
+ Sender::NewMemberProposal => Err(MlsError::ExternalSenderCannotCommit),
+ }?;
+
+ #[cfg(all(feature = "by_ref_proposal", feature = "custom_proposal"))]
+ let mut output = output;
+
+ #[cfg(all(feature = "by_ref_proposal", feature = "custom_proposal"))]
+ filter_out_unsupported_custom_proposals(
+ &mut output.applied_proposals,
+ &output.new_tree,
+ strategy,
+ )?;
+
+ #[cfg(all(not(feature = "by_ref_proposal"), feature = "custom_proposal"))]
+ filter_out_unsupported_custom_proposals(proposals, &output.new_tree)?;
+
+ Ok(output)
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ // The lint below is triggered by the `proposals` parameter which may or may not be a borrow.
+ #[allow(clippy::needless_borrow)]
+ async fn apply_proposals_from_new_member(
+ &self,
+ #[cfg(not(feature = "by_ref_proposal"))] proposals: &ProposalBundle,
+ #[cfg(feature = "by_ref_proposal")] proposals: ProposalBundle,
+ commit_time: Option<MlsTime>,
+ ) -> Result<ApplyProposalsOutput, MlsError> {
+ let external_leaf = self
+ .external_leaf
+ .ok_or(MlsError::ExternalCommitMustHaveNewLeaf)?;
+
+ ensure_exactly_one_external_init(&proposals)?;
+
+ ensure_at_most_one_removal_for_self(
+ &proposals,
+ external_leaf,
+ self.original_tree,
+ self.identity_provider,
+ self.original_group_extensions,
+ )
+ .await?;
+
+ ensure_proposals_in_external_commit_are_allowed(&proposals)?;
+ ensure_no_proposal_by_ref(&proposals)?;
+
+ #[cfg(feature = "by_ref_proposal")]
+ let mut proposals = filter_out_invalid_proposers(FilterStrategy::IgnoreNone, proposals)?;
+
+ filter_out_invalid_psks(
+ #[cfg(feature = "by_ref_proposal")]
+ FilterStrategy::IgnoreNone,
+ self.cipher_suite_provider,
+ #[cfg(feature = "by_ref_proposal")]
+ &mut proposals,
+ #[cfg(not(feature = "by_ref_proposal"))]
+ proposals,
+ self.psk_storage,
+ )
+ .await?;
+
+ let mut output = self
+ .apply_proposal_changes(
+ #[cfg(feature = "by_ref_proposal")]
+ FilterStrategy::IgnoreNone,
+ proposals,
+ commit_time,
+ )
+ .await?;
+
+ output.external_init_index = Some(
+ insert_external_leaf(
+ &mut output.new_tree,
+ external_leaf.clone(),
+ self.identity_provider,
+ self.original_group_extensions,
+ )
+ .await?,
+ );
+
+ Ok(output)
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(super) async fn apply_proposals_with_new_capabilities(
+ &self,
+ #[cfg(feature = "by_ref_proposal")] strategy: FilterStrategy,
+ #[cfg(not(feature = "by_ref_proposal"))] proposals: &ProposalBundle,
+ #[cfg(feature = "by_ref_proposal")] proposals: ProposalBundle,
+ group_context_extensions_proposal: ProposalInfo<ExtensionList>,
+ commit_time: Option<MlsTime>,
+ ) -> Result<ApplyProposalsOutput, MlsError>
+ where
+ C: IdentityProvider,
+ {
+ #[cfg(feature = "by_ref_proposal")]
+ let mut proposals_clone = proposals.clone();
+
+ // Apply adds, updates etc. in the context of new extensions
+ let output = self
+ .apply_tree_changes(
+ #[cfg(feature = "by_ref_proposal")]
+ strategy,
+ proposals,
+ &group_context_extensions_proposal.proposal,
+ commit_time,
+ )
+ .await?;
+
+ // Verify that capabilities and extensions are supported after modifications.
+ // TODO: The newly inserted nodes have already been validated by `apply_tree_changes`
+ // above. We should investigate if there is an easy way to avoid the double check.
+ let must_check = group_context_extensions_proposal
+ .proposal
+ .has_extension(RequiredCapabilitiesExt::extension_type());
+
+ #[cfg(feature = "by_ref_proposal")]
+ let must_check = must_check
+ || group_context_extensions_proposal
+ .proposal
+ .has_extension(ExternalSendersExt::extension_type());
+
+ let new_capabilities_supported = if must_check {
+ let leaf_validator = LeafNodeValidator::new(
+ self.cipher_suite_provider,
+ self.identity_provider,
+ Some(&group_context_extensions_proposal.proposal),
+ );
+
+ output
+ .new_tree
+ .non_empty_leaves()
+ .try_for_each(|(_, leaf)| {
+ leaf_validator.validate_required_capabilities(leaf)?;
+
+ #[cfg(feature = "by_ref_proposal")]
+ leaf_validator.validate_external_senders_ext_credentials(leaf)?;
+
+ Ok(())
+ })
+ } else {
+ Ok(())
+ };
+
+ let new_extensions_supported = group_context_extensions_proposal
+ .proposal
+ .iter()
+ .map(|extension| extension.extension_type)
+ .filter(|&ext_type| !ext_type.is_default())
+ .find(|ext_type| {
+ !output
+ .new_tree
+ .non_empty_leaves()
+ .all(|(_, leaf)| leaf.capabilities.extensions.contains(ext_type))
+ })
+ .map_or(Ok(()), |ext| Err(MlsError::UnsupportedGroupExtension(ext)));
+
+ #[cfg(not(feature = "by_ref_proposal"))]
+ {
+ new_capabilities_supported.and(new_extensions_supported)?;
+ Ok(output)
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ // If extensions are good, return `Ok`. If not and the strategy is to filter, remove the group
+ // context extensions proposal and try applying all proposals again in the context of the old
+ // extensions. Else, return an error.
+ match new_capabilities_supported.and(new_extensions_supported) {
+ Ok(()) => Ok(output),
+ Err(e) => {
+ if strategy.ignore(group_context_extensions_proposal.is_by_reference()) {
+ proposals_clone.group_context_extensions.clear();
+
+ self.apply_tree_changes(
+ strategy,
+ proposals_clone,
+ self.original_group_extensions,
+ commit_time,
+ )
+ .await
+ } else {
+ Err(e)
+ }
+ }
+ }
+ }
+
+ #[cfg(any(mls_build_async, not(feature = "rayon")))]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn validate_new_node<Ip: IdentityProvider, Cp: CipherSuiteProvider>(
+ &self,
+ leaf_node_validator: &LeafNodeValidator<'_, Ip, Cp>,
+ key_package: &KeyPackage,
+ commit_time: Option<MlsTime>,
+ ) -> Result<(), MlsError> {
+ leaf_node_validator
+ .check_if_valid(&key_package.leaf_node, ValidationContext::Add(commit_time))
+ .await?;
+
+ validate_key_package_properties(
+ key_package,
+ self.protocol_version,
+ self.cipher_suite_provider,
+ )
+ .await
+ }
+
+ #[cfg(all(not(mls_build_async), feature = "rayon"))]
+ pub fn validate_new_node<Ip: IdentityProvider, Cp: CipherSuiteProvider>(
+ &self,
+ leaf_node_validator: &LeafNodeValidator<'_, Ip, Cp>,
+ key_package: &KeyPackage,
+ commit_time: Option<MlsTime>,
+ ) -> Result<(), MlsError> {
+ let (a, b) = rayon::join(
+ || {
+ leaf_node_validator
+ .check_if_valid(&key_package.leaf_node, ValidationContext::Add(commit_time))
+ },
+ || {
+ validate_key_package_properties(
+ key_package,
+ self.protocol_version,
+ self.cipher_suite_provider,
+ )
+ },
+ );
+ a?;
+ b
+ }
+}
+
+#[cfg(feature = "psk")]
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(crate) async fn filter_out_invalid_psks<P, CP>(
+ #[cfg(feature = "by_ref_proposal")] strategy: FilterStrategy,
+ cipher_suite_provider: &CP,
+ #[cfg(not(feature = "by_ref_proposal"))] proposals: &ProposalBundle,
+ #[cfg(feature = "by_ref_proposal")] proposals: &mut ProposalBundle,
+ psk_storage: &P,
+) -> Result<(), MlsError>
+where
+ P: PreSharedKeyStorage,
+ CP: CipherSuiteProvider,
+{
+ let kdf_extract_size = cipher_suite_provider.kdf_extract_size();
+
+ #[cfg(feature = "std")]
+ let mut ids_seen = HashSet::new();
+
+ #[cfg(not(feature = "std"))]
+ let mut ids_seen = Vec::new();
+
+ #[cfg(feature = "by_ref_proposal")]
+ let mut bad_indices = Vec::new();
+
+ for i in 0..proposals.psk_proposals().len() {
+ let p = &proposals.psks[i];
+
+ let valid = matches!(
+ p.proposal.psk.key_id,
+ JustPreSharedKeyID::External(_)
+ | JustPreSharedKeyID::Resumption(ResumptionPsk {
+ usage: ResumptionPSKUsage::Application,
+ ..
+ })
+ );
+
+ let nonce_length = p.proposal.psk.psk_nonce.0.len();
+ let nonce_valid = nonce_length == kdf_extract_size;
+
+ #[cfg(feature = "std")]
+ let is_new_id = ids_seen.insert(p.proposal.psk.clone());
+
+ #[cfg(not(feature = "std"))]
+ let is_new_id = !ids_seen.contains(&p.proposal.psk);
+
+ let external_id_is_valid = match &p.proposal.psk.key_id {
+ JustPreSharedKeyID::External(id) => psk_storage
+ .contains(id)
+ .await
+ .map_err(|e| MlsError::PskStoreError(e.into_any_error()))
+ .and_then(|found| {
+ if found {
+ Ok(())
+ } else {
+ Err(MlsError::MissingRequiredPsk)
+ }
+ }),
+ JustPreSharedKeyID::Resumption(_) => Ok(()),
+ };
+
+ #[cfg(not(feature = "by_ref_proposal"))]
+ if !valid {
+ return Err(MlsError::InvalidTypeOrUsageInPreSharedKeyProposal);
+ } else if !nonce_valid {
+ return Err(MlsError::InvalidPskNonceLength);
+ } else if !is_new_id {
+ return Err(MlsError::DuplicatePskIds);
+ } else if external_id_is_valid.is_err() {
+ return external_id_is_valid;
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ {
+ let res = if !valid {
+ Err(MlsError::InvalidTypeOrUsageInPreSharedKeyProposal)
+ } else if !nonce_valid {
+ Err(MlsError::InvalidPskNonceLength)
+ } else if !is_new_id {
+ Err(MlsError::DuplicatePskIds)
+ } else {
+ external_id_is_valid
+ };
+
+ if !apply_strategy(strategy, p.is_by_reference(), res)? {
+ bad_indices.push(i)
+ }
+ }
+
+ #[cfg(not(feature = "std"))]
+ ids_seen.push(p.proposal.psk.clone());
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ bad_indices
+ .into_iter()
+ .rev()
+ .for_each(|i| proposals.remove::<PreSharedKeyProposal>(i));
+
+ Ok(())
+}
+
+#[cfg(not(feature = "psk"))]
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(crate) async fn filter_out_invalid_psks<P, CP>(
+ #[cfg(feature = "by_ref_proposal")] _: FilterStrategy,
+ _: &CP,
+ #[cfg(not(feature = "by_ref_proposal"))] _: &ProposalBundle,
+ #[cfg(feature = "by_ref_proposal")] _: &mut ProposalBundle,
+ _: &P,
+) -> Result<(), MlsError>
+where
+ P: PreSharedKeyStorage,
+ CP: CipherSuiteProvider,
+{
+ Ok(())
+}
+
+fn ensure_exactly_one_external_init(proposals: &ProposalBundle) -> Result<(), MlsError> {
+ (proposals.by_type::<ExternalInit>().count() == 1)
+ .then_some(())
+ .ok_or(MlsError::ExternalCommitMustHaveExactlyOneExternalInit)
+}
+
+/// Non-default proposal types are by default allowed. Custom MlsRules may disallow
+/// specific custom proposals in external commits
+fn ensure_proposals_in_external_commit_are_allowed(
+ proposals: &ProposalBundle,
+) -> Result<(), MlsError> {
+ let supported_default_types = [
+ ProposalType::EXTERNAL_INIT,
+ ProposalType::REMOVE,
+ ProposalType::PSK,
+ ];
+
+ let unsupported_type = proposals
+ .proposal_types()
+ .find(|ty| !supported_default_types.contains(ty) && ProposalType::DEFAULT.contains(ty));
+
+ match unsupported_type {
+ Some(kind) => Err(MlsError::InvalidProposalTypeInExternalCommit(kind)),
+ None => Ok(()),
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+async fn ensure_at_most_one_removal_for_self<C>(
+ proposals: &ProposalBundle,
+ external_leaf: &LeafNode,
+ tree: &TreeKemPublic,
+ identity_provider: &C,
+ extensions: &ExtensionList,
+) -> Result<(), MlsError>
+where
+ C: IdentityProvider,
+{
+ let mut removals = proposals.by_type::<RemoveProposal>();
+
+ match (removals.next(), removals.next()) {
+ (Some(removal), None) => {
+ ensure_removal_is_for_self(
+ &removal.proposal,
+ external_leaf,
+ tree,
+ identity_provider,
+ extensions,
+ )
+ .await
+ }
+ (Some(_), Some(_)) => Err(MlsError::ExternalCommitWithMoreThanOneRemove),
+ (None, _) => Ok(()),
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+async fn ensure_removal_is_for_self<C>(
+ removal: &RemoveProposal,
+ external_leaf: &LeafNode,
+ tree: &TreeKemPublic,
+ identity_provider: &C,
+ extensions: &ExtensionList,
+) -> Result<(), MlsError>
+where
+ C: IdentityProvider,
+{
+ let existing_signing_id = &tree.get_leaf_node(removal.to_remove)?.signing_identity;
+
+ identity_provider
+ .valid_successor(
+ existing_signing_id,
+ &external_leaf.signing_identity,
+ extensions,
+ )
+ .await
+ .map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))?
+ .then_some(())
+ .ok_or(MlsError::ExternalCommitRemovesOtherIdentity)
+}
+
+/// Non-default by-ref proposal types are by default allowed. Custom MlsRules may disallow
+/// specific custom by-ref proposals.
+fn ensure_no_proposal_by_ref(proposals: &ProposalBundle) -> Result<(), MlsError> {
+ proposals
+ .iter_proposals()
+ .all(|p| !ProposalType::DEFAULT.contains(&p.proposal.proposal_type()) || p.is_by_value())
+ .then_some(())
+ .ok_or(MlsError::OnlyMembersCanCommitProposalsByRef)
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+async fn insert_external_leaf<I: IdentityProvider>(
+ tree: &mut TreeKemPublic,
+ leaf_node: LeafNode,
+ identity_provider: &I,
+ extensions: &ExtensionList,
+) -> Result<LeafIndex, MlsError> {
+ tree.add_leaf(leaf_node, identity_provider, extensions, None)
+ .await
+}
diff --git a/src/group/proposal_filter/filtering_lite.rs b/src/group/proposal_filter/filtering_lite.rs
new file mode 100644
index 0000000..09ca389
--- /dev/null
+++ b/src/group/proposal_filter/filtering_lite.rs
@@ -0,0 +1,225 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use crate::{
+ client::MlsError,
+ group::proposal_filter::ProposalBundle,
+ iter::wrap_iter,
+ protocol_version::ProtocolVersion,
+ time::MlsTime,
+ tree_kem::{leaf_node_validator::LeafNodeValidator, node::LeafIndex},
+ CipherSuiteProvider, ExtensionList,
+};
+
+use super::filtering_common::{filter_out_invalid_psks, ApplyProposalsOutput, ProposalApplier};
+
+#[cfg(feature = "by_ref_proposal")]
+use {crate::extension::ExternalSendersExt, mls_rs_core::error::IntoAnyError};
+
+use mls_rs_core::{identity::IdentityProvider, psk::PreSharedKeyStorage};
+
+#[cfg(feature = "custom_proposal")]
+use itertools::Itertools;
+
+#[cfg(all(not(mls_build_async), feature = "rayon"))]
+use rayon::prelude::*;
+
+#[cfg(mls_build_async)]
+use futures::{StreamExt, TryStreamExt};
+
+#[cfg(feature = "custom_proposal")]
+use crate::tree_kem::TreeKemPublic;
+
+#[cfg(feature = "psk")]
+use crate::group::{
+ proposal::PreSharedKeyProposal, JustPreSharedKeyID, ResumptionPSKUsage, ResumptionPsk,
+};
+
+#[cfg(all(feature = "std", feature = "psk"))]
+use std::collections::HashSet;
+
+impl<'a, C, P, CSP> ProposalApplier<'a, C, P, CSP>
+where
+ C: IdentityProvider,
+ P: PreSharedKeyStorage,
+ CSP: CipherSuiteProvider,
+{
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(super) async fn apply_proposals_from_member(
+ &self,
+ commit_sender: LeafIndex,
+ proposals: &ProposalBundle,
+ commit_time: Option<MlsTime>,
+ ) -> Result<ApplyProposalsOutput, MlsError> {
+ filter_out_removal_of_committer(commit_sender, proposals)?;
+ filter_out_invalid_psks(self.cipher_suite_provider, proposals, self.psk_storage).await?;
+
+ #[cfg(feature = "by_ref_proposal")]
+ filter_out_invalid_group_extensions(proposals, self.identity_provider, commit_time).await?;
+
+ filter_out_extra_group_context_extensions(proposals)?;
+ filter_out_invalid_reinit(proposals, self.protocol_version)?;
+ filter_out_reinit_if_other_proposals(proposals)?;
+
+ self.apply_proposal_changes(proposals, commit_time).await
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(super) async fn apply_proposal_changes(
+ &self,
+ proposals: &ProposalBundle,
+ commit_time: Option<MlsTime>,
+ ) -> Result<ApplyProposalsOutput, MlsError> {
+ match proposals.group_context_extensions_proposal().cloned() {
+ Some(p) => {
+ self.apply_proposals_with_new_capabilities(proposals, p, commit_time)
+ .await
+ }
+ None => {
+ self.apply_tree_changes(proposals, self.original_group_extensions, commit_time)
+ .await
+ }
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(super) async fn apply_tree_changes(
+ &self,
+ proposals: &ProposalBundle,
+ group_extensions_in_use: &ExtensionList,
+ commit_time: Option<MlsTime>,
+ ) -> Result<ApplyProposalsOutput, MlsError> {
+ self.validate_new_nodes(proposals, group_extensions_in_use, commit_time)
+ .await?;
+
+ let mut new_tree = self.original_tree.clone();
+
+ let added = new_tree
+ .batch_edit_lite(
+ proposals,
+ group_extensions_in_use,
+ self.identity_provider,
+ self.cipher_suite_provider,
+ )
+ .await?;
+
+ let new_context_extensions = proposals
+ .group_context_extensions
+ .first()
+ .map(|gce| gce.proposal.clone());
+
+ Ok(ApplyProposalsOutput {
+ new_tree,
+ indexes_of_added_kpkgs: added,
+ external_init_index: None,
+ new_context_extensions,
+ })
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn validate_new_nodes(
+ &self,
+ proposals: &ProposalBundle,
+ group_extensions_in_use: &ExtensionList,
+ commit_time: Option<MlsTime>,
+ ) -> Result<(), MlsError> {
+ let leaf_node_validator = &LeafNodeValidator::new(
+ self.cipher_suite_provider,
+ self.identity_provider,
+ Some(group_extensions_in_use),
+ );
+
+ let adds = wrap_iter(proposals.add_proposals());
+
+ #[cfg(mls_build_async)]
+ let adds = adds.map(Ok);
+
+ { adds }
+ .try_for_each(|p| {
+ self.validate_new_node(leaf_node_validator, &p.proposal.key_package, commit_time)
+ })
+ .await
+ }
+}
+
+fn filter_out_removal_of_committer(
+ commit_sender: LeafIndex,
+ proposals: &ProposalBundle,
+) -> Result<(), MlsError> {
+ for p in &proposals.removals {
+ (p.proposal.to_remove != commit_sender)
+ .then_some(())
+ .ok_or(MlsError::CommitterSelfRemoval)?;
+ }
+
+ Ok(())
+}
+
+#[cfg(feature = "by_ref_proposal")]
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+async fn filter_out_invalid_group_extensions<C>(
+ proposals: &ProposalBundle,
+ identity_provider: &C,
+ commit_time: Option<MlsTime>,
+) -> Result<(), MlsError>
+where
+ C: IdentityProvider,
+{
+ if let Some(p) = proposals.group_context_extensions.first() {
+ if let Some(ext) = p.proposal.get_as::<ExternalSendersExt>()? {
+ ext.verify_all(identity_provider, commit_time, p.proposal())
+ .await
+ .map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))?;
+ }
+ }
+
+ Ok(())
+}
+
+fn filter_out_extra_group_context_extensions(proposals: &ProposalBundle) -> Result<(), MlsError> {
+ (proposals.group_context_extensions.len() < 2)
+ .then_some(())
+ .ok_or(MlsError::MoreThanOneGroupContextExtensionsProposal)
+}
+
+fn filter_out_invalid_reinit(
+ proposals: &ProposalBundle,
+ protocol_version: ProtocolVersion,
+) -> Result<(), MlsError> {
+ if let Some(p) = proposals.reinitializations.first() {
+ (p.proposal.version >= protocol_version)
+ .then_some(())
+ .ok_or(MlsError::InvalidProtocolVersionInReInit)?;
+ }
+
+ Ok(())
+}
+
+fn filter_out_reinit_if_other_proposals(proposals: &ProposalBundle) -> Result<(), MlsError> {
+ (proposals.reinitializations.is_empty() || proposals.length() == 1)
+ .then_some(())
+ .ok_or(MlsError::OtherProposalWithReInit)
+}
+
+#[cfg(feature = "custom_proposal")]
+pub(super) fn filter_out_unsupported_custom_proposals(
+ proposals: &ProposalBundle,
+ tree: &TreeKemPublic,
+) -> Result<(), MlsError> {
+ let supported_types = proposals
+ .custom_proposal_types()
+ .filter(|t| tree.can_support_proposal(*t))
+ .collect_vec();
+
+ for p in &proposals.custom_proposals {
+ let proposal_type = p.proposal.proposal_type();
+
+ supported_types
+ .contains(&proposal_type)
+ .then_some(())
+ .ok_or(MlsError::UnsupportedCustomProposal(proposal_type))?;
+ }
+
+ Ok(())
+}
diff --git a/src/group/proposal_ref.rs b/src/group/proposal_ref.rs
new file mode 100644
index 0000000..c97c9a1
--- /dev/null
+++ b/src/group/proposal_ref.rs
@@ -0,0 +1,226 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use core::ops::Deref;
+
+use super::*;
+use crate::hash_reference::HashReference;
+
+#[cfg_attr(
+ all(feature = "ffi", not(test)),
+ safer_ffi_gen::ffi_type(clone, opaque)
+)]
+#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+/// Unique identifier for a proposal message.
+pub struct ProposalRef(HashReference);
+
+impl Deref for ProposalRef {
+ type Target = [u8];
+
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+#[cfg_attr(all(feature = "ffi", not(test)), ::safer_ffi_gen::safer_ffi_gen)]
+impl ProposalRef {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn from_content<CS: CipherSuiteProvider>(
+ cipher_suite_provider: &CS,
+ content: &AuthenticatedContent,
+ ) -> Result<Self, MlsError> {
+ let bytes = &content.mls_encode_to_vec()?;
+
+ Ok(ProposalRef(
+ HashReference::compute(bytes, b"MLS 1.0 Proposal Reference", cipher_suite_provider)
+ .await?,
+ ))
+ }
+
+ pub fn as_slice(&self) -> &[u8] {
+ &self.0
+ }
+}
+
+#[cfg(test)]
+pub(crate) mod test_utils {
+ use super::*;
+ use crate::group::test_utils::{random_bytes, TEST_GROUP};
+ use alloc::boxed::Box;
+
+ impl ProposalRef {
+ pub fn new_fake(bytes: Vec<u8>) -> Self {
+ Self(bytes.into())
+ }
+ }
+
+ pub fn auth_content_from_proposal<S>(proposal: Proposal, sender: S) -> AuthenticatedContent
+ where
+ S: Into<Sender>,
+ {
+ AuthenticatedContent {
+ wire_format: WireFormat::PublicMessage,
+ content: FramedContent {
+ group_id: TEST_GROUP.to_vec(),
+ epoch: 0,
+ sender: sender.into(),
+ authenticated_data: vec![],
+ content: Content::Proposal(Box::new(proposal)),
+ },
+ auth: FramedContentAuthData {
+ signature: MessageSignature::from(random_bytes(128)),
+ confirmation_tag: None,
+ },
+ }
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use super::test_utils::auth_content_from_proposal;
+ use super::*;
+ use crate::{
+ crypto::test_utils::{test_cipher_suite_provider, try_test_cipher_suite_provider},
+ key_package::test_utils::test_key_package,
+ tree_kem::leaf_node::test_utils::get_basic_test_node,
+ };
+ use alloc::boxed::Box;
+
+ use crate::extension::RequiredCapabilitiesExt;
+
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn get_test_extension_list() -> ExtensionList {
+ let test_extension = RequiredCapabilitiesExt {
+ extensions: vec![42.into()],
+ proposals: Default::default(),
+ credentials: vec![],
+ };
+
+ let mut extension_list = ExtensionList::new();
+ extension_list.set_from(test_extension).unwrap();
+
+ extension_list
+ }
+
+ #[derive(serde::Serialize, serde::Deserialize)]
+ struct TestCase {
+ cipher_suite: u16,
+ #[serde(with = "hex::serde")]
+ input: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ output: Vec<u8>,
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ async fn generate_proposal_test_cases() -> Vec<TestCase> {
+ let mut test_cases = Vec::new();
+
+ for (protocol_version, cipher_suite) in
+ ProtocolVersion::all().flat_map(|p| CipherSuite::all().map(move |cs| (p, cs)))
+ {
+ let sender = LeafIndex(0);
+
+ let add = auth_content_from_proposal(
+ Proposal::Add(Box::new(AddProposal {
+ key_package: test_key_package(protocol_version, cipher_suite, "alice").await,
+ })),
+ sender,
+ );
+
+ let update = auth_content_from_proposal(
+ Proposal::Update(UpdateProposal {
+ leaf_node: get_basic_test_node(cipher_suite, "foo").await,
+ }),
+ sender,
+ );
+
+ let remove = auth_content_from_proposal(
+ Proposal::Remove(RemoveProposal {
+ to_remove: LeafIndex(1),
+ }),
+ sender,
+ );
+
+ let group_context_ext = auth_content_from_proposal(
+ Proposal::GroupContextExtensions(get_test_extension_list()),
+ sender,
+ );
+
+ let cipher_suite_provider = test_cipher_suite_provider(cipher_suite);
+
+ test_cases.push(TestCase {
+ cipher_suite: cipher_suite.into(),
+ input: add.mls_encode_to_vec().unwrap(),
+ output: ProposalRef::from_content(&cipher_suite_provider, &add)
+ .await
+ .unwrap()
+ .to_vec(),
+ });
+
+ test_cases.push(TestCase {
+ cipher_suite: cipher_suite.into(),
+ input: update.mls_encode_to_vec().unwrap(),
+ output: ProposalRef::from_content(&cipher_suite_provider, &update)
+ .await
+ .unwrap()
+ .to_vec(),
+ });
+
+ test_cases.push(TestCase {
+ cipher_suite: cipher_suite.into(),
+ input: remove.mls_encode_to_vec().unwrap(),
+ output: ProposalRef::from_content(&cipher_suite_provider, &remove)
+ .await
+ .unwrap()
+ .to_vec(),
+ });
+
+ test_cases.push(TestCase {
+ cipher_suite: cipher_suite.into(),
+ input: group_context_ext.mls_encode_to_vec().unwrap(),
+ output: ProposalRef::from_content(&cipher_suite_provider, &group_context_ext)
+ .await
+ .unwrap()
+ .to_vec(),
+ });
+ }
+
+ test_cases
+ }
+
+ #[cfg(mls_build_async)]
+ async fn load_test_cases() -> Vec<TestCase> {
+ load_test_case_json!(proposal_ref, generate_proposal_test_cases().await)
+ }
+
+ #[cfg(not(mls_build_async))]
+ fn load_test_cases() -> Vec<TestCase> {
+ load_test_case_json!(proposal_ref, generate_proposal_test_cases())
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_proposal_ref() {
+ let test_cases = load_test_cases().await;
+
+ for one_case in test_cases {
+ let Some(cs_provider) = try_test_cipher_suite_provider(one_case.cipher_suite) else {
+ continue;
+ };
+
+ let proposal_content =
+ AuthenticatedContent::mls_decode(&mut one_case.input.as_slice()).unwrap();
+
+ let proposal_ref = ProposalRef::from_content(&cs_provider, &proposal_content)
+ .await
+ .unwrap();
+
+ let expected_out = ProposalRef(HashReference::from(one_case.output));
+
+ assert_eq!(expected_out, proposal_ref);
+ }
+ }
+}
diff --git a/src/group/resumption.rs b/src/group/resumption.rs
new file mode 100644
index 0000000..3478ef3
--- /dev/null
+++ b/src/group/resumption.rs
@@ -0,0 +1,299 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::vec::Vec;
+
+use mls_rs_core::{
+ crypto::{CipherSuite, SignatureSecretKey},
+ extension::ExtensionList,
+ identity::SigningIdentity,
+ protocol_version::ProtocolVersion,
+};
+
+use crate::{client::MlsError, Client, Group, MlsMessage};
+
+use super::{
+ proposal::ReInitProposal, ClientConfig, ExportedTree, JustPreSharedKeyID, MessageProcessor,
+ NewMemberInfo, PreSharedKeyID, PskGroupId, PskSecretInput, ResumptionPSKUsage, ResumptionPsk,
+};
+
+struct ResumptionGroupParameters<'a> {
+ group_id: &'a [u8],
+ cipher_suite: CipherSuite,
+ version: ProtocolVersion,
+ extensions: &'a ExtensionList,
+}
+
+pub struct ReinitClient<C: ClientConfig + Clone> {
+ client: Client<C>,
+ reinit: ReInitProposal,
+ psk_input: PskSecretInput,
+}
+
+impl<C> Group<C>
+where
+ C: ClientConfig + Clone,
+{
+ /// Create a sub-group from a subset of the current group members.
+ ///
+ /// Membership within the resulting sub-group is indicated by providing a
+ /// key package that produces the same
+ /// [identity](crate::IdentityProvider::identity) value
+ /// as an existing group member. The identity value of each key package
+ /// is determined using the
+ /// [`IdentityProvider`](crate::IdentityProvider)
+ /// that is currently in use by this group instance.
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn branch(
+ &self,
+ sub_group_id: Vec<u8>,
+ new_key_packages: Vec<MlsMessage>,
+ ) -> Result<(Group<C>, Vec<MlsMessage>), MlsError> {
+ let new_group_params = ResumptionGroupParameters {
+ group_id: &sub_group_id,
+ cipher_suite: self.cipher_suite(),
+ version: self.protocol_version(),
+ extensions: &self.group_state().context.extensions,
+ };
+
+ resumption_create_group(
+ self.config.clone(),
+ new_key_packages,
+ &new_group_params,
+ // TODO investigate if it's worth updating your own signing identity here
+ self.current_member_signing_identity()?.clone(),
+ self.signer.clone(),
+ #[cfg(any(feature = "private_message", feature = "psk"))]
+ self.resumption_psk_input(ResumptionPSKUsage::Branch)?,
+ )
+ .await
+ }
+
+ /// Join a subgroup that was created by [`Group::branch`].
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn join_subgroup(
+ &self,
+ welcome: &MlsMessage,
+ tree_data: Option<ExportedTree<'_>>,
+ ) -> Result<(Group<C>, NewMemberInfo), MlsError> {
+ let expected_new_group_prams = ResumptionGroupParameters {
+ group_id: &[],
+ cipher_suite: self.cipher_suite(),
+ version: self.protocol_version(),
+ extensions: &self.group_state().context.extensions,
+ };
+
+ resumption_join_group(
+ self.config.clone(),
+ self.signer.clone(),
+ welcome,
+ tree_data,
+ expected_new_group_prams,
+ false,
+ self.resumption_psk_input(ResumptionPSKUsage::Branch)?,
+ )
+ .await
+ }
+
+ /// Generate a [`ReinitClient`] that can be used to create or join a new group
+ /// that is based on properties defined by a [`ReInitProposal`]
+ /// committed in a previously accepted commit. This is the only action available
+ /// after accepting such a commit. The old group can no longer be used according to the RFC.
+ ///
+ /// If the [`ReInitProposal`] changes the ciphersuite, then `new_signer`
+ /// and `new_signer_identity` must be set and match the new ciphersuite, as indicated by
+ /// [`pending_reinit_ciphersuite`](crate::group::StateUpdate::pending_reinit_ciphersuite)
+ /// of the [`StateUpdate`](crate::group::StateUpdate) outputted after processing the
+ /// commit to the reinit proposal. The value of [identity](crate::IdentityProvider::identity)
+ /// must be the same for `new_signing_identity` and the current identity in use by this
+ /// group instance.
+ pub fn get_reinit_client(
+ self,
+ new_signer: Option<SignatureSecretKey>,
+ new_signing_identity: Option<SigningIdentity>,
+ ) -> Result<ReinitClient<C>, MlsError> {
+ let psk_input = self.resumption_psk_input(ResumptionPSKUsage::Reinit)?;
+
+ let new_signing_identity = new_signing_identity
+ .map(Ok)
+ .unwrap_or_else(|| self.current_member_signing_identity().cloned())?;
+
+ let reinit = self
+ .state
+ .pending_reinit
+ .ok_or(MlsError::PendingReInitNotFound)?;
+
+ let new_signer = match new_signer {
+ Some(signer) => signer,
+ None => self.signer,
+ };
+
+ let client = Client::new(
+ self.config,
+ Some(new_signer),
+ Some((new_signing_identity, reinit.new_cipher_suite())),
+ reinit.new_version(),
+ );
+
+ Ok(ReinitClient {
+ client,
+ reinit,
+ psk_input,
+ })
+ }
+
+ fn resumption_psk_input(&self, usage: ResumptionPSKUsage) -> Result<PskSecretInput, MlsError> {
+ let psk = self.epoch_secrets.resumption_secret.clone();
+
+ let id = JustPreSharedKeyID::Resumption(ResumptionPsk {
+ usage,
+ psk_group_id: PskGroupId(self.group_id().to_vec()),
+ psk_epoch: self.current_epoch(),
+ });
+
+ let id = PreSharedKeyID::new(id, self.cipher_suite_provider())?;
+ Ok(PskSecretInput { id, psk })
+ }
+}
+
+/// A [`Client`] that can be used to create or join a new group
+/// that is based on properties defined by a [`ReInitProposal`]
+/// committed in a previously accepted commit.
+impl<C: ClientConfig + Clone> ReinitClient<C> {
+ /// Generate a key package for the new group. The key package can
+ /// be used in [`ReinitClient::commit`].
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn generate_key_package(&self) -> Result<MlsMessage, MlsError> {
+ self.client.generate_key_package_message().await
+ }
+
+ /// Create the new group using new key packages of all group members, possibly
+ /// generated by [`ReinitClient::generate_key_package`].
+ ///
+ /// # Warning
+ ///
+ /// This function will fail if the number of members in the reinitialized
+ /// group is not the same as the prior group roster.
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn commit(
+ self,
+ new_key_packages: Vec<MlsMessage>,
+ ) -> Result<(Group<C>, Vec<MlsMessage>), MlsError> {
+ let new_group_params = ResumptionGroupParameters {
+ group_id: self.reinit.group_id(),
+ cipher_suite: self.reinit.new_cipher_suite(),
+ version: self.reinit.new_version(),
+ extensions: self.reinit.new_group_context_extensions(),
+ };
+
+ resumption_create_group(
+ self.client.config.clone(),
+ new_key_packages,
+ &new_group_params,
+ // These private fields are created with `Some(x)` by `get_reinit_client`
+ self.client.signing_identity.unwrap().0,
+ self.client.signer.unwrap(),
+ #[cfg(any(feature = "private_message", feature = "psk"))]
+ self.psk_input,
+ )
+ .await
+ }
+
+ /// Join a reinitialized group that was created by [`ReinitClient::commit`].
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn join(
+ self,
+ welcome: &MlsMessage,
+ tree_data: Option<ExportedTree<'_>>,
+ ) -> Result<(Group<C>, NewMemberInfo), MlsError> {
+ let reinit = self.reinit;
+
+ let expected_group_params = ResumptionGroupParameters {
+ group_id: reinit.group_id(),
+ cipher_suite: reinit.new_cipher_suite(),
+ version: reinit.new_version(),
+ extensions: reinit.new_group_context_extensions(),
+ };
+
+ resumption_join_group(
+ self.client.config,
+ // This private field is created with `Some(x)` by `get_reinit_client`
+ self.client.signer.unwrap(),
+ welcome,
+ tree_data,
+ expected_group_params,
+ true,
+ self.psk_input,
+ )
+ .await
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+async fn resumption_create_group<C: ClientConfig + Clone>(
+ config: C,
+ new_key_packages: Vec<MlsMessage>,
+ new_group_params: &ResumptionGroupParameters<'_>,
+ signing_identity: SigningIdentity,
+ signer: SignatureSecretKey,
+ psk_input: PskSecretInput,
+) -> Result<(Group<C>, Vec<MlsMessage>), MlsError> {
+ // Create a new group with new parameters
+ let mut group = Group::new(
+ config,
+ Some(new_group_params.group_id.to_vec()),
+ new_group_params.cipher_suite,
+ new_group_params.version,
+ signing_identity,
+ new_group_params.extensions.clone(),
+ signer,
+ )
+ .await?;
+
+ // Install the resumption psk in the new group
+ group.previous_psk = Some(psk_input);
+
+ // Create a commit that adds new key packages and uses the resumption PSK
+ let mut commit = group.commit_builder();
+
+ for kp in new_key_packages.into_iter() {
+ commit = commit.add_member(kp)?;
+ }
+
+ let commit = commit.build().await?;
+ group.apply_pending_commit().await?;
+
+ // Uninstall the resumption psk on success (in case of failure, the new group is discarded anyway)
+ group.previous_psk = None;
+
+ Ok((group, commit.welcome_messages))
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+async fn resumption_join_group<C: ClientConfig + Clone>(
+ config: C,
+ signer: SignatureSecretKey,
+ welcome: &MlsMessage,
+ tree_data: Option<ExportedTree<'_>>,
+ expected_new_group_params: ResumptionGroupParameters<'_>,
+ verify_group_id: bool,
+ psk_input: PskSecretInput,
+) -> Result<(Group<C>, NewMemberInfo), MlsError> {
+ let psk_input = Some(psk_input);
+
+ let (group, new_member_info) =
+ Group::<C>::from_welcome_message(welcome, tree_data, config, signer, psk_input).await?;
+
+ if group.protocol_version() != expected_new_group_params.version {
+ Err(MlsError::ProtocolVersionMismatch)
+ } else if group.cipher_suite() != expected_new_group_params.cipher_suite {
+ Err(MlsError::CipherSuiteMismatch)
+ } else if verify_group_id && group.group_id() != expected_new_group_params.group_id {
+ Err(MlsError::GroupIdMismatch)
+ } else if &group.group_state().context.extensions != expected_new_group_params.extensions {
+ Err(MlsError::ReInitExtensionsMismatch)
+ } else {
+ Ok((group, new_member_info))
+ }
+}
diff --git a/src/group/roster.rs b/src/group/roster.rs
new file mode 100644
index 0000000..dd0a9f0
--- /dev/null
+++ b/src/group/roster.rs
@@ -0,0 +1,91 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use super::*;
+
+pub use mls_rs_core::group::Member;
+
+#[cfg(feature = "state_update")]
+pub(crate) fn member_from_key_package(key_package: &KeyPackage, index: LeafIndex) -> Member {
+ member_from_leaf_node(&key_package.leaf_node, index)
+}
+
+pub(crate) fn member_from_leaf_node(leaf_node: &LeafNode, leaf_index: LeafIndex) -> Member {
+ Member::new(
+ *leaf_index,
+ leaf_node.signing_identity.clone(),
+ leaf_node.ungreased_capabilities(),
+ leaf_node.ungreased_extensions(),
+ )
+}
+
+#[cfg_attr(
+ all(feature = "ffi", not(test)),
+ safer_ffi_gen::ffi_type(clone, opaque)
+)]
+#[derive(Clone, Debug)]
+pub struct Roster<'a> {
+ pub(crate) public_tree: &'a TreeKemPublic,
+}
+
+#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
+impl<'a> Roster<'a> {
+ /// Iterator over the current roster that lazily copies data out of the
+ /// internal group state.
+ ///
+ /// # Warning
+ ///
+ /// The indexes within this iterator do not correlate with indexes of users
+ /// within [`ReceivedMessage`] content descriptions due to the layout of
+ /// member information within a MLS group state.
+ #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen_ignore)]
+ pub fn members_iter(&self) -> impl Iterator<Item = Member> + 'a {
+ self.public_tree
+ .non_empty_leaves()
+ .map(|(index, node)| member_from_leaf_node(node, index))
+ }
+
+ /// The current set of group members. This function makes a clone of
+ /// member information from the internal group state.
+ ///
+ /// # Warning
+ ///
+ /// The indexes within this roster do not correlate with indexes of users
+ /// within [`ReceivedMessage`] content descriptions due to the layout of
+ /// member information within a MLS group state.
+ pub fn members(&self) -> Vec<Member> {
+ self.members_iter().collect()
+ }
+
+ /// Retrieve the member with given `index` within the group in time `O(1)`.
+ /// This index does correlate with indexes of users within [`ReceivedMessage`]
+ /// content descriptions.
+ pub fn member_with_index(&self, index: u32) -> Result<Member, MlsError> {
+ let index = LeafIndex(index);
+
+ self.public_tree
+ .get_leaf_node(index)
+ .map(|l| member_from_leaf_node(l, index))
+ }
+
+ /// Iterator over member's signing identities.
+ ///
+ /// # Warning
+ ///
+ /// The indexes within this iterator do not correlate with indexes of users
+ /// within [`ReceivedMessage`] content descriptions due to the layout of
+ /// member information within a MLS group state.
+ #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen_ignore)]
+ pub fn member_identities_iter(&self) -> impl Iterator<Item = &SigningIdentity> + '_ {
+ self.public_tree
+ .non_empty_leaves()
+ .map(|(_, node)| &node.signing_identity)
+ }
+}
+
+impl TreeKemPublic {
+ pub(crate) fn roster(&self) -> Roster {
+ Roster { public_tree: self }
+ }
+}
diff --git a/src/group/secret_tree.rs b/src/group/secret_tree.rs
new file mode 100644
index 0000000..df0c30f
--- /dev/null
+++ b/src/group/secret_tree.rs
@@ -0,0 +1,1115 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::vec::Vec;
+use core::{
+ fmt::{self, Debug},
+ ops::{Deref, DerefMut},
+};
+
+use zeroize::Zeroizing;
+
+use crate::{client::MlsError, tree_kem::math::TreeIndex, CipherSuiteProvider};
+
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+use mls_rs_core::error::IntoAnyError;
+
+#[cfg(feature = "std")]
+use std::collections::HashMap;
+
+#[cfg(not(feature = "std"))]
+use alloc::collections::BTreeMap;
+
+use super::key_schedule::kdf_expand_with_label;
+
+pub(crate) const MAX_RATCHET_BACK_HISTORY: u32 = 1024;
+
+#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+#[repr(u8)]
+enum SecretTreeNode {
+ Secret(TreeSecret) = 0u8,
+ Ratchet(SecretRatchets) = 1u8,
+}
+
+impl SecretTreeNode {
+ fn into_secret(self) -> Option<TreeSecret> {
+ if let SecretTreeNode::Secret(secret) = self {
+ Some(secret)
+ } else {
+ None
+ }
+ }
+}
+
+#[derive(Clone, PartialEq, MlsEncode, MlsDecode, MlsSize)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+struct TreeSecret(
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))]
+ Zeroizing<Vec<u8>>,
+);
+
+impl Debug for TreeSecret {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ mls_rs_core::debug::pretty_bytes(&self.0)
+ .named("TreeSecret")
+ .fmt(f)
+ }
+}
+
+impl Deref for TreeSecret {
+ type Target = Vec<u8>;
+
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+impl DerefMut for TreeSecret {
+ fn deref_mut(&mut self) -> &mut Self::Target {
+ &mut self.0
+ }
+}
+
+impl AsRef<[u8]> for TreeSecret {
+ fn as_ref(&self) -> &[u8] {
+ &self.0
+ }
+}
+
+impl From<Vec<u8>> for TreeSecret {
+ fn from(vec: Vec<u8>) -> Self {
+ TreeSecret(Zeroizing::new(vec))
+ }
+}
+
+impl From<Zeroizing<Vec<u8>>> for TreeSecret {
+ fn from(vec: Zeroizing<Vec<u8>>) -> Self {
+ TreeSecret(vec)
+ }
+}
+
+#[derive(Clone, Debug, PartialEq, MlsEncode, MlsDecode, MlsSize, Default)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+struct TreeSecretsVec<T: TreeIndex> {
+ #[cfg(feature = "std")]
+ inner: HashMap<T, SecretTreeNode>,
+ #[cfg(not(feature = "std"))]
+ inner: Vec<(T, SecretTreeNode)>,
+}
+
+#[cfg(feature = "std")]
+impl<T: TreeIndex> TreeSecretsVec<T> {
+ fn set_node(&mut self, index: T, value: SecretTreeNode) {
+ self.inner.insert(index, value);
+ }
+
+ fn take_node(&mut self, index: &T) -> Option<SecretTreeNode> {
+ self.inner.remove(index)
+ }
+}
+
+#[cfg(not(feature = "std"))]
+impl<T: TreeIndex> TreeSecretsVec<T> {
+ fn set_node(&mut self, index: T, value: SecretTreeNode) {
+ if let Some(i) = self.find_node(&index) {
+ self.inner[i] = (index, value)
+ } else {
+ self.inner.push((index, value))
+ }
+ }
+
+ fn take_node(&mut self, index: &T) -> Option<SecretTreeNode> {
+ self.find_node(index).map(|i| self.inner.remove(i).1)
+ }
+
+ fn find_node(&self, index: &T) -> Option<usize> {
+ use itertools::Itertools;
+
+ self.inner
+ .iter()
+ .find_position(|(i, _)| i == index)
+ .map(|(i, _)| i)
+ }
+}
+
+#[derive(Clone, Debug, PartialEq, MlsEncode, MlsDecode, MlsSize)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub struct SecretTree<T: TreeIndex> {
+ known_secrets: TreeSecretsVec<T>,
+ leaf_count: T,
+}
+
+impl<T: TreeIndex> SecretTree<T> {
+ pub(crate) fn empty() -> SecretTree<T> {
+ SecretTree {
+ known_secrets: Default::default(),
+ leaf_count: T::zero(),
+ }
+ }
+}
+
+#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub struct SecretRatchets {
+ pub application: SecretKeyRatchet,
+ pub handshake: SecretKeyRatchet,
+}
+
+impl SecretRatchets {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn message_key_generation<P: CipherSuiteProvider>(
+ &mut self,
+ cipher_suite_provider: &P,
+ generation: u32,
+ key_type: KeyType,
+ ) -> Result<MessageKeyData, MlsError> {
+ match key_type {
+ KeyType::Handshake => {
+ self.handshake
+ .get_message_key(cipher_suite_provider, generation)
+ .await
+ }
+ KeyType::Application => {
+ self.application
+ .get_message_key(cipher_suite_provider, generation)
+ .await
+ }
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn next_message_key<P: CipherSuiteProvider>(
+ &mut self,
+ cipher_suite: &P,
+ key_type: KeyType,
+ ) -> Result<MessageKeyData, MlsError> {
+ match key_type {
+ KeyType::Handshake => self.handshake.next_message_key(cipher_suite).await,
+ KeyType::Application => self.application.next_message_key(cipher_suite).await,
+ }
+ }
+}
+
+impl<T: TreeIndex> SecretTree<T> {
+ pub fn new(leaf_count: T, encryption_secret: Zeroizing<Vec<u8>>) -> SecretTree<T> {
+ let mut known_secrets = TreeSecretsVec::default();
+
+ let root_secret = SecretTreeNode::Secret(TreeSecret::from(encryption_secret));
+ known_secrets.set_node(leaf_count.root(), root_secret);
+
+ Self {
+ known_secrets,
+ leaf_count,
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn consume_node<P: CipherSuiteProvider>(
+ &mut self,
+ cipher_suite_provider: &P,
+ index: &T,
+ ) -> Result<(), MlsError> {
+ let node = self.known_secrets.take_node(index);
+
+ if let Some(secret) = node.and_then(|n| n.into_secret()) {
+ let left_index = index.left().ok_or(MlsError::LeafNodeNoChildren)?;
+ let right_index = index.right().ok_or(MlsError::LeafNodeNoChildren)?;
+
+ let left_secret =
+ kdf_expand_with_label(cipher_suite_provider, &secret, b"tree", b"left", None)
+ .await?;
+
+ let right_secret =
+ kdf_expand_with_label(cipher_suite_provider, &secret, b"tree", b"right", None)
+ .await?;
+
+ self.known_secrets
+ .set_node(left_index, SecretTreeNode::Secret(left_secret.into()));
+
+ self.known_secrets
+ .set_node(right_index, SecretTreeNode::Secret(right_secret.into()));
+ }
+
+ Ok(())
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn take_leaf_ratchet<P: CipherSuiteProvider>(
+ &mut self,
+ cipher_suite: &P,
+ leaf_index: &T,
+ ) -> Result<SecretRatchets, MlsError> {
+ let node_index = leaf_index;
+
+ let node = match self.known_secrets.take_node(node_index) {
+ Some(node) => node,
+ None => {
+ // Start at the root node and work your way down consuming any intermediates needed
+ for i in node_index.direct_copath(&self.leaf_count).into_iter().rev() {
+ self.consume_node(cipher_suite, &i.path).await?;
+ }
+
+ self.known_secrets
+ .take_node(node_index)
+ .ok_or(MlsError::InvalidLeafConsumption)?
+ }
+ };
+
+ Ok(match node {
+ SecretTreeNode::Ratchet(ratchet) => ratchet,
+ SecretTreeNode::Secret(secret) => SecretRatchets {
+ application: SecretKeyRatchet::new(cipher_suite, &secret, KeyType::Application)
+ .await?,
+ handshake: SecretKeyRatchet::new(cipher_suite, &secret, KeyType::Handshake).await?,
+ },
+ })
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn next_message_key<P: CipherSuiteProvider>(
+ &mut self,
+ cipher_suite: &P,
+ leaf_index: T,
+ key_type: KeyType,
+ ) -> Result<MessageKeyData, MlsError> {
+ let mut ratchet = self.take_leaf_ratchet(cipher_suite, &leaf_index).await?;
+ let res = ratchet.next_message_key(cipher_suite, key_type).await?;
+
+ self.known_secrets
+ .set_node(leaf_index, SecretTreeNode::Ratchet(ratchet));
+
+ Ok(res)
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn message_key_generation<P: CipherSuiteProvider>(
+ &mut self,
+ cipher_suite: &P,
+ leaf_index: T,
+ key_type: KeyType,
+ generation: u32,
+ ) -> Result<MessageKeyData, MlsError> {
+ let mut ratchet = self.take_leaf_ratchet(cipher_suite, &leaf_index).await?;
+
+ let res = ratchet
+ .message_key_generation(cipher_suite, generation, key_type)
+ .await?;
+
+ self.known_secrets
+ .set_node(leaf_index, SecretTreeNode::Ratchet(ratchet));
+
+ Ok(res)
+ }
+}
+
+#[derive(Clone, Copy)]
+pub enum KeyType {
+ Handshake,
+ Application,
+}
+
+#[cfg_attr(
+ all(feature = "ffi", not(test)),
+ safer_ffi_gen::ffi_type(clone, opaque)
+)]
+#[derive(Clone, PartialEq, Eq, MlsEncode, MlsDecode, MlsSize)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+/// AEAD key derived by the MLS secret tree.
+pub struct MessageKeyData {
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))]
+ pub(crate) nonce: Zeroizing<Vec<u8>>,
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))]
+ pub(crate) key: Zeroizing<Vec<u8>>,
+ pub(crate) generation: u32,
+}
+
+impl Debug for MessageKeyData {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("MessageKeyData")
+ .field("nonce", &mls_rs_core::debug::pretty_bytes(&self.nonce))
+ .field("key", &mls_rs_core::debug::pretty_bytes(&self.key))
+ .field("generation", &self.generation)
+ .finish()
+ }
+}
+
+#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
+impl MessageKeyData {
+ /// AEAD nonce.
+ #[cfg_attr(not(feature = "secret_tree_access"), allow(dead_code))]
+ pub fn nonce(&self) -> &[u8] {
+ &self.nonce
+ }
+
+ /// AEAD key.
+ #[cfg_attr(not(feature = "secret_tree_access"), allow(dead_code))]
+ pub fn key(&self) -> &[u8] {
+ &self.key
+ }
+
+ /// Generation of this key within the key schedule.
+ #[cfg_attr(not(feature = "secret_tree_access"), allow(dead_code))]
+ pub fn generation(&self) -> u32 {
+ self.generation
+ }
+}
+
+#[derive(Debug, Clone, PartialEq)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub struct SecretKeyRatchet {
+ secret: TreeSecret,
+ generation: u32,
+ #[cfg(all(feature = "out_of_order", feature = "std"))]
+ history: HashMap<u32, MessageKeyData>,
+ #[cfg(all(feature = "out_of_order", not(feature = "std")))]
+ history: BTreeMap<u32, MessageKeyData>,
+}
+
+impl MlsSize for SecretKeyRatchet {
+ fn mls_encoded_len(&self) -> usize {
+ let len = mls_rs_codec::byte_vec::mls_encoded_len(&self.secret)
+ + self.generation.mls_encoded_len();
+
+ #[cfg(feature = "out_of_order")]
+ return len + mls_rs_codec::iter::mls_encoded_len(self.history.values());
+ #[cfg(not(feature = "out_of_order"))]
+ return len;
+ }
+}
+
+#[cfg(feature = "out_of_order")]
+impl MlsEncode for SecretKeyRatchet {
+ fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error> {
+ mls_rs_codec::byte_vec::mls_encode(&self.secret, writer)?;
+ self.generation.mls_encode(writer)?;
+ mls_rs_codec::iter::mls_encode(self.history.values(), writer)
+ }
+}
+
+#[cfg(not(feature = "out_of_order"))]
+impl MlsEncode for SecretKeyRatchet {
+ fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error> {
+ mls_rs_codec::byte_vec::mls_encode(&self.secret, writer)?;
+ self.generation.mls_encode(writer)
+ }
+}
+
+impl MlsDecode for SecretKeyRatchet {
+ fn mls_decode(reader: &mut &[u8]) -> Result<Self, mls_rs_codec::Error> {
+ Ok(Self {
+ secret: mls_rs_codec::byte_vec::mls_decode(reader)?,
+ generation: u32::mls_decode(reader)?,
+ #[cfg(all(feature = "std", feature = "out_of_order"))]
+ history: mls_rs_codec::iter::mls_decode_collection(reader, |data| {
+ let mut items = HashMap::default();
+
+ while !data.is_empty() {
+ let item = MessageKeyData::mls_decode(data)?;
+ items.insert(item.generation, item);
+ }
+
+ Ok(items)
+ })?,
+ #[cfg(all(not(feature = "std"), feature = "out_of_order"))]
+ history: mls_rs_codec::iter::mls_decode_collection(reader, |data| {
+ let mut items = alloc::collections::BTreeMap::default();
+
+ while !data.is_empty() {
+ let item = MessageKeyData::mls_decode(data)?;
+ items.insert(item.generation, item);
+ }
+
+ Ok(items)
+ })?,
+ })
+ }
+}
+
+impl SecretKeyRatchet {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn new<P: CipherSuiteProvider>(
+ cipher_suite_provider: &P,
+ secret: &[u8],
+ key_type: KeyType,
+ ) -> Result<Self, MlsError> {
+ let label = match key_type {
+ KeyType::Handshake => b"handshake".as_slice(),
+ KeyType::Application => b"application".as_slice(),
+ };
+
+ let secret = kdf_expand_with_label(cipher_suite_provider, secret, label, &[], None)
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
+
+ Ok(Self {
+ secret: TreeSecret::from(secret),
+ generation: 0,
+ #[cfg(feature = "out_of_order")]
+ history: Default::default(),
+ })
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn get_message_key<P: CipherSuiteProvider>(
+ &mut self,
+ cipher_suite_provider: &P,
+ generation: u32,
+ ) -> Result<MessageKeyData, MlsError> {
+ #[cfg(feature = "out_of_order")]
+ if generation < self.generation {
+ return self
+ .history
+ .remove_entry(&generation)
+ .map(|(_, mk)| mk)
+ .ok_or(MlsError::KeyMissing(generation));
+ }
+
+ #[cfg(not(feature = "out_of_order"))]
+ if generation < self.generation {
+ return Err(MlsError::KeyMissing(generation));
+ }
+
+ let max_generation_allowed = self.generation + MAX_RATCHET_BACK_HISTORY;
+
+ if generation > max_generation_allowed {
+ return Err(MlsError::InvalidFutureGeneration(generation));
+ }
+
+ #[cfg(not(feature = "out_of_order"))]
+ while self.generation < generation {
+ self.next_message_key(cipher_suite_provider)?;
+ }
+
+ #[cfg(feature = "out_of_order")]
+ while self.generation < generation {
+ let key_data = self.next_message_key(cipher_suite_provider).await?;
+ self.history.insert(key_data.generation, key_data);
+ }
+
+ self.next_message_key(cipher_suite_provider).await
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn next_message_key<P: CipherSuiteProvider>(
+ &mut self,
+ cipher_suite_provider: &P,
+ ) -> Result<MessageKeyData, MlsError> {
+ let generation = self.generation;
+
+ let key = MessageKeyData {
+ nonce: self
+ .derive_secret(
+ cipher_suite_provider,
+ b"nonce",
+ cipher_suite_provider.aead_nonce_size(),
+ )
+ .await?,
+ key: self
+ .derive_secret(
+ cipher_suite_provider,
+ b"key",
+ cipher_suite_provider.aead_key_size(),
+ )
+ .await?,
+ generation,
+ };
+
+ self.secret = self
+ .derive_secret(
+ cipher_suite_provider,
+ b"secret",
+ cipher_suite_provider.kdf_extract_size(),
+ )
+ .await?
+ .into();
+
+ self.generation = generation + 1;
+
+ Ok(key)
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn derive_secret<P: CipherSuiteProvider>(
+ &self,
+ cipher_suite_provider: &P,
+ label: &[u8],
+ len: usize,
+ ) -> Result<Zeroizing<Vec<u8>>, MlsError> {
+ kdf_expand_with_label(
+ cipher_suite_provider,
+ self.secret.as_ref(),
+ label,
+ &self.generation.to_be_bytes(),
+ Some(len),
+ )
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
+ }
+}
+
+#[cfg(test)]
+pub(crate) mod test_utils {
+ use alloc::{string::String, vec::Vec};
+ use mls_rs_core::crypto::CipherSuiteProvider;
+ use zeroize::Zeroizing;
+
+ use crate::{crypto::test_utils::try_test_cipher_suite_provider, tree_kem::math::TreeIndex};
+
+ use super::{KeyType, SecretKeyRatchet, SecretTree};
+
+ pub(crate) fn get_test_tree<T: TreeIndex>(secret: Vec<u8>, leaf_count: T) -> SecretTree<T> {
+ SecretTree::new(leaf_count, Zeroizing::new(secret))
+ }
+
+ impl SecretTree<u32> {
+ pub(crate) fn get_root_secret(&self) -> Vec<u8> {
+ self.known_secrets
+ .clone()
+ .take_node(&self.leaf_count.root())
+ .unwrap()
+ .into_secret()
+ .unwrap()
+ .to_vec()
+ }
+ }
+
+ #[derive(Debug, serde::Serialize, serde::Deserialize)]
+ pub struct RatchetInteropTestCase {
+ #[serde(with = "hex::serde")]
+ secret: Vec<u8>,
+ label: String,
+ generation: u32,
+ length: usize,
+ #[serde(with = "hex::serde")]
+ out: Vec<u8>,
+ }
+
+ #[derive(Debug, serde::Serialize, serde::Deserialize)]
+ pub struct InteropTestCase {
+ cipher_suite: u16,
+ derive_tree_secret: RatchetInteropTestCase,
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_basic_crypto_test_vectors() {
+ let test_cases: Vec<InteropTestCase> =
+ load_test_case_json!(basic_crypto, Vec::<InteropTestCase>::new());
+
+ for test_case in test_cases {
+ if let Some(cs) = try_test_cipher_suite_provider(test_case.cipher_suite) {
+ test_case.derive_tree_secret.verify(&cs).await
+ }
+ }
+ }
+
+ impl RatchetInteropTestCase {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn verify<P: CipherSuiteProvider>(&self, cs: &P) {
+ let mut ratchet = SecretKeyRatchet::new(cs, &self.secret, KeyType::Application)
+ .await
+ .unwrap();
+
+ ratchet.secret = self.secret.clone().into();
+ ratchet.generation = self.generation;
+
+ let computed = ratchet
+ .derive_secret(cs, self.label.as_bytes(), self.length)
+ .await
+ .unwrap();
+
+ assert_eq!(&computed.to_vec(), &self.out);
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use alloc::vec;
+
+ use crate::{
+ cipher_suite::CipherSuite,
+ client::test_utils::TEST_CIPHER_SUITE,
+ crypto::test_utils::{
+ test_cipher_suite_provider, try_test_cipher_suite_provider, TestCryptoProvider,
+ },
+ tree_kem::node::NodeIndex,
+ };
+
+ #[cfg(not(mls_build_async))]
+ use crate::group::test_utils::random_bytes;
+
+ use super::{test_utils::get_test_tree, *};
+
+ use assert_matches::assert_matches;
+
+ #[cfg(target_arch = "wasm32")]
+ use wasm_bindgen_test::wasm_bindgen_test as test;
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_secret_tree() {
+ test_secret_tree_custom(16u32, (0..16).map(|i| 2 * i).collect(), true).await;
+ test_secret_tree_custom(1u64 << 62, (1..62).map(|i| 1u64 << i).collect(), false).await;
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn test_secret_tree_custom<T: TreeIndex>(
+ leaf_count: T,
+ leaves_to_check: Vec<T>,
+ all_deleted: bool,
+ ) {
+ for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
+ let cs_provider = test_cipher_suite_provider(cipher_suite);
+
+ let test_secret = vec![0u8; cs_provider.kdf_extract_size()];
+ let mut test_tree = get_test_tree(test_secret, leaf_count.clone());
+
+ let mut secrets = Vec::<SecretRatchets>::new();
+
+ for i in &leaves_to_check {
+ let secret = test_tree
+ .take_leaf_ratchet(&test_cipher_suite_provider(cipher_suite), i)
+ .await
+ .unwrap();
+
+ secrets.push(secret);
+ }
+
+ // Verify the tree is now completely empty
+ assert!(!all_deleted || test_tree.known_secrets.inner.is_empty());
+
+ // Verify that all the secrets are unique
+ let count = secrets.len();
+ secrets.dedup();
+ assert_eq!(count, secrets.len());
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_secret_key_ratchet() {
+ for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
+ let provider = test_cipher_suite_provider(cipher_suite);
+
+ let mut app_ratchet = SecretKeyRatchet::new(
+ &provider,
+ &vec![0u8; provider.kdf_extract_size()],
+ KeyType::Application,
+ )
+ .await
+ .unwrap();
+
+ let mut handshake_ratchet = SecretKeyRatchet::new(
+ &provider,
+ &vec![0u8; provider.kdf_extract_size()],
+ KeyType::Handshake,
+ )
+ .await
+ .unwrap();
+
+ let app_key_one = app_ratchet.next_message_key(&provider).await.unwrap();
+ let app_key_two = app_ratchet.next_message_key(&provider).await.unwrap();
+ let app_keys = vec![app_key_one, app_key_two];
+
+ let handshake_key_one = handshake_ratchet.next_message_key(&provider).await.unwrap();
+ let handshake_key_two = handshake_ratchet.next_message_key(&provider).await.unwrap();
+ let handshake_keys = vec![handshake_key_one, handshake_key_two];
+
+ // Verify that the keys have different outcomes due to their different labels
+ assert_ne!(app_keys, handshake_keys);
+
+ // Verify that the keys at each generation are different
+ assert_ne!(handshake_keys[0], handshake_keys[1]);
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_get_key() {
+ for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
+ let provider = test_cipher_suite_provider(cipher_suite);
+
+ let mut ratchet = SecretKeyRatchet::new(
+ &test_cipher_suite_provider(cipher_suite),
+ &vec![0u8; provider.kdf_extract_size()],
+ KeyType::Application,
+ )
+ .await
+ .unwrap();
+
+ let mut ratchet_clone = ratchet.clone();
+
+ // This will generate keys 0 and 1 in ratchet_clone
+ let _ = ratchet_clone.next_message_key(&provider).await.unwrap();
+ let clone_2 = ratchet_clone.next_message_key(&provider).await.unwrap();
+
+ // Going back in time should result in an error
+ let res = ratchet_clone.get_message_key(&provider, 0).await;
+ assert!(res.is_err());
+
+ // Calling get key should be the same as calling next until hitting the desired generation
+ let second_key = ratchet
+ .get_message_key(&provider, ratchet_clone.generation - 1)
+ .await
+ .unwrap();
+
+ assert_eq!(clone_2, second_key)
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_secret_ratchet() {
+ for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
+ let provider = test_cipher_suite_provider(cipher_suite);
+
+ let mut ratchet = SecretKeyRatchet::new(
+ &provider,
+ &vec![0u8; provider.kdf_extract_size()],
+ KeyType::Application,
+ )
+ .await
+ .unwrap();
+
+ let original_secret = ratchet.secret.clone();
+ let _ = ratchet.next_message_key(&provider).await.unwrap();
+ let new_secret = ratchet.secret;
+ assert_ne!(original_secret, new_secret)
+ }
+ }
+
+ #[cfg(feature = "out_of_order")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_out_of_order_keys() {
+ let cipher_suite = TEST_CIPHER_SUITE;
+ let provider = test_cipher_suite_provider(cipher_suite);
+
+ let mut ratchet = SecretKeyRatchet::new(&provider, &[0u8; 32], KeyType::Handshake)
+ .await
+ .unwrap();
+ let mut ratchet_clone = ratchet.clone();
+
+ // Ask for all the keys in order from the original ratchet
+ let mut ordered_keys = Vec::<MessageKeyData>::new();
+
+ for i in 0..=MAX_RATCHET_BACK_HISTORY {
+ ordered_keys.push(ratchet.get_message_key(&provider, i).await.unwrap());
+ }
+
+ // Ask for a key at index MAX_RATCHET_BACK_HISTORY in the clone
+ let last_key = ratchet_clone
+ .get_message_key(&provider, MAX_RATCHET_BACK_HISTORY)
+ .await
+ .unwrap();
+
+ assert_eq!(last_key, ordered_keys[ordered_keys.len() - 1]);
+
+ // Get all the other keys
+ let mut back_history_keys = Vec::<MessageKeyData>::new();
+
+ for i in 0..MAX_RATCHET_BACK_HISTORY - 1 {
+ back_history_keys.push(ratchet_clone.get_message_key(&provider, i).await.unwrap());
+ }
+
+ assert_eq!(
+ back_history_keys,
+ ordered_keys[..(MAX_RATCHET_BACK_HISTORY as usize) - 1]
+ );
+ }
+
+ #[cfg(not(feature = "out_of_order"))]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn out_of_order_keys_should_throw_error() {
+ let cipher_suite = TEST_CIPHER_SUITE;
+ let provider = test_cipher_suite_provider(cipher_suite);
+
+ let mut ratchet = SecretKeyRatchet::new(&provider, &[0u8; 32], KeyType::Handshake)
+ .await
+ .unwrap();
+
+ ratchet.get_message_key(&provider, 10).await.unwrap();
+ let res = ratchet.get_message_key(&provider, 9).await;
+ assert_matches!(res, Err(MlsError::KeyMissing(9)))
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_too_out_of_order() {
+ let cipher_suite = TEST_CIPHER_SUITE;
+ let provider = test_cipher_suite_provider(cipher_suite);
+
+ let mut ratchet = SecretKeyRatchet::new(&provider, &[0u8; 32], KeyType::Handshake)
+ .await
+ .unwrap();
+
+ let res = ratchet
+ .get_message_key(&provider, MAX_RATCHET_BACK_HISTORY + 1)
+ .await;
+
+ let invalid_generation = MAX_RATCHET_BACK_HISTORY + 1;
+
+ assert_matches!(
+ res,
+ Err(MlsError::InvalidFutureGeneration(invalid))
+ if invalid == invalid_generation
+ )
+ }
+
+ #[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)]
+ struct Ratchet {
+ application_keys: Vec<Vec<u8>>,
+ handshake_keys: Vec<Vec<u8>>,
+ }
+
+ #[derive(Debug, serde::Serialize, serde::Deserialize)]
+ struct TestCase {
+ cipher_suite: u16,
+ #[serde(with = "hex::serde")]
+ encryption_secret: Vec<u8>,
+ ratchets: Vec<Ratchet>,
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn get_ratchet_data(
+ secret_tree: &mut SecretTree<NodeIndex>,
+ cipher_suite: CipherSuite,
+ ) -> Vec<Ratchet> {
+ let provider = test_cipher_suite_provider(cipher_suite);
+ let mut ratchet_data = Vec::new();
+
+ for index in 0..16 {
+ let mut ratchets = secret_tree
+ .take_leaf_ratchet(&provider, &(index * 2))
+ .await
+ .unwrap();
+
+ let mut application_keys = Vec::new();
+
+ for _ in 0..20 {
+ let key = ratchets
+ .handshake
+ .next_message_key(&provider)
+ .await
+ .unwrap()
+ .mls_encode_to_vec()
+ .unwrap();
+
+ application_keys.push(key);
+ }
+
+ let mut handshake_keys = Vec::new();
+
+ for _ in 0..20 {
+ let key = ratchets
+ .handshake
+ .next_message_key(&provider)
+ .await
+ .unwrap()
+ .mls_encode_to_vec()
+ .unwrap();
+
+ handshake_keys.push(key);
+ }
+
+ ratchet_data.push(Ratchet {
+ application_keys,
+ handshake_keys,
+ });
+ }
+
+ ratchet_data
+ }
+
+ #[cfg(not(mls_build_async))]
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn generate_test_vector() -> Vec<TestCase> {
+ CipherSuite::all()
+ .map(|cipher_suite| {
+ let provider = test_cipher_suite_provider(cipher_suite);
+ let encryption_secret = random_bytes(provider.kdf_extract_size());
+
+ let mut secret_tree =
+ SecretTree::new(16, Zeroizing::new(encryption_secret.clone()));
+
+ TestCase {
+ cipher_suite: cipher_suite.into(),
+ encryption_secret,
+ ratchets: get_ratchet_data(&mut secret_tree, cipher_suite),
+ }
+ })
+ .collect()
+ }
+
+ #[cfg(mls_build_async)]
+ fn generate_test_vector() -> Vec<TestCase> {
+ panic!("Tests cannot be generated in async mode");
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_secret_tree_test_vectors() {
+ let test_cases: Vec<TestCase> = load_test_case_json!(secret_tree, generate_test_vector());
+
+ for case in test_cases {
+ let Some(cs_provider) = try_test_cipher_suite_provider(case.cipher_suite) else {
+ continue;
+ };
+
+ let mut secret_tree = SecretTree::new(16, Zeroizing::new(case.encryption_secret));
+ let ratchet_data = get_ratchet_data(&mut secret_tree, cs_provider.cipher_suite()).await;
+
+ assert_eq!(ratchet_data, case.ratchets);
+ }
+ }
+}
+
+#[cfg(all(test, feature = "rfc_compliant", feature = "std"))]
+mod interop_tests {
+ #[cfg(not(mls_build_async))]
+ use mls_rs_core::crypto::{CipherSuite, CipherSuiteProvider};
+ use zeroize::Zeroizing;
+
+ use crate::{
+ crypto::test_utils::try_test_cipher_suite_provider,
+ group::{ciphertext_processor::InteropSenderData, secret_tree::KeyType},
+ };
+
+ use super::SecretTree;
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn interop_test_vector() {
+ // The test vector can be found here https://github.com/mlswg/mls-implementations/blob/main/test-vectors/secret-tree.json
+ let test_cases = load_interop_test_cases();
+
+ for case in test_cases {
+ let Some(cs) = try_test_cipher_suite_provider(case.cipher_suite) else {
+ continue;
+ };
+
+ case.sender_data.verify(&cs).await;
+
+ let mut tree = SecretTree::new(
+ case.leaves.len() as u32,
+ Zeroizing::new(case.encryption_secret),
+ );
+
+ for (index, leaves) in case.leaves.iter().enumerate() {
+ for leaf in leaves.iter() {
+ let key = tree
+ .message_key_generation(
+ &cs,
+ (index as u32) * 2,
+ KeyType::Application,
+ leaf.generation,
+ )
+ .await
+ .unwrap();
+
+ assert_eq!(key.key.to_vec(), leaf.application_key);
+ assert_eq!(key.nonce.to_vec(), leaf.application_nonce);
+
+ let key = tree
+ .message_key_generation(
+ &cs,
+ (index as u32) * 2,
+ KeyType::Handshake,
+ leaf.generation,
+ )
+ .await
+ .unwrap();
+
+ assert_eq!(key.key.to_vec(), leaf.handshake_key);
+ assert_eq!(key.nonce.to_vec(), leaf.handshake_nonce);
+ }
+ }
+ }
+ }
+
+ #[derive(Debug, serde::Serialize, serde::Deserialize)]
+ struct InteropTestCase {
+ cipher_suite: u16,
+ #[serde(with = "hex::serde")]
+ encryption_secret: Vec<u8>,
+ sender_data: InteropSenderData,
+ leaves: Vec<Vec<InteropLeaf>>,
+ }
+
+ #[derive(Debug, serde::Serialize, serde::Deserialize)]
+ struct InteropLeaf {
+ generation: u32,
+ #[serde(with = "hex::serde")]
+ application_key: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ application_nonce: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ handshake_key: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ handshake_nonce: Vec<u8>,
+ }
+
+ fn load_interop_test_cases() -> Vec<InteropTestCase> {
+ load_test_case_json!(secret_tree_interop, generate_test_vector())
+ }
+
+ #[cfg(not(mls_build_async))]
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn generate_test_vector() -> Vec<InteropTestCase> {
+ let mut test_cases = vec![];
+
+ for cs in CipherSuite::all() {
+ let Some(cs) = try_test_cipher_suite_provider(*cs) else {
+ continue;
+ };
+
+ let gens = [0, 15];
+ let tree_sizes = [1, 8, 32];
+
+ for n_leaves in tree_sizes {
+ let encryption_secret = cs.random_bytes_vec(cs.kdf_extract_size()).unwrap();
+
+ let mut tree = SecretTree::new(n_leaves, Zeroizing::new(encryption_secret.clone()));
+
+ let leaves = (0..n_leaves)
+ .map(|leaf| {
+ gens.into_iter()
+ .map(|gen| {
+ let index = leaf * 2u32;
+
+ let handshake_key = tree
+ .message_key_generation(&cs, index, KeyType::Handshake, gen)
+ .unwrap();
+
+ let app_key = tree
+ .message_key_generation(&cs, index, KeyType::Application, gen)
+ .unwrap();
+
+ InteropLeaf {
+ generation: gen,
+ application_key: app_key.key.to_vec(),
+ application_nonce: app_key.nonce.to_vec(),
+ handshake_key: handshake_key.key.to_vec(),
+ handshake_nonce: handshake_key.nonce.to_vec(),
+ }
+ })
+ .collect()
+ })
+ .collect();
+
+ let case = InteropTestCase {
+ cipher_suite: *cs.cipher_suite(),
+ encryption_secret,
+ sender_data: InteropSenderData::new(&cs),
+ leaves,
+ };
+
+ test_cases.push(case);
+ }
+ }
+
+ test_cases
+ }
+
+ #[cfg(mls_build_async)]
+ fn generate_test_vector() -> Vec<InteropTestCase> {
+ panic!("Tests cannot be generated in async mode");
+ }
+}
diff --git a/src/group/snapshot.rs b/src/group/snapshot.rs
new file mode 100644
index 0000000..dca64f8
--- /dev/null
+++ b/src/group/snapshot.rs
@@ -0,0 +1,325 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use crate::{
+ client::MlsError,
+ client_config::ClientConfig,
+ group::{
+ key_schedule::KeySchedule, CommitGeneration, ConfirmationTag, Group, GroupContext,
+ GroupState, InterimTranscriptHash, ReInitProposal, TreeKemPublic,
+ },
+ tree_kem::TreeKemPrivate,
+};
+
+#[cfg(feature = "by_ref_proposal")]
+use crate::{
+ crypto::{HpkePublicKey, HpkeSecretKey},
+ group::ProposalRef,
+};
+
+#[cfg(feature = "by_ref_proposal")]
+use super::proposal_cache::{CachedProposal, ProposalCache};
+
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+
+use mls_rs_core::crypto::SignatureSecretKey;
+#[cfg(feature = "tree_index")]
+use mls_rs_core::identity::IdentityProvider;
+
+#[cfg(all(feature = "std", feature = "by_ref_proposal"))]
+use std::collections::HashMap;
+
+#[cfg(all(feature = "by_ref_proposal", not(feature = "std")))]
+use alloc::vec::Vec;
+
+use super::{cipher_suite_provider, epoch::EpochSecrets, state_repo::GroupStateRepository};
+
+#[derive(Debug, PartialEq, Clone, MlsEncode, MlsDecode, MlsSize)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub(crate) struct Snapshot {
+ version: u16,
+ pub(crate) state: RawGroupState,
+ private_tree: TreeKemPrivate,
+ epoch_secrets: EpochSecrets,
+ key_schedule: KeySchedule,
+ #[cfg(all(feature = "std", feature = "by_ref_proposal"))]
+ pending_updates: HashMap<HpkePublicKey, (HpkeSecretKey, Option<SignatureSecretKey>)>,
+ #[cfg(all(not(feature = "std"), feature = "by_ref_proposal"))]
+ pending_updates: Vec<(HpkePublicKey, (HpkeSecretKey, Option<SignatureSecretKey>))>,
+ pending_commit: Option<CommitGeneration>,
+ signer: SignatureSecretKey,
+}
+
+#[derive(Debug, MlsEncode, MlsDecode, MlsSize, PartialEq, Clone)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub(crate) struct RawGroupState {
+ pub(crate) context: GroupContext,
+ #[cfg(all(feature = "std", feature = "by_ref_proposal"))]
+ pub(crate) proposals: HashMap<ProposalRef, CachedProposal>,
+ #[cfg(all(not(feature = "std"), feature = "by_ref_proposal"))]
+ pub(crate) proposals: Vec<(ProposalRef, CachedProposal)>,
+ pub(crate) public_tree: TreeKemPublic,
+ pub(crate) interim_transcript_hash: InterimTranscriptHash,
+ pub(crate) pending_reinit: Option<ReInitProposal>,
+ pub(crate) confirmation_tag: ConfirmationTag,
+}
+
+impl RawGroupState {
+ pub(crate) fn export(state: &GroupState) -> Self {
+ #[cfg(feature = "tree_index")]
+ let public_tree = state.public_tree.clone();
+
+ #[cfg(not(feature = "tree_index"))]
+ let public_tree = {
+ let mut tree = TreeKemPublic::new();
+ tree.nodes = state.public_tree.nodes.clone();
+ tree
+ };
+
+ Self {
+ context: state.context.clone(),
+ #[cfg(feature = "by_ref_proposal")]
+ proposals: state.proposals.proposals.clone(),
+ public_tree,
+ interim_transcript_hash: state.interim_transcript_hash.clone(),
+ pending_reinit: state.pending_reinit.clone(),
+ confirmation_tag: state.confirmation_tag.clone(),
+ }
+ }
+
+ #[cfg(feature = "tree_index")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn import<C>(self, identity_provider: &C) -> Result<GroupState, MlsError>
+ where
+ C: IdentityProvider,
+ {
+ let context = self.context;
+
+ #[cfg(feature = "by_ref_proposal")]
+ let proposals = ProposalCache::import(
+ context.protocol_version,
+ context.group_id.clone(),
+ self.proposals,
+ );
+
+ let mut public_tree = self.public_tree;
+
+ public_tree
+ .initialize_index_if_necessary(identity_provider, &context.extensions)
+ .await?;
+
+ Ok(GroupState {
+ #[cfg(feature = "by_ref_proposal")]
+ proposals,
+ context,
+ public_tree,
+ interim_transcript_hash: self.interim_transcript_hash,
+ pending_reinit: self.pending_reinit,
+ confirmation_tag: self.confirmation_tag,
+ })
+ }
+
+ #[cfg(not(feature = "tree_index"))]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn import(self) -> Result<GroupState, MlsError> {
+ let context = self.context;
+
+ #[cfg(feature = "by_ref_proposal")]
+ let proposals = ProposalCache::import(
+ context.protocol_version,
+ context.group_id.clone(),
+ self.proposals,
+ );
+
+ Ok(GroupState {
+ #[cfg(feature = "by_ref_proposal")]
+ proposals,
+ context,
+ public_tree: self.public_tree,
+ interim_transcript_hash: self.interim_transcript_hash,
+ pending_reinit: self.pending_reinit,
+ confirmation_tag: self.confirmation_tag,
+ })
+ }
+}
+
+impl<C> Group<C>
+where
+ C: ClientConfig + Clone,
+{
+ /// Write the current state of the group to the
+ /// [`GroupStorageProvider`](crate::GroupStateStorage)
+ /// that is currently in use by the group.
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn write_to_storage(&mut self) -> Result<(), MlsError> {
+ self.state_repo.write_to_storage(self.snapshot()).await
+ }
+
+ pub(crate) fn snapshot(&self) -> Snapshot {
+ Snapshot {
+ state: RawGroupState::export(&self.state),
+ private_tree: self.private_tree.clone(),
+ key_schedule: self.key_schedule.clone(),
+ #[cfg(feature = "by_ref_proposal")]
+ pending_updates: self.pending_updates.clone(),
+ pending_commit: self.pending_commit.clone(),
+ epoch_secrets: self.epoch_secrets.clone(),
+ version: 1,
+ signer: self.signer.clone(),
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn from_snapshot(config: C, snapshot: Snapshot) -> Result<Self, MlsError> {
+ let cipher_suite_provider = cipher_suite_provider(
+ config.crypto_provider(),
+ snapshot.state.context.cipher_suite,
+ )?;
+
+ #[cfg(feature = "tree_index")]
+ let identity_provider = config.identity_provider();
+
+ let state_repo = GroupStateRepository::new(
+ #[cfg(feature = "prior_epoch")]
+ snapshot.state.context.group_id.clone(),
+ config.group_state_storage(),
+ config.key_package_repo(),
+ None,
+ )?;
+
+ Ok(Group {
+ config,
+ state: snapshot
+ .state
+ .import(
+ #[cfg(feature = "tree_index")]
+ &identity_provider,
+ )
+ .await?,
+ private_tree: snapshot.private_tree,
+ key_schedule: snapshot.key_schedule,
+ #[cfg(feature = "by_ref_proposal")]
+ pending_updates: snapshot.pending_updates,
+ pending_commit: snapshot.pending_commit,
+ #[cfg(test)]
+ commit_modifiers: Default::default(),
+ epoch_secrets: snapshot.epoch_secrets,
+ state_repo,
+ cipher_suite_provider,
+ #[cfg(feature = "psk")]
+ previous_psk: None,
+ signer: snapshot.signer,
+ })
+ }
+}
+
+#[cfg(test)]
+pub(crate) mod test_utils {
+ use alloc::vec;
+
+ use crate::{
+ cipher_suite::CipherSuite,
+ crypto::test_utils::test_cipher_suite_provider,
+ group::{
+ confirmation_tag::ConfirmationTag, epoch::test_utils::get_test_epoch_secrets,
+ key_schedule::test_utils::get_test_key_schedule, test_utils::get_test_group_context,
+ transcript_hash::InterimTranscriptHash,
+ },
+ tree_kem::{node::LeafIndex, TreeKemPrivate},
+ };
+
+ use super::{RawGroupState, Snapshot};
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn get_test_snapshot(cipher_suite: CipherSuite, epoch_id: u64) -> Snapshot {
+ Snapshot {
+ state: RawGroupState {
+ context: get_test_group_context(epoch_id, cipher_suite).await,
+ #[cfg(feature = "by_ref_proposal")]
+ proposals: Default::default(),
+ public_tree: Default::default(),
+ interim_transcript_hash: InterimTranscriptHash::from(vec![]),
+ pending_reinit: None,
+ confirmation_tag: ConfirmationTag::empty(&test_cipher_suite_provider(cipher_suite))
+ .await,
+ },
+ private_tree: TreeKemPrivate::new(LeafIndex(0)),
+ epoch_secrets: get_test_epoch_secrets(cipher_suite),
+ key_schedule: get_test_key_schedule(cipher_suite),
+ #[cfg(feature = "by_ref_proposal")]
+ pending_updates: Default::default(),
+ pending_commit: None,
+ version: 1,
+ signer: vec![].into(),
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use alloc::vec;
+
+ use crate::{
+ client::test_utils::{TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
+ group::{
+ test_utils::{test_group, TestGroup},
+ Group,
+ },
+ };
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn snapshot_restore(group: TestGroup) {
+ let snapshot = group.group.snapshot();
+
+ let group_restored = Group::from_snapshot(group.group.config.clone(), snapshot)
+ .await
+ .unwrap();
+
+ assert!(Group::equal_group_state(&group.group, &group_restored));
+
+ #[cfg(feature = "tree_index")]
+ assert!(group_restored
+ .state
+ .public_tree
+ .equal_internals(&group.group.state.public_tree))
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn snapshot_with_pending_commit_can_be_serialized_to_json() {
+ let mut group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ group.group.commit(vec![]).await.unwrap();
+
+ snapshot_restore(group).await
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn snapshot_with_pending_updates_can_be_serialized_to_json() {
+ let mut group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+
+ // Creating the update proposal will add it to pending updates
+ let update_proposal = group.update_proposal().await;
+
+ // This will insert the proposal into the internal proposal cache
+ let _ = group.group.proposal_message(update_proposal, vec![]).await;
+
+ snapshot_restore(group).await
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn snapshot_can_be_serialized_to_json_with_internals() {
+ let group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+
+ snapshot_restore(group).await
+ }
+
+ #[cfg(feature = "serde")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn serde() {
+ let snapshot = super::test_utils::get_test_snapshot(TEST_CIPHER_SUITE, 5).await;
+ let json = serde_json::to_string_pretty(&snapshot).unwrap();
+ let recovered = serde_json::from_str(&json).unwrap();
+ assert_eq!(snapshot, recovered);
+ }
+}
diff --git a/src/group/state.rs b/src/group/state.rs
new file mode 100644
index 0000000..4d97a04
--- /dev/null
+++ b/src/group/state.rs
@@ -0,0 +1,43 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use super::{
+ confirmation_tag::ConfirmationTag, proposal::ReInitProposal,
+ transcript_hash::InterimTranscriptHash,
+};
+use crate::group::{GroupContext, TreeKemPublic};
+
+#[derive(Clone, Debug, PartialEq)]
+#[non_exhaustive]
+pub struct GroupState {
+ #[cfg(feature = "by_ref_proposal")]
+ pub(crate) proposals: crate::group::ProposalCache,
+ pub(crate) context: GroupContext,
+ pub(crate) public_tree: TreeKemPublic,
+ pub(crate) interim_transcript_hash: InterimTranscriptHash,
+ pub(crate) pending_reinit: Option<ReInitProposal>,
+ pub(crate) confirmation_tag: ConfirmationTag,
+}
+
+impl GroupState {
+ pub(crate) fn new(
+ context: GroupContext,
+ current_tree: TreeKemPublic,
+ interim_transcript_hash: InterimTranscriptHash,
+ confirmation_tag: ConfirmationTag,
+ ) -> Self {
+ Self {
+ #[cfg(feature = "by_ref_proposal")]
+ proposals: crate::group::ProposalCache::new(
+ context.protocol_version,
+ context.group_id.clone(),
+ ),
+ context,
+ public_tree: current_tree,
+ interim_transcript_hash,
+ pending_reinit: None,
+ confirmation_tag,
+ }
+ }
+}
diff --git a/src/group/state_repo.rs b/src/group/state_repo.rs
new file mode 100644
index 0000000..6e33b0a
--- /dev/null
+++ b/src/group/state_repo.rs
@@ -0,0 +1,573 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use crate::client::MlsError;
+use crate::{group::PriorEpoch, key_package::KeyPackageRef};
+
+use alloc::collections::VecDeque;
+use alloc::vec::Vec;
+use core::fmt::{self, Debug};
+use mls_rs_codec::{MlsDecode, MlsEncode};
+use mls_rs_core::group::{EpochRecord, GroupState};
+use mls_rs_core::{error::IntoAnyError, group::GroupStateStorage, key_package::KeyPackageStorage};
+
+use super::snapshot::Snapshot;
+
+#[cfg(feature = "psk")]
+use crate::group::ResumptionPsk;
+
+#[cfg(feature = "psk")]
+use mls_rs_core::psk::PreSharedKey;
+
+/// A set of changes to apply to a GroupStateStorage implementation. These changes MUST
+/// be made in a single transaction to avoid creating invalid states.
+#[derive(Default, Clone, Debug)]
+struct EpochStorageCommit {
+ pub(crate) inserts: VecDeque<PriorEpoch>,
+ pub(crate) updates: Vec<PriorEpoch>,
+}
+
+#[derive(Clone)]
+pub(crate) struct GroupStateRepository<S, K>
+where
+ S: GroupStateStorage,
+ K: KeyPackageStorage,
+{
+ pending_commit: EpochStorageCommit,
+ pending_key_package_removal: Option<KeyPackageRef>,
+ group_id: Vec<u8>,
+ storage: S,
+ key_package_repo: K,
+}
+
+impl<S, K> Debug for GroupStateRepository<S, K>
+where
+ S: GroupStateStorage + Debug,
+ K: KeyPackageStorage + Debug,
+{
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("GroupStateRepository")
+ .field("pending_commit", &self.pending_commit)
+ .field(
+ "pending_key_package_removal",
+ &self.pending_key_package_removal,
+ )
+ .field(
+ "group_id",
+ &mls_rs_core::debug::pretty_group_id(&self.group_id),
+ )
+ .field("storage", &self.storage)
+ .field("key_package_repo", &self.key_package_repo)
+ .finish()
+ }
+}
+
+impl<S, K> GroupStateRepository<S, K>
+where
+ S: GroupStateStorage,
+ K: KeyPackageStorage,
+{
+ pub fn new(
+ group_id: Vec<u8>,
+ storage: S,
+ key_package_repo: K,
+ // Set to `None` if restoring from snapshot; set to `Some` when joining a group.
+ key_package_to_remove: Option<KeyPackageRef>,
+ ) -> Result<GroupStateRepository<S, K>, MlsError> {
+ Ok(GroupStateRepository {
+ group_id,
+ storage,
+ pending_key_package_removal: key_package_to_remove,
+ pending_commit: Default::default(),
+ key_package_repo,
+ })
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn find_max_id(&self) -> Result<Option<u64>, MlsError> {
+ if let Some(max) = self.pending_commit.inserts.back().map(|e| e.epoch_id()) {
+ Ok(Some(max))
+ } else {
+ self.storage
+ .max_epoch_id(&self.group_id)
+ .await
+ .map_err(|e| MlsError::GroupStorageError(e.into_any_error()))
+ }
+ }
+
+ #[cfg(feature = "psk")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn resumption_secret(
+ &self,
+ psk_id: &ResumptionPsk,
+ ) -> Result<Option<PreSharedKey>, MlsError> {
+ // Search the local inserts cache
+ if let Some(min) = self.pending_commit.inserts.front().map(|e| e.epoch_id()) {
+ if psk_id.psk_epoch >= min {
+ return Ok(self
+ .pending_commit
+ .inserts
+ .get((psk_id.psk_epoch - min) as usize)
+ .map(|e| e.secrets.resumption_secret.clone()));
+ }
+ }
+
+ // Search the local updates cache
+ let maybe_pending = self.find_pending(psk_id.psk_epoch);
+
+ if let Some(pending) = maybe_pending {
+ return Ok(Some(
+ self.pending_commit.updates[pending]
+ .secrets
+ .resumption_secret
+ .clone(),
+ ));
+ }
+
+ // Search the stored cache
+ self.storage
+ .epoch(&psk_id.psk_group_id.0, psk_id.psk_epoch)
+ .await
+ .map_err(|e| MlsError::GroupStorageError(e.into_any_error()))?
+ .map(|e| Ok(PriorEpoch::mls_decode(&mut &*e)?.secrets.resumption_secret))
+ .transpose()
+ }
+
+ #[cfg(feature = "private_message")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn get_epoch_mut(
+ &mut self,
+ epoch_id: u64,
+ ) -> Result<Option<&mut PriorEpoch>, MlsError> {
+ // Search the local inserts cache
+ if let Some(min) = self.pending_commit.inserts.front().map(|e| e.epoch_id()) {
+ if epoch_id >= min {
+ return Ok(self
+ .pending_commit
+ .inserts
+ .get_mut((epoch_id - min) as usize));
+ }
+ }
+
+ // Look in the cached updates map, and if not found look in disk storage
+ // and insert into the updates map for future caching
+ match self.find_pending(epoch_id) {
+ Some(i) => self.pending_commit.updates.get_mut(i).map(Ok),
+ None => self
+ .storage
+ .epoch(&self.group_id, epoch_id)
+ .await
+ .map_err(|e| MlsError::GroupStorageError(e.into_any_error()))?
+ .and_then(|epoch| {
+ PriorEpoch::mls_decode(&mut &*epoch)
+ .map(|epoch| {
+ self.pending_commit.updates.push(epoch);
+ self.pending_commit.updates.last_mut()
+ })
+ .transpose()
+ }),
+ }
+ .transpose()
+ .map_err(Into::into)
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn insert(&mut self, epoch: PriorEpoch) -> Result<(), MlsError> {
+ if epoch.group_id() != self.group_id {
+ return Err(MlsError::GroupIdMismatch);
+ }
+
+ let epoch_id = epoch.epoch_id();
+
+ if let Some(expected_id) = self.find_max_id().await?.map(|id| id + 1) {
+ if epoch_id != expected_id {
+ return Err(MlsError::InvalidEpoch);
+ }
+ }
+
+ self.pending_commit.inserts.push_back(epoch);
+
+ Ok(())
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn write_to_storage(&mut self, group_snapshot: Snapshot) -> Result<(), MlsError> {
+ let inserts = self
+ .pending_commit
+ .inserts
+ .iter()
+ .map(|e| Ok(EpochRecord::new(e.epoch_id(), e.mls_encode_to_vec()?)))
+ .collect::<Result<_, MlsError>>()?;
+
+ let updates = self
+ .pending_commit
+ .updates
+ .iter()
+ .map(|e| Ok(EpochRecord::new(e.epoch_id(), e.mls_encode_to_vec()?)))
+ .collect::<Result<_, MlsError>>()?;
+
+ let group_state = GroupState {
+ data: group_snapshot.mls_encode_to_vec()?,
+ id: group_snapshot.state.context.group_id,
+ };
+
+ self.storage
+ .write(group_state, inserts, updates)
+ .await
+ .map_err(|e| MlsError::GroupStorageError(e.into_any_error()))?;
+
+ if let Some(ref key_package_ref) = self.pending_key_package_removal {
+ self.key_package_repo
+ .delete(key_package_ref)
+ .await
+ .map_err(|e| MlsError::KeyPackageRepoError(e.into_any_error()))?;
+ }
+
+ self.pending_commit.inserts.clear();
+ self.pending_commit.updates.clear();
+
+ Ok(())
+ }
+
+ #[cfg(any(feature = "psk", feature = "private_message"))]
+ fn find_pending(&self, epoch_id: u64) -> Option<usize> {
+ self.pending_commit
+ .updates
+ .iter()
+ .position(|ep| ep.context.epoch == epoch_id)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use alloc::vec;
+ use mls_rs_codec::MlsEncode;
+
+ use crate::{
+ client::test_utils::{TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
+ group::{
+ epoch::{test_utils::get_test_epoch_with_id, SenderDataSecret},
+ test_utils::{random_bytes, test_member, TEST_GROUP},
+ PskGroupId, ResumptionPSKUsage,
+ },
+ storage_provider::in_memory::{InMemoryGroupStateStorage, InMemoryKeyPackageStorage},
+ };
+
+ use super::*;
+
+ fn test_group_state_repo(
+ retention_limit: usize,
+ ) -> GroupStateRepository<InMemoryGroupStateStorage, InMemoryKeyPackageStorage> {
+ GroupStateRepository::new(
+ TEST_GROUP.to_vec(),
+ InMemoryGroupStateStorage::new()
+ .with_max_epoch_retention(retention_limit)
+ .unwrap(),
+ InMemoryKeyPackageStorage::default(),
+ None,
+ )
+ .unwrap()
+ }
+
+ fn test_epoch(epoch_id: u64) -> PriorEpoch {
+ get_test_epoch_with_id(TEST_GROUP.to_vec(), TEST_CIPHER_SUITE, epoch_id)
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn test_snapshot(epoch_id: u64) -> Snapshot {
+ crate::group::snapshot::test_utils::get_test_snapshot(TEST_CIPHER_SUITE, epoch_id).await
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_epoch_inserts() {
+ let mut test_repo = test_group_state_repo(1);
+ let test_epoch = test_epoch(0);
+
+ test_repo.insert(test_epoch.clone()).await.unwrap();
+
+ // Check the in-memory state
+ assert_eq!(
+ test_repo.pending_commit.inserts.back().unwrap(),
+ &test_epoch
+ );
+
+ assert!(test_repo.pending_commit.updates.is_empty());
+
+ #[cfg(feature = "std")]
+ assert!(test_repo.storage.inner.lock().unwrap().is_empty());
+ #[cfg(not(feature = "std"))]
+ assert!(test_repo.storage.inner.lock().is_empty());
+
+ let psk_id = ResumptionPsk {
+ psk_epoch: 0,
+ psk_group_id: PskGroupId(test_repo.group_id.clone()),
+ usage: ResumptionPSKUsage::Application,
+ };
+
+ // Make sure you can recall an epoch sitting as a pending insert
+ let resumption = test_repo.resumption_secret(&psk_id).await.unwrap();
+ let prior_epoch = test_repo.get_epoch_mut(0).await.unwrap().cloned();
+
+ assert_eq!(
+ prior_epoch.clone().unwrap().secrets.resumption_secret,
+ resumption.unwrap()
+ );
+
+ assert_eq!(prior_epoch.unwrap(), test_epoch);
+
+ // Write to the storage
+ let snapshot = test_snapshot(test_epoch.epoch_id()).await;
+ test_repo.write_to_storage(snapshot.clone()).await.unwrap();
+
+ // Make sure the memory cache cleared
+ assert!(test_repo.pending_commit.inserts.is_empty());
+ assert!(test_repo.pending_commit.updates.is_empty());
+
+ // Make sure the storage was written
+ #[cfg(feature = "std")]
+ let storage = test_repo.storage.inner.lock().unwrap();
+ #[cfg(not(feature = "std"))]
+ let storage = test_repo.storage.inner.lock();
+
+ assert_eq!(storage.len(), 1);
+
+ let stored = storage.get(TEST_GROUP).unwrap();
+
+ assert_eq!(stored.state_data, snapshot.mls_encode_to_vec().unwrap());
+
+ assert_eq!(stored.epoch_data.len(), 1);
+
+ assert_eq!(
+ stored.epoch_data.back().unwrap(),
+ &EpochRecord::new(
+ test_epoch.epoch_id(),
+ test_epoch.mls_encode_to_vec().unwrap()
+ )
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_updates() {
+ let mut test_repo = test_group_state_repo(2);
+ let test_epoch_0 = test_epoch(0);
+
+ test_repo.insert(test_epoch_0.clone()).await.unwrap();
+
+ test_repo
+ .write_to_storage(test_snapshot(0).await)
+ .await
+ .unwrap();
+
+ // Update the stored epoch
+ let to_update = test_repo.get_epoch_mut(0).await.unwrap().unwrap();
+ assert_eq!(to_update, &test_epoch_0);
+
+ let new_sender_secret = random_bytes(32);
+ to_update.secrets.sender_data_secret = SenderDataSecret::from(new_sender_secret);
+ let to_update = to_update.clone();
+
+ assert_eq!(test_repo.pending_commit.updates.len(), 1);
+ assert!(test_repo.pending_commit.inserts.is_empty());
+
+ assert_eq!(
+ test_repo.pending_commit.updates.first().unwrap(),
+ &to_update
+ );
+
+ // Make sure you can access an epoch pending update
+ let psk_id = ResumptionPsk {
+ psk_epoch: 0,
+ psk_group_id: PskGroupId(test_repo.group_id.clone()),
+ usage: ResumptionPSKUsage::Application,
+ };
+
+ let owned = test_repo.resumption_secret(&psk_id).await.unwrap();
+ assert_eq!(owned.as_ref(), Some(&to_update.secrets.resumption_secret));
+
+ // Write the update to storage
+ let snapshot = test_snapshot(1).await;
+ test_repo.write_to_storage(snapshot.clone()).await.unwrap();
+
+ assert!(test_repo.pending_commit.updates.is_empty());
+ assert!(test_repo.pending_commit.inserts.is_empty());
+
+ // Make sure the storage was written
+ #[cfg(feature = "std")]
+ let storage = test_repo.storage.inner.lock().unwrap();
+ #[cfg(not(feature = "std"))]
+ let storage = test_repo.storage.inner.lock();
+
+ assert_eq!(storage.len(), 1);
+
+ let stored = storage.get(TEST_GROUP).unwrap();
+
+ assert_eq!(stored.state_data, snapshot.mls_encode_to_vec().unwrap());
+
+ assert_eq!(stored.epoch_data.len(), 1);
+
+ assert_eq!(
+ stored.epoch_data.back().unwrap(),
+ &EpochRecord::new(to_update.epoch_id(), to_update.mls_encode_to_vec().unwrap())
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_insert_and_update() {
+ let mut test_repo = test_group_state_repo(2);
+ let test_epoch_0 = test_epoch(0);
+
+ test_repo.insert(test_epoch_0).await.unwrap();
+
+ test_repo
+ .write_to_storage(test_snapshot(0).await)
+ .await
+ .unwrap();
+
+ // Update the stored epoch
+ let to_update = test_repo.get_epoch_mut(0).await.unwrap().unwrap();
+ let new_sender_secret = random_bytes(32);
+ to_update.secrets.sender_data_secret = SenderDataSecret::from(new_sender_secret);
+ let to_update = to_update.clone();
+
+ // Insert another epoch
+ let test_epoch_1 = test_epoch(1);
+ test_repo.insert(test_epoch_1.clone()).await.unwrap();
+
+ test_repo
+ .write_to_storage(test_snapshot(1).await)
+ .await
+ .unwrap();
+
+ assert!(test_repo.pending_commit.inserts.is_empty());
+ assert!(test_repo.pending_commit.updates.is_empty());
+
+ // Make sure the storage was written
+ #[cfg(feature = "std")]
+ let storage = test_repo.storage.inner.lock().unwrap();
+ #[cfg(not(feature = "std"))]
+ let storage = test_repo.storage.inner.lock();
+
+ assert_eq!(storage.len(), 1);
+
+ let stored = storage.get(TEST_GROUP).unwrap();
+
+ assert_eq!(stored.epoch_data.len(), 2);
+
+ assert_eq!(
+ stored.epoch_data.front().unwrap(),
+ &EpochRecord::new(to_update.epoch_id(), to_update.mls_encode_to_vec().unwrap())
+ );
+
+ assert_eq!(
+ stored.epoch_data.back().unwrap(),
+ &EpochRecord::new(
+ test_epoch_1.epoch_id(),
+ test_epoch_1.mls_encode_to_vec().unwrap()
+ )
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_many_epochs_in_storage() {
+ let epochs = (0..10).map(test_epoch).collect::<Vec<_>>();
+
+ let mut test_repo = test_group_state_repo(10);
+
+ for epoch in epochs.iter().cloned() {
+ test_repo.insert(epoch).await.unwrap()
+ }
+
+ test_repo
+ .write_to_storage(test_snapshot(9).await)
+ .await
+ .unwrap();
+
+ for mut epoch in epochs {
+ let res = test_repo.get_epoch_mut(epoch.epoch_id()).await.unwrap();
+
+ assert_eq!(res, Some(&mut epoch));
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_stored_groups_list() {
+ let mut test_repo = test_group_state_repo(2);
+ let test_epoch_0 = test_epoch(0);
+
+ test_repo.insert(test_epoch_0.clone()).await.unwrap();
+
+ test_repo
+ .write_to_storage(test_snapshot(0).await)
+ .await
+ .unwrap();
+
+ assert_eq!(
+ test_repo.storage.stored_groups(),
+ vec![test_epoch_0.context.group_id]
+ )
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn reducing_retention_limit_takes_effect_on_epoch_access() {
+ let mut repo = test_group_state_repo(1);
+
+ repo.insert(test_epoch(0)).await.unwrap();
+ repo.insert(test_epoch(1)).await.unwrap();
+
+ repo.write_to_storage(test_snapshot(0).await).await.unwrap();
+
+ let mut repo = GroupStateRepository {
+ storage: repo.storage,
+ ..test_group_state_repo(1)
+ };
+
+ let res = repo.get_epoch_mut(0).await.unwrap();
+
+ assert!(res.is_none());
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn in_memory_storage_obeys_retention_limit_after_saving() {
+ let mut repo = test_group_state_repo(1);
+
+ repo.insert(test_epoch(0)).await.unwrap();
+ repo.write_to_storage(test_snapshot(0).await).await.unwrap();
+ repo.insert(test_epoch(1)).await.unwrap();
+ repo.write_to_storage(test_snapshot(1).await).await.unwrap();
+
+ #[cfg(feature = "std")]
+ let lock = repo.storage.inner.lock().unwrap();
+ #[cfg(not(feature = "std"))]
+ let lock = repo.storage.inner.lock();
+
+ assert_eq!(lock.get(TEST_GROUP).unwrap().epoch_data.len(), 1);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn used_key_package_is_deleted() {
+ let key_package_repo = InMemoryKeyPackageStorage::default();
+
+ let key_package = test_member(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, b"member")
+ .await
+ .0;
+
+ let (id, data) = key_package.to_storage().unwrap();
+
+ key_package_repo.insert(id, data);
+
+ let mut repo = GroupStateRepository::new(
+ TEST_GROUP.to_vec(),
+ InMemoryGroupStateStorage::new(),
+ key_package_repo,
+ Some(key_package.reference.clone()),
+ )
+ .unwrap();
+
+ repo.key_package_repo.get(&key_package.reference).unwrap();
+
+ repo.write_to_storage(test_snapshot(4).await).await.unwrap();
+
+ assert!(repo.key_package_repo.get(&key_package.reference).is_none());
+ }
+}
diff --git a/src/group/state_repo_light.rs b/src/group/state_repo_light.rs
new file mode 100644
index 0000000..76d1fb6
--- /dev/null
+++ b/src/group/state_repo_light.rs
@@ -0,0 +1,132 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use crate::client::MlsError;
+use crate::key_package::KeyPackageRef;
+
+use alloc::vec::Vec;
+use mls_rs_codec::MlsEncode;
+use mls_rs_core::{
+ error::IntoAnyError,
+ group::{GroupState, GroupStateStorage},
+ key_package::KeyPackageStorage,
+};
+
+use super::snapshot::Snapshot;
+
+#[derive(Debug, Clone)]
+pub(crate) struct GroupStateRepository<S, K>
+where
+ S: GroupStateStorage,
+ K: KeyPackageStorage,
+{
+ pending_key_package_removal: Option<KeyPackageRef>,
+ storage: S,
+ key_package_repo: K,
+}
+
+impl<S, K> GroupStateRepository<S, K>
+where
+ S: GroupStateStorage,
+ K: KeyPackageStorage,
+{
+ pub fn new(
+ storage: S,
+ key_package_repo: K,
+ // Set to `None` if restoring from snapshot; set to `Some` when joining a group.
+ key_package_to_remove: Option<KeyPackageRef>,
+ ) -> Result<GroupStateRepository<S, K>, MlsError> {
+ Ok(GroupStateRepository {
+ storage,
+ pending_key_package_removal: key_package_to_remove,
+ key_package_repo,
+ })
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn write_to_storage(&mut self, group_snapshot: Snapshot) -> Result<(), MlsError> {
+ let group_state = GroupState {
+ data: group_snapshot.mls_encode_to_vec()?,
+ id: group_snapshot.state.context.group_id,
+ };
+
+ self.storage
+ .write(group_state, Vec::new(), Vec::new())
+ .await
+ .map_err(|e| MlsError::GroupStorageError(e.into_any_error()))?;
+
+ if let Some(ref key_package_ref) = self.pending_key_package_removal {
+ self.key_package_repo
+ .delete(key_package_ref)
+ .await
+ .map_err(|e| MlsError::KeyPackageRepoError(e.into_any_error()))?;
+ }
+
+ Ok(())
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use crate::{
+ client::test_utils::{TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
+ group::{
+ snapshot::{test_utils::get_test_snapshot, Snapshot},
+ test_utils::{test_member, TEST_GROUP},
+ },
+ storage_provider::in_memory::{InMemoryGroupStateStorage, InMemoryKeyPackageStorage},
+ };
+
+ use alloc::vec;
+
+ use super::GroupStateRepository;
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn test_snapshot(epoch_id: u64) -> Snapshot {
+ get_test_snapshot(TEST_CIPHER_SUITE, epoch_id).await
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_stored_groups_list() {
+ let mut test_repo = GroupStateRepository::new(
+ InMemoryGroupStateStorage::default(),
+ InMemoryKeyPackageStorage::default(),
+ None,
+ )
+ .unwrap();
+
+ test_repo
+ .write_to_storage(test_snapshot(0).await)
+ .await
+ .unwrap();
+
+ assert_eq!(test_repo.storage.stored_groups(), vec![TEST_GROUP])
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn used_key_package_is_deleted() {
+ let key_package_repo = InMemoryKeyPackageStorage::default();
+
+ let key_package = test_member(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, b"member")
+ .await
+ .0;
+
+ let (id, data) = key_package.to_storage().unwrap();
+
+ key_package_repo.insert(id, data);
+
+ let mut repo = GroupStateRepository::new(
+ InMemoryGroupStateStorage::default(),
+ key_package_repo,
+ Some(key_package.reference.clone()),
+ )
+ .unwrap();
+
+ repo.key_package_repo.get(&key_package.reference).unwrap();
+
+ repo.write_to_storage(test_snapshot(4).await).await.unwrap();
+
+ assert!(repo.key_package_repo.get(&key_package.reference).is_none());
+ }
+}
diff --git a/src/group/test_utils.rs b/src/group/test_utils.rs
new file mode 100644
index 0000000..764d5e6
--- /dev/null
+++ b/src/group/test_utils.rs
@@ -0,0 +1,521 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use core::ops::{Deref, DerefMut};
+
+use alloc::format;
+use rand::RngCore;
+
+use super::*;
+use crate::{
+ client::{
+ test_utils::{
+ test_client_with_key_pkg, test_client_with_key_pkg_custom, TEST_CIPHER_SUITE,
+ TEST_PROTOCOL_VERSION,
+ },
+ MlsError,
+ },
+ client_builder::test_utils::{TestClientBuilder, TestClientConfig},
+ crypto::test_utils::test_cipher_suite_provider,
+ extension::ExtensionType,
+ identity::basic::BasicIdentityProvider,
+ identity::test_utils::get_test_signing_identity,
+ key_package::{KeyPackageGeneration, KeyPackageGenerator},
+ mls_rules::{CommitOptions, DefaultMlsRules},
+ tree_kem::{leaf_node::test_utils::get_test_capabilities, Lifetime},
+};
+
+use crate::extension::RequiredCapabilitiesExt;
+
+#[cfg(not(feature = "by_ref_proposal"))]
+use crate::crypto::HpkePublicKey;
+
+pub const TEST_GROUP: &[u8] = b"group";
+
+#[derive(Clone)]
+pub(crate) struct TestGroup {
+ pub group: Group<TestClientConfig>,
+}
+
+impl TestGroup {
+ #[cfg(feature = "external_client")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn propose(&mut self, proposal: Proposal) -> MlsMessage {
+ self.group.proposal_message(proposal, vec![]).await.unwrap()
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn update_proposal(&mut self) -> Proposal {
+ self.group.update_proposal(None, None).await.unwrap()
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn join_with_custom_config<F>(
+ &mut self,
+ name: &str,
+ custom_kp: bool,
+ mut config: F,
+ ) -> Result<(TestGroup, MlsMessage), MlsError>
+ where
+ F: FnMut(&mut TestClientConfig),
+ {
+ let (mut new_client, new_key_package) = if custom_kp {
+ test_client_with_key_pkg_custom(
+ self.group.protocol_version(),
+ self.group.cipher_suite(),
+ name,
+ &mut config,
+ )
+ .await
+ } else {
+ test_client_with_key_pkg(
+ self.group.protocol_version(),
+ self.group.cipher_suite(),
+ name,
+ )
+ .await
+ };
+
+ // Add new member to the group
+ let CommitOutput {
+ welcome_messages,
+ ratchet_tree,
+ commit_message,
+ ..
+ } = self
+ .group
+ .commit_builder()
+ .add_member(new_key_package)
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ // Apply the commit to the original group
+ self.group.apply_pending_commit().await.unwrap();
+
+ config(&mut new_client.config);
+
+ // Group from new member's perspective
+ let (new_group, _) = Group::join(
+ &welcome_messages[0],
+ ratchet_tree,
+ new_client.config.clone(),
+ new_client.signer.clone().unwrap(),
+ )
+ .await?;
+
+ let new_test_group = TestGroup { group: new_group };
+
+ Ok((new_test_group, commit_message))
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn join(&mut self, name: &str) -> (TestGroup, MlsMessage) {
+ self.join_with_custom_config(name, false, |_| ())
+ .await
+ .unwrap()
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn process_pending_commit(
+ &mut self,
+ ) -> Result<CommitMessageDescription, MlsError> {
+ self.group.apply_pending_commit().await
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn process_message(
+ &mut self,
+ message: MlsMessage,
+ ) -> Result<ReceivedMessage, MlsError> {
+ self.group.process_incoming_message(message).await
+ }
+
+ #[cfg(feature = "private_message")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn make_plaintext(&mut self, content: Content) -> MlsMessage {
+ let auth_content = AuthenticatedContent::new_signed(
+ &self.group.cipher_suite_provider,
+ &self.group.state.context,
+ Sender::Member(*self.group.private_tree.self_index),
+ content,
+ &self.group.signer,
+ WireFormat::PublicMessage,
+ Vec::new(),
+ )
+ .await
+ .unwrap();
+
+ self.group.format_for_wire(auth_content).await.unwrap()
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(crate) async fn get_test_group_context(epoch: u64, cipher_suite: CipherSuite) -> GroupContext {
+ let cs = test_cipher_suite_provider(cipher_suite);
+
+ GroupContext {
+ protocol_version: TEST_PROTOCOL_VERSION,
+ cipher_suite,
+ group_id: TEST_GROUP.to_vec(),
+ epoch,
+ tree_hash: cs.hash(&[1, 2, 3]).await.unwrap(),
+ confirmed_transcript_hash: cs.hash(&[3, 2, 1]).await.unwrap().into(),
+ extensions: ExtensionList::from(vec![]),
+ }
+}
+
+#[cfg(feature = "prior_epoch")]
+pub(crate) fn get_test_group_context_with_id(
+ group_id: Vec<u8>,
+ epoch: u64,
+ cipher_suite: CipherSuite,
+) -> GroupContext {
+ GroupContext {
+ protocol_version: TEST_PROTOCOL_VERSION,
+ cipher_suite,
+ group_id,
+ epoch,
+ tree_hash: vec![],
+ confirmed_transcript_hash: ConfirmedTranscriptHash::from(vec![]),
+ extensions: ExtensionList::from(vec![]),
+ }
+}
+
+pub(crate) fn group_extensions() -> ExtensionList {
+ let required_capabilities = RequiredCapabilitiesExt::default();
+
+ let mut extensions = ExtensionList::new();
+ extensions.set_from(required_capabilities).unwrap();
+ extensions
+}
+
+pub(crate) fn lifetime() -> Lifetime {
+ Lifetime::years(1).unwrap()
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(crate) async fn test_member(
+ protocol_version: ProtocolVersion,
+ cipher_suite: CipherSuite,
+ identifier: &[u8],
+) -> (KeyPackageGeneration, SignatureSecretKey) {
+ let (signing_identity, signing_key) = get_test_signing_identity(cipher_suite, identifier).await;
+
+ let key_package_generator = KeyPackageGenerator {
+ protocol_version,
+ cipher_suite_provider: &test_cipher_suite_provider(cipher_suite),
+ signing_identity: &signing_identity,
+ signing_key: &signing_key,
+ identity_provider: &BasicIdentityProvider,
+ };
+
+ let key_package = key_package_generator
+ .generate(
+ lifetime(),
+ get_test_capabilities(),
+ ExtensionList::default(),
+ ExtensionList::default(),
+ )
+ .await
+ .unwrap();
+
+ (key_package, signing_key)
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(crate) async fn test_group_custom(
+ protocol_version: ProtocolVersion,
+ cipher_suite: CipherSuite,
+ extension_types: Vec<ExtensionType>,
+ leaf_extensions: Option<ExtensionList>,
+ commit_options: Option<CommitOptions>,
+) -> TestGroup {
+ let leaf_extensions = leaf_extensions.unwrap_or_default();
+ let commit_options = commit_options.unwrap_or_default();
+
+ let (signing_identity, secret_key) = get_test_signing_identity(cipher_suite, b"member").await;
+
+ let group = TestClientBuilder::new_for_test()
+ .leaf_node_extensions(leaf_extensions)
+ .mls_rules(DefaultMlsRules::default().with_commit_options(commit_options))
+ .extension_types(extension_types)
+ .protocol_versions(ProtocolVersion::all())
+ .used_protocol_version(protocol_version)
+ .signing_identity(signing_identity.clone(), secret_key, cipher_suite)
+ .build()
+ .create_group_with_id(TEST_GROUP.to_vec(), group_extensions())
+ .await
+ .unwrap();
+
+ TestGroup { group }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(crate) async fn test_group(
+ protocol_version: ProtocolVersion,
+ cipher_suite: CipherSuite,
+) -> TestGroup {
+ test_group_custom(
+ protocol_version,
+ cipher_suite,
+ Default::default(),
+ None,
+ None,
+ )
+ .await
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(crate) async fn test_group_custom_config<F>(
+ protocol_version: ProtocolVersion,
+ cipher_suite: CipherSuite,
+ custom: F,
+) -> TestGroup
+where
+ F: FnOnce(TestClientBuilder) -> TestClientBuilder,
+{
+ let (signing_identity, secret_key) = get_test_signing_identity(cipher_suite, b"member").await;
+
+ let client_builder = TestClientBuilder::new_for_test().used_protocol_version(protocol_version);
+
+ let group = custom(client_builder)
+ .signing_identity(signing_identity.clone(), secret_key, cipher_suite)
+ .build()
+ .create_group_with_id(TEST_GROUP.to_vec(), group_extensions())
+ .await
+ .unwrap();
+
+ TestGroup { group }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(crate) async fn test_n_member_group(
+ protocol_version: ProtocolVersion,
+ cipher_suite: CipherSuite,
+ num_members: usize,
+) -> Vec<TestGroup> {
+ let group = test_group(protocol_version, cipher_suite).await;
+
+ let mut groups = vec![group];
+
+ for i in 1..num_members {
+ let (new_group, commit) = groups.get_mut(0).unwrap().join(&format!("name {i}")).await;
+ process_commit(&mut groups, commit, 0).await;
+ groups.push(new_group);
+ }
+
+ groups
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(crate) async fn process_commit(groups: &mut [TestGroup], commit: MlsMessage, excluded: u32) {
+ for g in groups
+ .iter_mut()
+ .filter(|g| g.group.current_member_index() != excluded)
+ {
+ g.process_message(commit.clone()).await.unwrap();
+ }
+}
+
+pub(crate) fn get_test_25519_key(key_byte: u8) -> HpkePublicKey {
+ vec![key_byte; 32].into()
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(crate) async fn get_test_groups_with_features(
+ n: usize,
+ extensions: ExtensionList,
+ leaf_extensions: ExtensionList,
+) -> Vec<Group<TestClientConfig>> {
+ let mut clients = Vec::new();
+
+ for i in 0..n {
+ let (identity, secret_key) =
+ get_test_signing_identity(TEST_CIPHER_SUITE, format!("member{i}").as_bytes()).await;
+
+ clients.push(
+ TestClientBuilder::new_for_test()
+ .extension_type(999.into())
+ .leaf_node_extensions(leaf_extensions.clone())
+ .signing_identity(identity, secret_key, TEST_CIPHER_SUITE)
+ .build(),
+ );
+ }
+
+ let group = clients[0]
+ .create_group_with_id(b"TEST GROUP".to_vec(), extensions)
+ .await
+ .unwrap();
+
+ let mut groups = vec![group];
+
+ for client in clients.iter().skip(1) {
+ let key_package = client.generate_key_package_message().await.unwrap();
+
+ let commit_output = groups[0]
+ .commit_builder()
+ .add_member(key_package)
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ groups[0].apply_pending_commit().await.unwrap();
+
+ for group in groups.iter_mut().skip(1) {
+ group
+ .process_incoming_message(commit_output.commit_message.clone())
+ .await
+ .unwrap();
+ }
+
+ groups.push(
+ client
+ .join_group(None, &commit_output.welcome_messages[0])
+ .await
+ .unwrap()
+ .0,
+ );
+ }
+
+ groups
+}
+
+pub fn random_bytes(count: usize) -> Vec<u8> {
+ let mut buf = vec![0; count];
+ rand::thread_rng().fill_bytes(&mut buf);
+ buf
+}
+
+pub(crate) struct GroupWithoutKeySchedule {
+ inner: Group<TestClientConfig>,
+ pub secrets: Option<(TreeKemPrivate, PathSecret)>,
+ pub provisional_public_state: Option<ProvisionalState>,
+}
+
+impl Deref for GroupWithoutKeySchedule {
+ type Target = Group<TestClientConfig>;
+
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn deref(&self) -> &Self::Target {
+ &self.inner
+ }
+}
+
+impl DerefMut for GroupWithoutKeySchedule {
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn deref_mut(&mut self) -> &mut Self::Target {
+ &mut self.inner
+ }
+}
+
+#[cfg(feature = "rfc_compliant")]
+impl GroupWithoutKeySchedule {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn new(cs: CipherSuite) -> Self {
+ Self {
+ inner: test_group(TEST_PROTOCOL_VERSION, cs).await.group,
+ secrets: None,
+ provisional_public_state: None,
+ }
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(all(target_arch = "wasm32", mls_build_async), maybe_async::must_be_async(?Send))]
+#[cfg_attr(
+ all(not(target_arch = "wasm32"), mls_build_async),
+ maybe_async::must_be_async
+)]
+impl MessageProcessor for GroupWithoutKeySchedule {
+ type CipherSuiteProvider = <Group<TestClientConfig> as MessageProcessor>::CipherSuiteProvider;
+ type OutputType = <Group<TestClientConfig> as MessageProcessor>::OutputType;
+ type PreSharedKeyStorage = <Group<TestClientConfig> as MessageProcessor>::PreSharedKeyStorage;
+ type IdentityProvider = <Group<TestClientConfig> as MessageProcessor>::IdentityProvider;
+ type MlsRules = <Group<TestClientConfig> as MessageProcessor>::MlsRules;
+
+ fn group_state(&self) -> &GroupState {
+ self.inner.group_state()
+ }
+
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn group_state_mut(&mut self) -> &mut GroupState {
+ self.inner.group_state_mut()
+ }
+
+ fn mls_rules(&self) -> Self::MlsRules {
+ self.inner.mls_rules()
+ }
+
+ fn identity_provider(&self) -> Self::IdentityProvider {
+ self.inner.identity_provider()
+ }
+
+ fn cipher_suite_provider(&self) -> &Self::CipherSuiteProvider {
+ self.inner.cipher_suite_provider()
+ }
+
+ fn psk_storage(&self) -> Self::PreSharedKeyStorage {
+ self.inner.psk_storage()
+ }
+
+ fn can_continue_processing(&self, provisional_state: &ProvisionalState) -> bool {
+ self.inner.can_continue_processing(provisional_state)
+ }
+
+ #[cfg(feature = "private_message")]
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn min_epoch_available(&self) -> Option<u64> {
+ self.inner.min_epoch_available()
+ }
+
+ async fn apply_update_path(
+ &mut self,
+ sender: LeafIndex,
+ update_path: &ValidatedUpdatePath,
+ provisional_state: &mut ProvisionalState,
+ ) -> Result<Option<(TreeKemPrivate, PathSecret)>, MlsError> {
+ self.inner
+ .apply_update_path(sender, update_path, provisional_state)
+ .await
+ }
+
+ #[cfg(feature = "private_message")]
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ async fn process_ciphertext(
+ &mut self,
+ cipher_text: &PrivateMessage,
+ ) -> Result<EventOrContent<Self::OutputType>, MlsError> {
+ self.inner.process_ciphertext(cipher_text).await
+ }
+
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ async fn verify_plaintext_authentication(
+ &self,
+ message: PublicMessage,
+ ) -> Result<EventOrContent<Self::OutputType>, MlsError> {
+ self.inner.verify_plaintext_authentication(message).await
+ }
+
+ async fn update_key_schedule(
+ &mut self,
+ secrets: Option<(TreeKemPrivate, PathSecret)>,
+ _interim_transcript_hash: InterimTranscriptHash,
+ _confirmation_tag: &ConfirmationTag,
+ provisional_public_state: ProvisionalState,
+ ) -> Result<(), MlsError> {
+ self.provisional_public_state = Some(provisional_public_state);
+ self.secrets = secrets;
+ Ok(())
+ }
+
+ #[cfg(feature = "private_message")]
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn self_index(&self) -> Option<LeafIndex> {
+ <Group<TestClientConfig> as MessageProcessor>::self_index(&self.inner)
+ }
+}
diff --git a/src/group/transcript_hash.rs b/src/group/transcript_hash.rs
new file mode 100644
index 0000000..c336dfa
--- /dev/null
+++ b/src/group/transcript_hash.rs
@@ -0,0 +1,293 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::vec::Vec;
+use core::{
+ fmt::{self, Debug},
+ ops::Deref,
+};
+
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+use mls_rs_core::{crypto::CipherSuiteProvider, error::IntoAnyError};
+
+use crate::{
+ client::MlsError,
+ group::{framing::FramedContent, MessageSignature},
+ WireFormat,
+};
+
+use super::{AuthenticatedContent, ConfirmationTag};
+
+#[derive(Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub struct ConfirmedTranscriptHash(
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
+ Vec<u8>,
+);
+
+impl Debug for ConfirmedTranscriptHash {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ mls_rs_core::debug::pretty_bytes(&self.0)
+ .named("ConfirmedTranscriptHash")
+ .fmt(f)
+ }
+}
+
+impl Deref for ConfirmedTranscriptHash {
+ type Target = Vec<u8>;
+
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+impl From<Vec<u8>> for ConfirmedTranscriptHash {
+ fn from(value: Vec<u8>) -> Self {
+ Self(value)
+ }
+}
+
+impl ConfirmedTranscriptHash {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn create<P: CipherSuiteProvider>(
+ cipher_suite_provider: &P,
+ interim_transcript_hash: &InterimTranscriptHash,
+ content: &AuthenticatedContent,
+ ) -> Result<Self, MlsError> {
+ #[derive(Debug, MlsSize, MlsEncode)]
+ struct ConfirmedTranscriptHashInput<'a> {
+ wire_format: WireFormat,
+ content: &'a FramedContent,
+ signature: &'a MessageSignature,
+ }
+
+ let input = ConfirmedTranscriptHashInput {
+ wire_format: content.wire_format,
+ content: &content.content,
+ signature: &content.auth.signature,
+ };
+
+ let hash_input = [
+ interim_transcript_hash.deref(),
+ input.mls_encode_to_vec()?.deref(),
+ ]
+ .concat();
+
+ cipher_suite_provider
+ .hash(&hash_input)
+ .await
+ .map(Into::into)
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
+ }
+}
+
+#[derive(Clone, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub(crate) struct InterimTranscriptHash(
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
+ Vec<u8>,
+);
+
+impl Debug for InterimTranscriptHash {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ mls_rs_core::debug::pretty_bytes(&self.0)
+ .named("InterimTranscriptHash")
+ .fmt(f)
+ }
+}
+
+impl Deref for InterimTranscriptHash {
+ type Target = Vec<u8>;
+
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+impl From<Vec<u8>> for InterimTranscriptHash {
+ fn from(value: Vec<u8>) -> Self {
+ Self(value)
+ }
+}
+
+impl InterimTranscriptHash {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn create<P: CipherSuiteProvider>(
+ cipher_suite_provider: &P,
+ confirmed: &ConfirmedTranscriptHash,
+ confirmation_tag: &ConfirmationTag,
+ ) -> Result<Self, MlsError> {
+ #[derive(Debug, MlsSize, MlsEncode)]
+ struct InterimTranscriptHashInput<'a> {
+ confirmation_tag: &'a ConfirmationTag,
+ }
+
+ let input = InterimTranscriptHashInput { confirmation_tag }.mls_encode_to_vec()?;
+
+ cipher_suite_provider
+ .hash(&[confirmed.0.deref(), &input].concat())
+ .await
+ .map(Into::into)
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
+ }
+}
+
+// Test vectors come from the MLS interop repository and contain a proposal by reference.
+#[cfg(feature = "by_ref_proposal")]
+#[cfg(test)]
+mod tests {
+ use alloc::vec::Vec;
+
+ use mls_rs_codec::MlsDecode;
+
+ use crate::{
+ crypto::test_utils::try_test_cipher_suite_provider,
+ group::{framing::ContentType, message_signature::AuthenticatedContent, transcript_hashes},
+ };
+
+ #[cfg(not(mls_build_async))]
+ use alloc::{boxed::Box, vec};
+
+ #[cfg(not(mls_build_async))]
+ use crate::{
+ crypto::test_utils::test_cipher_suite_provider,
+ group::{
+ confirmation_tag::ConfirmationTag,
+ framing::Content,
+ proposal::{Proposal, ProposalOrRef, RemoveProposal},
+ test_utils::get_test_group_context,
+ Commit, LeafIndex, Sender,
+ },
+ mls_rs_codec::MlsEncode,
+ CipherSuite, CipherSuiteProvider, WireFormat,
+ };
+
+ #[cfg(not(mls_build_async))]
+ use super::{ConfirmedTranscriptHash, InterimTranscriptHash};
+
+ #[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
+ struct TestCase {
+ pub cipher_suite: u16,
+
+ #[serde(with = "hex::serde")]
+ pub confirmation_key: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub authenticated_content: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub interim_transcript_hash_before: Vec<u8>,
+
+ #[serde(with = "hex::serde")]
+ pub confirmed_transcript_hash_after: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub interim_transcript_hash_after: Vec<u8>,
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn transcript_hash() {
+ let test_cases: Vec<TestCase> =
+ load_test_case_json!(interop_transcript_hashes, generate_test_vector());
+
+ for test_case in test_cases.into_iter() {
+ let Some(cs) = try_test_cipher_suite_provider(test_case.cipher_suite) else {
+ continue;
+ };
+
+ let auth_content =
+ AuthenticatedContent::mls_decode(&mut &*test_case.authenticated_content).unwrap();
+
+ assert!(auth_content.content.content_type() == ContentType::Commit);
+
+ let conf_key = &test_case.confirmation_key;
+ let conf_hash_after = test_case.confirmed_transcript_hash_after.into();
+ let conf_tag = auth_content.auth.confirmation_tag.clone().unwrap();
+
+ let matches = conf_tag
+ .matches(conf_key, &conf_hash_after, &cs)
+ .await
+ .unwrap();
+
+ assert!(matches);
+
+ let (expected_interim, expected_conf) = transcript_hashes(
+ &cs,
+ &test_case.interim_transcript_hash_before.into(),
+ &auth_content,
+ )
+ .await
+ .unwrap();
+
+ assert_eq!(*expected_interim, test_case.interim_transcript_hash_after);
+ assert_eq!(expected_conf, conf_hash_after);
+ }
+ }
+
+ #[cfg(not(mls_build_async))]
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn generate_test_vector() -> Vec<TestCase> {
+ CipherSuite::all().fold(vec![], |mut test_cases, cs| {
+ let cs = test_cipher_suite_provider(cs);
+
+ let context = get_test_group_context(0x3456, cs.cipher_suite());
+
+ let proposal = Proposal::Remove(RemoveProposal {
+ to_remove: LeafIndex(1),
+ });
+
+ let proposal = ProposalOrRef::Proposal(Box::new(proposal));
+
+ let commit = Commit {
+ proposals: vec![proposal],
+ path: None,
+ };
+
+ let signer = cs.signature_key_generate().unwrap().0;
+
+ let mut auth_content = AuthenticatedContent::new_signed(
+ &cs,
+ &context,
+ Sender::Member(0),
+ Content::Commit(alloc::boxed::Box::new(commit)),
+ &signer,
+ WireFormat::PublicMessage,
+ vec![],
+ )
+ .unwrap();
+
+ let interim_hash_before = cs.random_bytes_vec(cs.kdf_extract_size()).unwrap().into();
+
+ let conf_hash_after =
+ ConfirmedTranscriptHash::create(&cs, &interim_hash_before, &auth_content).unwrap();
+
+ let conf_key = cs.random_bytes_vec(cs.kdf_extract_size()).unwrap();
+ let conf_tag = ConfirmationTag::create(&conf_key, &conf_hash_after, &cs).unwrap();
+
+ let interim_hash_after =
+ InterimTranscriptHash::create(&cs, &conf_hash_after, &conf_tag).unwrap();
+
+ auth_content.auth.confirmation_tag = Some(conf_tag);
+
+ let test_case = TestCase {
+ cipher_suite: cs.cipher_suite().into(),
+
+ confirmation_key: conf_key,
+ authenticated_content: auth_content.mls_encode_to_vec().unwrap(),
+ interim_transcript_hash_before: interim_hash_before.0,
+
+ confirmed_transcript_hash_after: conf_hash_after.0,
+ interim_transcript_hash_after: interim_hash_after.0,
+ };
+
+ test_cases.push(test_case);
+ test_cases
+ })
+ }
+
+ #[cfg(mls_build_async)]
+ fn generate_test_vector() -> Vec<TestCase> {
+ panic!("Tests cannot be generated in async mode");
+ }
+}
diff --git a/src/group/util.rs b/src/group/util.rs
new file mode 100644
index 0000000..dadfafa
--- /dev/null
+++ b/src/group/util.rs
@@ -0,0 +1,202 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use mls_rs_core::{
+ error::IntoAnyError, identity::IdentityProvider, key_package::KeyPackageStorage,
+};
+
+use crate::{
+ cipher_suite::CipherSuite,
+ client::MlsError,
+ extension::RatchetTreeExt,
+ key_package::KeyPackageGeneration,
+ protocol_version::ProtocolVersion,
+ signer::Signable,
+ tree_kem::{node::LeafIndex, tree_validator::TreeValidator, TreeKemPublic},
+ CipherSuiteProvider, CryptoProvider,
+};
+
+#[cfg(feature = "by_ref_proposal")]
+use crate::extension::ExternalSendersExt;
+
+use super::{
+ framing::Sender, message_signature::AuthenticatedContent,
+ transcript_hash::InterimTranscriptHash, ConfirmedTranscriptHash, EncryptedGroupSecrets,
+ ExportedTree, GroupInfo, GroupState,
+};
+
+use super::message_processor::ProvisionalState;
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(crate) async fn validate_group_info_common<C: CipherSuiteProvider>(
+ msg_version: ProtocolVersion,
+ group_info: &GroupInfo,
+ tree: &TreeKemPublic,
+ cs: &C,
+) -> Result<(), MlsError> {
+ if msg_version != group_info.group_context.protocol_version {
+ return Err(MlsError::ProtocolVersionMismatch);
+ }
+
+ if group_info.group_context.cipher_suite != cs.cipher_suite() {
+ return Err(MlsError::CipherSuiteMismatch);
+ }
+
+ let sender_leaf = &tree.get_leaf_node(group_info.signer)?;
+
+ group_info
+ .verify(cs, &sender_leaf.signing_identity.signature_key, &())
+ .await?;
+
+ Ok(())
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(crate) async fn validate_group_info_member<C: CipherSuiteProvider>(
+ self_state: &GroupState,
+ msg_version: ProtocolVersion,
+ group_info: &GroupInfo,
+ cs: &C,
+) -> Result<(), MlsError> {
+ validate_group_info_common(msg_version, group_info, &self_state.public_tree, cs).await?;
+
+ let self_tree = ExportedTree::new_borrowed(&self_state.public_tree.nodes);
+
+ if let Some(tree) = group_info.extensions.get_as::<RatchetTreeExt>()? {
+ (tree.tree_data == self_tree)
+ .then_some(())
+ .ok_or(MlsError::InvalidGroupInfo)?;
+ }
+
+ (group_info.group_context == self_state.context
+ && group_info.confirmation_tag == self_state.confirmation_tag)
+ .then_some(())
+ .ok_or(MlsError::InvalidGroupInfo)?;
+
+ Ok(())
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(crate) async fn validate_group_info_joiner<C, I>(
+ msg_version: ProtocolVersion,
+ group_info: &GroupInfo,
+ tree: Option<ExportedTree<'_>>,
+ id_provider: &I,
+ cs: &C,
+) -> Result<TreeKemPublic, MlsError>
+where
+ C: CipherSuiteProvider,
+ I: IdentityProvider,
+{
+ let tree = match group_info.extensions.get_as::<RatchetTreeExt>()? {
+ Some(ext) => ext.tree_data,
+ None => tree.ok_or(MlsError::RatchetTreeNotFound)?,
+ };
+
+ let context = &group_info.group_context;
+
+ let mut tree =
+ TreeKemPublic::import_node_data(tree.into(), id_provider, &context.extensions).await?;
+
+ // Verify the integrity of the ratchet tree
+ TreeValidator::new(cs, context, id_provider)
+ .validate(&mut tree)
+ .await?;
+
+ #[cfg(feature = "by_ref_proposal")]
+ if let Some(ext_senders) = context.extensions.get_as::<ExternalSendersExt>()? {
+ // TODO do joiners verify group against current time??
+ ext_senders
+ .verify_all(id_provider, None, &context.extensions)
+ .await
+ .map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))?;
+ }
+
+ validate_group_info_common(msg_version, group_info, &tree, cs).await?;
+
+ Ok(tree)
+}
+
+pub(crate) fn commit_sender(
+ sender: &Sender,
+ provisional_state: &ProvisionalState,
+) -> Result<LeafIndex, MlsError> {
+ match sender {
+ Sender::Member(index) => Ok(LeafIndex(*index)),
+ #[cfg(feature = "by_ref_proposal")]
+ Sender::External(_) => Err(MlsError::ExternalSenderCannotCommit),
+ #[cfg(feature = "by_ref_proposal")]
+ Sender::NewMemberProposal => Err(MlsError::ExpectedAddProposalForNewMemberProposal),
+ Sender::NewMemberCommit => provisional_state
+ .external_init_index
+ .ok_or(MlsError::ExternalCommitMissingExternalInit),
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(super) async fn transcript_hashes<P: CipherSuiteProvider>(
+ cipher_suite_provider: &P,
+ prev_interim_transcript_hash: &InterimTranscriptHash,
+ content: &AuthenticatedContent,
+) -> Result<(InterimTranscriptHash, ConfirmedTranscriptHash), MlsError> {
+ let confirmed_transcript_hash = ConfirmedTranscriptHash::create(
+ cipher_suite_provider,
+ prev_interim_transcript_hash,
+ content,
+ )
+ .await?;
+
+ let confirmation_tag = content
+ .auth
+ .confirmation_tag
+ .as_ref()
+ .ok_or(MlsError::InvalidConfirmationTag)?;
+
+ let interim_transcript_hash = InterimTranscriptHash::create(
+ cipher_suite_provider,
+ &confirmed_transcript_hash,
+ confirmation_tag,
+ )
+ .await?;
+
+ Ok((interim_transcript_hash, confirmed_transcript_hash))
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(crate) async fn find_key_package_generation<'a, K: KeyPackageStorage>(
+ key_package_repo: &K,
+ secrets: &'a [EncryptedGroupSecrets],
+) -> Result<(&'a EncryptedGroupSecrets, KeyPackageGeneration), MlsError> {
+ for secret in secrets {
+ if let Some(val) = key_package_repo
+ .get(&secret.new_member)
+ .await
+ .map_err(|e| MlsError::KeyPackageRepoError(e.into_any_error()))
+ .and_then(|maybe_data| {
+ if let Some(data) = maybe_data {
+ KeyPackageGeneration::from_storage(secret.new_member.to_vec(), data)
+ .map(|kpg| Some((secret, kpg)))
+ } else {
+ Ok::<_, MlsError>(None)
+ }
+ })?
+ {
+ return Ok(val);
+ }
+ }
+
+ Err(MlsError::WelcomeKeyPackageNotFound)
+}
+
+pub(crate) fn cipher_suite_provider<P>(
+ crypto: P,
+ cipher_suite: CipherSuite,
+) -> Result<P::CipherSuiteProvider, MlsError>
+where
+ P: CryptoProvider,
+{
+ crypto
+ .cipher_suite_provider(cipher_suite)
+ .ok_or(MlsError::UnsupportedCipherSuite(cipher_suite))
+}