| // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. |
| // Copyright by contributors to this project. |
| // SPDX-License-Identifier: (Apache-2.0 OR MIT) |
| |
| use crate::{ |
| client::MlsError, |
| client_config::ClientConfig, |
| group::{ |
| key_schedule::KeySchedule, CommitGeneration, ConfirmationTag, Group, GroupContext, |
| GroupState, InterimTranscriptHash, ReInitProposal, TreeKemPublic, |
| }, |
| tree_kem::TreeKemPrivate, |
| }; |
| |
| #[cfg(feature = "by_ref_proposal")] |
| use crate::{ |
| crypto::{HpkePublicKey, HpkeSecretKey}, |
| group::ProposalRef, |
| }; |
| |
| #[cfg(feature = "by_ref_proposal")] |
| use super::proposal_cache::{CachedProposal, ProposalCache}; |
| |
| use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize}; |
| |
| use mls_rs_core::crypto::SignatureSecretKey; |
| #[cfg(feature = "tree_index")] |
| use mls_rs_core::identity::IdentityProvider; |
| |
| #[cfg(all(feature = "std", feature = "by_ref_proposal"))] |
| use std::collections::HashMap; |
| |
| #[cfg(all(feature = "by_ref_proposal", not(feature = "std")))] |
| use alloc::vec::Vec; |
| |
| use super::{cipher_suite_provider, epoch::EpochSecrets, state_repo::GroupStateRepository}; |
| |
| #[derive(Debug, PartialEq, Clone, MlsEncode, MlsDecode, MlsSize)] |
| #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] |
| pub(crate) struct Snapshot { |
| version: u16, |
| pub(crate) state: RawGroupState, |
| private_tree: TreeKemPrivate, |
| epoch_secrets: EpochSecrets, |
| key_schedule: KeySchedule, |
| #[cfg(all(feature = "std", feature = "by_ref_proposal"))] |
| pending_updates: HashMap<HpkePublicKey, (HpkeSecretKey, Option<SignatureSecretKey>)>, |
| #[cfg(all(not(feature = "std"), feature = "by_ref_proposal"))] |
| pending_updates: Vec<(HpkePublicKey, (HpkeSecretKey, Option<SignatureSecretKey>))>, |
| pending_commit: Option<CommitGeneration>, |
| signer: SignatureSecretKey, |
| } |
| |
| #[derive(Debug, MlsEncode, MlsDecode, MlsSize, PartialEq, Clone)] |
| #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] |
| pub(crate) struct RawGroupState { |
| pub(crate) context: GroupContext, |
| #[cfg(all(feature = "std", feature = "by_ref_proposal"))] |
| pub(crate) proposals: HashMap<ProposalRef, CachedProposal>, |
| #[cfg(all(not(feature = "std"), feature = "by_ref_proposal"))] |
| pub(crate) proposals: Vec<(ProposalRef, CachedProposal)>, |
| pub(crate) public_tree: TreeKemPublic, |
| pub(crate) interim_transcript_hash: InterimTranscriptHash, |
| pub(crate) pending_reinit: Option<ReInitProposal>, |
| pub(crate) confirmation_tag: ConfirmationTag, |
| } |
| |
| impl RawGroupState { |
| pub(crate) fn export(state: &GroupState) -> Self { |
| #[cfg(feature = "tree_index")] |
| let public_tree = state.public_tree.clone(); |
| |
| #[cfg(not(feature = "tree_index"))] |
| let public_tree = { |
| let mut tree = TreeKemPublic::new(); |
| tree.nodes = state.public_tree.nodes.clone(); |
| tree |
| }; |
| |
| Self { |
| context: state.context.clone(), |
| #[cfg(feature = "by_ref_proposal")] |
| proposals: state.proposals.proposals.clone(), |
| public_tree, |
| interim_transcript_hash: state.interim_transcript_hash.clone(), |
| pending_reinit: state.pending_reinit.clone(), |
| confirmation_tag: state.confirmation_tag.clone(), |
| } |
| } |
| |
| #[cfg(feature = "tree_index")] |
| #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] |
| pub(crate) async fn import<C>(self, identity_provider: &C) -> Result<GroupState, MlsError> |
| where |
| C: IdentityProvider, |
| { |
| let context = self.context; |
| |
| #[cfg(feature = "by_ref_proposal")] |
| let proposals = ProposalCache::import( |
| context.protocol_version, |
| context.group_id.clone(), |
| self.proposals, |
| ); |
| |
| let mut public_tree = self.public_tree; |
| |
| public_tree |
| .initialize_index_if_necessary(identity_provider, &context.extensions) |
| .await?; |
| |
| Ok(GroupState { |
| #[cfg(feature = "by_ref_proposal")] |
| proposals, |
| context, |
| public_tree, |
| interim_transcript_hash: self.interim_transcript_hash, |
| pending_reinit: self.pending_reinit, |
| confirmation_tag: self.confirmation_tag, |
| }) |
| } |
| |
| #[cfg(not(feature = "tree_index"))] |
| #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] |
| pub(crate) async fn import(self) -> Result<GroupState, MlsError> { |
| let context = self.context; |
| |
| #[cfg(feature = "by_ref_proposal")] |
| let proposals = ProposalCache::import( |
| context.protocol_version, |
| context.group_id.clone(), |
| self.proposals, |
| ); |
| |
| Ok(GroupState { |
| #[cfg(feature = "by_ref_proposal")] |
| proposals, |
| context, |
| public_tree: self.public_tree, |
| interim_transcript_hash: self.interim_transcript_hash, |
| pending_reinit: self.pending_reinit, |
| confirmation_tag: self.confirmation_tag, |
| }) |
| } |
| } |
| |
| impl<C> Group<C> |
| where |
| C: ClientConfig + Clone, |
| { |
| /// Write the current state of the group to the |
| /// [`GroupStorageProvider`](crate::GroupStateStorage) |
| /// that is currently in use by the group. |
| #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] |
| pub async fn write_to_storage(&mut self) -> Result<(), MlsError> { |
| self.state_repo.write_to_storage(self.snapshot()).await |
| } |
| |
| pub(crate) fn snapshot(&self) -> Snapshot { |
| Snapshot { |
| state: RawGroupState::export(&self.state), |
| private_tree: self.private_tree.clone(), |
| key_schedule: self.key_schedule.clone(), |
| #[cfg(feature = "by_ref_proposal")] |
| pending_updates: self.pending_updates.clone(), |
| pending_commit: self.pending_commit.clone(), |
| epoch_secrets: self.epoch_secrets.clone(), |
| version: 1, |
| signer: self.signer.clone(), |
| } |
| } |
| |
| #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] |
| pub(crate) async fn from_snapshot(config: C, snapshot: Snapshot) -> Result<Self, MlsError> { |
| let cipher_suite_provider = cipher_suite_provider( |
| config.crypto_provider(), |
| snapshot.state.context.cipher_suite, |
| )?; |
| |
| #[cfg(feature = "tree_index")] |
| let identity_provider = config.identity_provider(); |
| |
| let state_repo = GroupStateRepository::new( |
| #[cfg(feature = "prior_epoch")] |
| snapshot.state.context.group_id.clone(), |
| config.group_state_storage(), |
| config.key_package_repo(), |
| None, |
| )?; |
| |
| Ok(Group { |
| config, |
| state: snapshot |
| .state |
| .import( |
| #[cfg(feature = "tree_index")] |
| &identity_provider, |
| ) |
| .await?, |
| private_tree: snapshot.private_tree, |
| key_schedule: snapshot.key_schedule, |
| #[cfg(feature = "by_ref_proposal")] |
| pending_updates: snapshot.pending_updates, |
| pending_commit: snapshot.pending_commit, |
| #[cfg(test)] |
| commit_modifiers: Default::default(), |
| epoch_secrets: snapshot.epoch_secrets, |
| state_repo, |
| cipher_suite_provider, |
| #[cfg(feature = "psk")] |
| previous_psk: None, |
| signer: snapshot.signer, |
| }) |
| } |
| } |
| |
| #[cfg(test)] |
| pub(crate) mod test_utils { |
| use alloc::vec; |
| |
| use crate::{ |
| cipher_suite::CipherSuite, |
| crypto::test_utils::test_cipher_suite_provider, |
| group::{ |
| confirmation_tag::ConfirmationTag, epoch::test_utils::get_test_epoch_secrets, |
| key_schedule::test_utils::get_test_key_schedule, test_utils::get_test_group_context, |
| transcript_hash::InterimTranscriptHash, |
| }, |
| tree_kem::{node::LeafIndex, TreeKemPrivate}, |
| }; |
| |
| use super::{RawGroupState, Snapshot}; |
| |
| #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] |
| pub(crate) async fn get_test_snapshot(cipher_suite: CipherSuite, epoch_id: u64) -> Snapshot { |
| Snapshot { |
| state: RawGroupState { |
| context: get_test_group_context(epoch_id, cipher_suite).await, |
| #[cfg(feature = "by_ref_proposal")] |
| proposals: Default::default(), |
| public_tree: Default::default(), |
| interim_transcript_hash: InterimTranscriptHash::from(vec![]), |
| pending_reinit: None, |
| confirmation_tag: ConfirmationTag::empty(&test_cipher_suite_provider(cipher_suite)) |
| .await, |
| }, |
| private_tree: TreeKemPrivate::new(LeafIndex(0)), |
| epoch_secrets: get_test_epoch_secrets(cipher_suite), |
| key_schedule: get_test_key_schedule(cipher_suite), |
| #[cfg(feature = "by_ref_proposal")] |
| pending_updates: Default::default(), |
| pending_commit: None, |
| version: 1, |
| signer: vec![].into(), |
| } |
| } |
| } |
| |
| #[cfg(test)] |
| mod tests { |
| use alloc::vec; |
| |
| use crate::{ |
| client::test_utils::{TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION}, |
| group::{ |
| test_utils::{test_group, TestGroup}, |
| Group, |
| }, |
| }; |
| |
| #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] |
| async fn snapshot_restore(group: TestGroup) { |
| let snapshot = group.group.snapshot(); |
| |
| let group_restored = Group::from_snapshot(group.group.config.clone(), snapshot) |
| .await |
| .unwrap(); |
| |
| assert!(Group::equal_group_state(&group.group, &group_restored)); |
| |
| #[cfg(feature = "tree_index")] |
| assert!(group_restored |
| .state |
| .public_tree |
| .equal_internals(&group.group.state.public_tree)) |
| } |
| |
| #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] |
| async fn snapshot_with_pending_commit_can_be_serialized_to_json() { |
| let mut group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await; |
| group.group.commit(vec![]).await.unwrap(); |
| |
| snapshot_restore(group).await |
| } |
| |
| #[cfg(feature = "by_ref_proposal")] |
| #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] |
| async fn snapshot_with_pending_updates_can_be_serialized_to_json() { |
| let mut group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await; |
| |
| // Creating the update proposal will add it to pending updates |
| let update_proposal = group.update_proposal().await; |
| |
| // This will insert the proposal into the internal proposal cache |
| let _ = group.group.proposal_message(update_proposal, vec![]).await; |
| |
| snapshot_restore(group).await |
| } |
| |
| #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] |
| async fn snapshot_can_be_serialized_to_json_with_internals() { |
| let group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await; |
| |
| snapshot_restore(group).await |
| } |
| |
| #[cfg(feature = "serde")] |
| #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] |
| async fn serde() { |
| let snapshot = super::test_utils::get_test_snapshot(TEST_CIPHER_SUITE, 5).await; |
| let json = serde_json::to_string_pretty(&snapshot).unwrap(); |
| let recovered = serde_json::from_str(&json).unwrap(); |
| assert_eq!(snapshot, recovered); |
| } |
| } |