blob: 153bc8fe6ace9d57af3ca7fba6431cc37ea869b5 [file] [log] [blame]
Martin Geisler51f31cc2024-04-09 13:35:45 +02001// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// Copyright by contributors to this project.
3// SPDX-License-Identifier: (Apache-2.0 OR MIT)
4
5use mls_rs::client_builder::Preferences;
6use mls_rs::group::{ReceivedMessage, StateUpdate};
7use mls_rs::{CipherSuite, ExtensionList, Group, MlsMessage, ProtocolVersion};
8
9use crate::test_client::{generate_client, TestClientConfig};
10
11#[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
12pub struct TestCase {
13 pub cipher_suite: u16,
14
15 pub external_psks: Vec<TestExternalPsk>,
16 #[serde(with = "hex::serde")]
17 pub key_package: Vec<u8>,
18 #[serde(with = "hex::serde")]
19 pub signature_priv: Vec<u8>,
20 #[serde(with = "hex::serde")]
21 pub encryption_priv: Vec<u8>,
22 #[serde(with = "hex::serde")]
23 pub init_priv: Vec<u8>,
24
25 #[serde(with = "hex::serde")]
26 pub welcome: Vec<u8>,
27 pub ratchet_tree: Option<TestRatchetTree>,
28 #[serde(with = "hex::serde")]
29 pub initial_epoch_authenticator: Vec<u8>,
30
31 pub epochs: Vec<TestEpoch>,
32}
33
34#[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
35pub struct TestExternalPsk {
36 #[serde(with = "hex::serde")]
37 pub psk_id: Vec<u8>,
38 #[serde(with = "hex::serde")]
39 pub psk: Vec<u8>,
40}
41
42#[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
43pub struct TestEpoch {
44 pub proposals: Vec<TestMlsMessage>,
45 #[serde(with = "hex::serde")]
46 pub commit: Vec<u8>,
47 #[serde(with = "hex::serde")]
48 pub epoch_authenticator: Vec<u8>,
49}
50
51#[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
52pub struct TestMlsMessage(#[serde(with = "hex::serde")] pub Vec<u8>);
53
54#[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
55pub struct TestRatchetTree(#[serde(with = "hex::serde")] pub Vec<u8>);
56
57impl TestEpoch {
58 pub fn new(
59 proposals: Vec<MlsMessage>,
60 commit: &MlsMessage,
61 epoch_authenticator: Vec<u8>,
62 ) -> Self {
63 let proposals = proposals
64 .into_iter()
65 .map(|p| TestMlsMessage(p.to_bytes().unwrap()))
66 .collect();
67
68 Self {
69 proposals,
70 commit: commit.to_bytes().unwrap(),
71 epoch_authenticator,
72 }
73 }
74}
75
76#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
77pub async fn get_test_groups(
78 protocol_version: ProtocolVersion,
79 cipher_suite: CipherSuite,
80 num_participants: usize,
81 preferences: Preferences,
82) -> Vec<Group<TestClientConfig>> {
83 // Create the group with Alice as the group initiator
84 let creator = generate_client(cipher_suite, b"alice".to_vec(), preferences.clone());
85
86 let mut creator_group = creator
87 .client
88 .create_group_with_id(
89 protocol_version,
90 cipher_suite,
91 b"group".to_vec(),
92 creator.identity,
93 ExtensionList::default(),
94 )
95 .await
96 .unwrap();
97
98 // Generate random clients that will be members of the group
99 let receiver_clients = (0..num_participants - 1)
100 .map(|i| {
101 generate_client(
102 cipher_suite,
103 format!("bob{i}").into_bytes(),
104 preferences.clone(),
105 )
106 })
107 .collect::<Vec<_>>();
108
109 let mut receiver_keys = Vec::new();
110
111 for client in &receiver_clients {
112 let keys = client
113 .client
114 .generate_key_package_message(protocol_version, cipher_suite, client.identity.clone())
115 .await
116 .unwrap();
117
118 receiver_keys.push(keys);
119 }
120
121 // Add the generated clients to the group the creator made
122 let mut commit_builder = creator_group.commit_builder();
123
124 for key in &receiver_keys {
125 commit_builder = commit_builder.add_member(key.clone()).unwrap();
126 }
127
128 let welcome = commit_builder.build().await.unwrap().welcome_message;
129
130 // Creator can confirm the commit was processed by the server
131 #[cfg(feature = "state_update")]
132 {
133 let commit_description = creator_group.apply_pending_commit().await.unwrap();
134
135 assert!(commit_description.state_update.is_active());
136 assert_eq!(commit_description.state_update.new_epoch(), 1);
137 }
138
139 #[cfg(not(feature = "state_update"))]
140 creator_group.apply_pending_commit().await.unwrap();
141
142 for client in &receiver_clients {
143 let res = creator_group
144 .member_with_identity(client.identity.credential.as_basic().unwrap().identifier())
145 .await;
146
147 assert!(res.is_ok());
148 }
149
150 #[cfg(feature = "state_update")]
151 assert!(commit_description
152 .state_update
153 .roster_update()
154 .removed()
155 .is_empty());
156
157 // Export the tree for receivers
158 let tree_data = creator_group.export_tree().unwrap();
159
160 // All the receivers will be able to join the group
161 let mut receiver_groups = Vec::new();
162
163 for client in &receiver_clients {
164 let test_client = client
165 .client
166 .join_group(Some(&tree_data), welcome.clone().unwrap())
167 .await
168 .unwrap()
169 .0;
170
171 receiver_groups.push(test_client);
172 }
173
174 for one_receiver in &receiver_groups {
175 assert!(Group::equal_group_state(&creator_group, one_receiver));
176 }
177
178 receiver_groups.insert(0, creator_group);
179
180 receiver_groups
181}
182
183#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
184pub async fn all_process_commit_with_update(
185 groups: &mut [Group<TestClientConfig>],
186 commit: &MlsMessage,
187 sender: usize,
188) -> Vec<StateUpdate> {
189 let mut state_updates = Vec::new();
190
191 for g in groups {
192 let state_update = if sender != g.current_member_index() as usize {
193 let processed_msg = g.process_incoming_message(commit.clone()).await.unwrap();
194
195 match processed_msg {
196 ReceivedMessage::Commit(update) => update.state_update,
197 _ => panic!("Expected commit, got {processed_msg:?}"),
198 }
199 } else {
200 g.apply_pending_commit().await.unwrap().state_update
201 };
202
203 state_updates.push(state_update);
204 }
205
206 state_updates
207}
208
209#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
210pub async fn all_process_message(
211 groups: &mut [Group<TestClientConfig>],
212 message: &MlsMessage,
213 sender: usize,
214 is_commit: bool,
215) {
216 for group in groups {
217 if sender != group.current_member_index() as usize {
218 group
219 .process_incoming_message(message.clone())
220 .await
221 .unwrap();
222 } else if is_commit {
223 group.apply_pending_commit().await.unwrap();
224 }
225 }
226}
227
228#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
229pub async fn add_random_members(
230 first_id: usize,
231 num_added: usize,
232 committer: usize,
233 groups: &mut Vec<Group<TestClientConfig>>,
234 test_case: Option<&mut TestCase>,
235) {
236 let cipher_suite = groups[committer].cipher_suite();
237 let committer_index = groups[committer].current_member_index() as usize;
238
239 let mut key_packages = Vec::new();
240 let mut new_clients = Vec::new();
241
242 for i in 0..num_added {
243 let id = first_id + i;
244 let new_client = generate_client(
245 cipher_suite,
246 format!("dave-{id}").into(),
247 Preferences::default(),
248 );
249
250 let key_package = new_client
251 .client
252 .generate_key_package_message(
253 ProtocolVersion::MLS_10,
254 cipher_suite,
255 new_client.identity.clone(),
256 )
257 .await
258 .unwrap();
259
260 key_packages.push(key_package);
261 new_clients.push(new_client);
262 }
263
264 let committer_group = &mut groups[committer];
265 let mut commit = committer_group.commit_builder();
266
267 for key_package in key_packages {
268 commit = commit.add_member(key_package).unwrap();
269 }
270
271 let commit_output = commit.build().await.unwrap();
272
273 all_process_message(groups, &commit_output.commit_message, committer_index, true).await;
274
275 let auth = groups[committer].epoch_authenticator().unwrap().to_vec();
276 let epoch = TestEpoch::new(vec![], &commit_output.commit_message, auth);
277
278 if let Some(tc) = test_case {
279 tc.epochs.push(epoch)
280 };
281
282 let tree_data = groups[committer].export_tree().unwrap();
283
284 let mut new_groups = Vec::new();
285
286 for client in &new_clients {
287 let tree_data = tree_data.clone();
288 let commit = commit_output.welcome_message.clone().unwrap();
289
290 let client = client
291 .client
292 .join_group(Some(&tree_data.clone()), commit)
293 .await
294 .unwrap()
295 .0;
296
297 new_groups.push(client);
298 }
299
300 groups.append(&mut new_groups);
301}
302
303#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
304pub async fn remove_members(
305 removed_members: Vec<usize>,
306 committer: usize,
307 groups: &mut Vec<Group<TestClientConfig>>,
308 test_case: Option<&mut TestCase>,
309) {
310 let remove_indexes = removed_members
311 .iter()
312 .map(|removed| groups[*removed].current_member_index())
313 .collect::<Vec<u32>>();
314
315 let mut commit_builder = groups[committer].commit_builder();
316
317 for index in remove_indexes {
318 commit_builder = commit_builder.remove_member(index).unwrap();
319 }
320
321 let commit = commit_builder.build().await.unwrap().commit_message;
322 let committer_index = groups[committer].current_member_index() as usize;
323 all_process_message(groups, &commit, committer_index, true).await;
324
325 let auth = groups[committer].epoch_authenticator().unwrap().to_vec();
326 let epoch = TestEpoch::new(vec![], &commit, auth);
327
328 if let Some(tc) = test_case {
329 tc.epochs.push(epoch)
330 };
331
332 let mut index = 0;
333
334 groups.retain(|_| {
335 index += 1;
336 !(removed_members.contains(&(index - 1)))
337 });
338}