From e921b1b614c8839e5c7bcd3906004f08f28fd889 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9ophile=20BR=C3=89ZOT?= Date: Fri, 25 Apr 2025 12:18:02 +0200 Subject: [PATCH 1/5] refacto MemoryADT implementation for KmsEncryptionLayer --- .../findex_server/tests/findex/basic.rs | 2 +- .../findex_client/src/kms/encryption_layer.rs | 21 +-- crate/findex_client/src/kms/memory_adt.rs | 176 ++++++++---------- crate/findex_client/src/kms/requests.rs | 27 +-- 4 files changed, 98 insertions(+), 128 deletions(-) diff --git a/crate/cli/src/actions/findex_server/tests/findex/basic.rs b/crate/cli/src/actions/findex_server/tests/findex/basic.rs index f626058b..439a34ec 100644 --- a/crate/cli/src/actions/findex_server/tests/findex/basic.rs +++ b/crate/cli/src/actions/findex_server/tests/findex/basic.rs @@ -31,7 +31,7 @@ use crate::{ }; pub(crate) fn findex_number_of_threads() -> Option { - std::env::var("GITHUB_ACTIONS").is_ok().then_some(1) + std::env::var("GITHUB_ACTIONS").map(|_| 1).ok() } #[tokio::test] diff --git a/crate/findex_client/src/kms/encryption_layer.rs b/crate/findex_client/src/kms/encryption_layer.rs index a8ae80cd..ab972071 100644 --- a/crate/findex_client/src/kms/encryption_layer.rs +++ b/crate/findex_client/src/kms/encryption_layer.rs @@ -64,14 +64,13 @@ impl< } /// Compute multiple HMAC on given memory addresses. - pub(crate) async fn hmac( + pub(crate) async fn batch_permute<'a>( &self, - addresses: Vec, + addresses: impl Iterator, ) -> ClientResult> { - trace!("hmac: Computing HMAC on addresses: {:?}", addresses); let tokens = self .kms_client - .message(self.build_mac_message_request(&addresses)?) + .message(self.build_mac_message_request(addresses)?) .await? .extract_items_data()? .into_iter() @@ -93,29 +92,27 @@ impl< } /// Bulk encrypts the given words using AES-XTS-512 and the given memory addresses as tweak. - pub(crate) async fn encrypt( + pub(crate) async fn batch_encrypt<'a>( &self, - words: &[[u8; WORD_LENGTH]], - tokens: &[Memory::Address], + bindings: impl Iterator, ) -> ClientResult> { Self::extract_words( &self .kms_client - .message(self.build_encrypt_message_request(words, tokens)?) + .message(self.build_encrypt_message_request(bindings)?) .await?, ) } /// Decrypts these ciphertexts using the given addresses as tweak. - pub(crate) async fn decrypt( + pub(crate) async fn batch_decrypt<'a>( &self, - words: &[[u8; WORD_LENGTH]], - tokens: &[Memory::Address], + bindings: impl Iterator, ) -> ClientResult> { Self::extract_words( &self .kms_client - .message(self.build_decrypt_message_request(words, tokens)?) + .message(self.build_decrypt_message_request(bindings)?) .await?, ) } diff --git a/crate/findex_client/src/kms/memory_adt.rs b/crate/findex_client/src/kms/memory_adt.rs index 45233733..441a7857 100644 --- a/crate/findex_client/src/kms/memory_adt.rs +++ b/crate/findex_client/src/kms/memory_adt.rs @@ -1,3 +1,5 @@ +use std::iter::once; + use cosmian_findex::{ADDRESS_LENGTH, Address, MemoryADT}; use tracing::trace; @@ -18,131 +20,109 @@ impl< guard: (Self::Address, Option), bindings: Vec<(Self::Address, Self::Word)>, ) -> Result, Self::Error> { - trace!("guarded_write: guard: {:?}", guard); - let (address, optional_word) = guard; - - // Split bindings into two vectors - let (mut bindings, mut bindings_words): (Vec<_>, Vec<_>) = bindings.into_iter().unzip(); - trace!("guarded_write: bindings_addresses: {bindings:?}"); - trace!("guarded_write: bindings_words: {bindings_words:?}"); - - // Compute HMAC of all addresses together (including the guard address) - bindings.push(address); // size: n+1 - let mut tokens = self.hmac(bindings).await?; - trace!("guarded_write: tokens: {tokens:?}"); - - // Put apart the last token - let token = tokens - .pop() - .ok_or_else(|| ClientError::Default("No token found".to_owned()))?; - - let (ciphertexts_and_tokens, old) = if let Some(word) = optional_word { - // Zip words and tokens - bindings_words.push(word); // size: n+1 - tokens.push(token); // size: n+1 - - // Bulk Encrypt - let mut ciphertexts = self.encrypt(&bindings_words, &tokens).await?; - trace!("guarded_write: ciphertexts: {ciphertexts:?}"); - - // Pop the old value - let old = ciphertexts - .pop() - .ok_or_else(|| ClientError::Default("No ciphertext found".to_owned()))?; - - // Zip ciphertexts and tokens - (ciphertexts.into_iter().zip(tokens), Some(old)) - } else { - // Bulk Encrypt - let ciphertexts = self.encrypt(&bindings_words, &tokens).await?; - trace!("guarded_write: ciphertexts: {ciphertexts:?}"); - - // Zip ciphertexts and tokens - (ciphertexts.into_iter().zip(tokens), None) - }; + // Cryptographic operations being delegated to the KMS, it is better to + // perform them in batch. Since permuted addresses are used as tweak in + // the AES-XTS encryption of the words, two batches are required. - // - // Send bindings to server - let cur = self - .mem - .guarded_write( - (token, old), - ciphertexts_and_tokens - .into_iter() - .map(|(w, a)| (a, w)) - .collect(), + trace!("guarded_write: {guard:?}, {bindings:?}"); + + let permuted_addresses = self + .batch_permute(bindings.iter().map(|(a, _)| a).chain([&guard.0])) + .await?; + + let encrypted_words = self + .batch_encrypt( + permuted_addresses + .iter() + .zip(bindings.iter().map(|(_, w)| w).chain(guard.1.iter())), ) + .await?; + + let encrypted_guard = ( + *permuted_addresses.get(bindings.len()).ok_or_else(|| { + ClientError::Default("no permuted guard address found".to_owned()) + })?, + encrypted_words.get(bindings.len()).cloned(), + ); + + let encrypted_bindings = permuted_addresses + .into_iter() + .zip(encrypted_words) + .take(bindings.len()) + .collect::>(); + + let permuted_ag = encrypted_guard.0.clone(); + + // Perform the actual call to the memory. + let encrypted_wg_cur = self + .mem + .guarded_write(encrypted_guard, encrypted_bindings) .await .map_err(|e| ClientError::Default(format!("Memory error: {e}")))?; - // - // Decrypt the current value (if any) - let res = match cur { + let wg_cur = match encrypted_wg_cur { Some(ctx) => Some( *self - .decrypt(&[ctx], &[token]) + .batch_decrypt(once((&permuted_ag, &ctx))) .await? .first() .ok_or_else(|| ClientError::Default("No plaintext found".to_owned()))?, ), None => None, }; - trace!("guarded_write: res: {res:?}"); - Ok(res) + trace!("guarded_write: current guard word: {wg_cur:?}"); + + Ok(wg_cur) } async fn batch_read( &self, addresses: Vec, ) -> Result>, Self::Error> { - trace!("batch_read: Addresses: {:?}", addresses); + trace!("batch_read: addresses: {:?}", addresses); - // Compute HMAC of all addresses - let tokens = self.hmac(addresses).await?; - trace!("batch_read: tokens: {:?}", tokens); + let permuted_addresses = self.batch_permute(addresses.iter()).await?; - // Read encrypted values server-side - let ciphertexts = self + let encrypted_words = self .mem - .batch_read(tokens.clone()) + .batch_read(permuted_addresses.clone()) .await .map_err(|e| ClientError::Default(format!("Memory error: {e}")))?; - trace!("batch_read: ciphertexts: {ciphertexts:?}"); - - // Track the positions of None values and bulk ciphertexts and tokens - let (stripped_ciphertexts, stripped_tokens, none_positions): (Vec<_>, Vec<_>, Vec<_>) = - ciphertexts - .into_iter() - .zip(tokens.into_iter()) - .enumerate() - .fold( - (vec![], vec![], vec![]), - |(mut ctxs, mut ts, mut ns), (i, (c, t))| { - match c { - Some(cipher) => { - ctxs.push(cipher); - ts.push(t); - } - None => ns.push(i), - } - (ctxs, ts, ns) - }, - ); - - // Recover plaintext-words - let words = self - .decrypt(&stripped_ciphertexts, &stripped_tokens) + + // None values need to be filtered out to compose with batch_decrypt. + // However, their positions shall not be lost. + let some_encrypted_words = encrypted_words + .into_iter() + .enumerate() + .filter_map(|(i, w)| w.map(|w| (i, w))) + .collect::>(); + + let some_words = self + .batch_decrypt( + some_encrypted_words + .iter() + .map(|(i, w)| (&permuted_addresses[*i], w)), + ) .await?; - trace!("batch_read: words: {:?}", words); - let mut res = words.into_iter().map(Some).collect::>(); - for i in none_positions { - res.insert(i, None); + // Replace the None values in the list of decrypted words at the same + // position as in the list of encrypted words. + let mut pos = some_encrypted_words.into_iter().map(|(i, _)| i).peekable(); + let mut words = Vec::with_capacity(addresses.len()); + let mut some_words = some_words.into_iter(); + for i in 0..addresses.len() { + if Some(&i) == pos.peek() { + pos.next(); + words.push(some_words.next()); + } else { + words.push(None); + } } - trace!("batch_read: res: {:?}", res); - Ok(res) + trace!("batch_read: words: {:?}", words); + + Ok(words) } } @@ -227,8 +207,8 @@ mod tests { handles.push(task::spawn(async move { for _ in 0..1_000 { - let ctx = layer.encrypt(&[ptx], &[tok]).await?.remove(0); - let res = layer.decrypt(&[ctx], &[tok]).await?.remove(0); + let ctx = layer.batch_encrypt(once((&tok, &ptx))).await?.remove(0); + let res = layer.batch_decrypt(once((&tok, &ctx))).await?.remove(0); assert_eq!(ptx, res); assert_eq!(ptx.len(), res.len()); } diff --git a/crate/findex_client/src/kms/requests.rs b/crate/findex_client/src/kms/requests.rs index 11f3dc02..23cb67ea 100644 --- a/crate/findex_client/src/kms/requests.rs +++ b/crate/findex_client/src/kms/requests.rs @@ -45,12 +45,11 @@ impl< } } - pub(crate) fn build_mac_message_request( + pub(crate) fn build_mac_message_request<'a>( &self, - addresses: &[Memory::Address], + addresses: impl Iterator, ) -> ClientResult { let items = addresses - .iter() .map(|address| { MessageBatchItem::new(Operation::Mac(self.build_mac_request(address.to_vec()))) }) @@ -73,15 +72,12 @@ impl< )?) } - pub(crate) fn build_encrypt_message_request( + pub(crate) fn build_encrypt_message_request<'a>( &self, - words: &[[u8; WORD_LENGTH]], - tokens: &[Memory::Address], + bindings: impl Iterator, ) -> ClientResult { - let items = words - .iter() - .zip(tokens) - .map(|(word, address)| { + let items = bindings + .map(|(address, word)| { self.build_encrypt_request(word.to_vec(), address.to_vec()) .map(|encrypt_request| { MessageBatchItem::new(Operation::Encrypt(encrypt_request)) @@ -105,15 +101,12 @@ impl< } } - pub(crate) fn build_decrypt_message_request( + pub(crate) fn build_decrypt_message_request<'a>( &self, - words: &[[u8; WORD_LENGTH]], - tokens: &[Memory::Address], + bindings: impl Iterator, ) -> ClientResult { - let items = words - .iter() - .zip(tokens) - .map(|(word, address)| { + let items = bindings + .map(|(address, word)| { MessageBatchItem::new(Operation::Decrypt( self.build_decrypt_request(word.to_vec(), address.to_vec()), )) From e05a59f86800c2300d35ed7571ca81a350cd2cfb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9ophile=20BR=C3=89ZOT?= Date: Fri, 25 Apr 2025 12:35:28 +0200 Subject: [PATCH 2/5] fix clippy lints --- crate/findex_client/src/kms/memory_adt.rs | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/crate/findex_client/src/kms/memory_adt.rs b/crate/findex_client/src/kms/memory_adt.rs index 441a7857..082267d1 100644 --- a/crate/findex_client/src/kms/memory_adt.rs +++ b/crate/findex_client/src/kms/memory_adt.rs @@ -42,7 +42,7 @@ impl< *permuted_addresses.get(bindings.len()).ok_or_else(|| { ClientError::Default("no permuted guard address found".to_owned()) })?, - encrypted_words.get(bindings.len()).cloned(), + encrypted_words.get(bindings.len()).copied(), ); let encrypted_bindings = permuted_addresses @@ -51,7 +51,7 @@ impl< .take(bindings.len()) .collect::>(); - let permuted_ag = encrypted_guard.0.clone(); + let permuted_ag = encrypted_guard.0; // Perform the actual call to the memory. let encrypted_wg_cur = self @@ -90,6 +90,12 @@ impl< .await .map_err(|e| ClientError::Default(format!("Memory error: {e}")))?; + if permuted_addresses.len() < encrypted_words.len() { + return Err(ClientError::Default(format!( + "there can be no more words than addresses" + ))); + } + // None values need to be filtered out to compose with batch_decrypt. // However, their positions shall not be lost. let some_encrypted_words = encrypted_words @@ -100,6 +106,11 @@ impl< let some_words = self .batch_decrypt( + // Since indexes are produced using encrypted_words and the + // above check guarantees its length is not greater than the + // length of permuted_addresses, the following indexing is + // guaranteed to be in range. + #[allow(clippy::index_slicing)] some_encrypted_words .iter() .map(|(i, w)| (&permuted_addresses[*i], w)), From 4e7d71c400d0fe68b64b195d8594515bd5af1481 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9ophile=20BR=C3=89ZOT?= Date: Fri, 25 Apr 2025 12:41:09 +0200 Subject: [PATCH 3/5] fix clippy lints --- crate/findex_client/src/kms/memory_adt.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/crate/findex_client/src/kms/memory_adt.rs b/crate/findex_client/src/kms/memory_adt.rs index 082267d1..6313d42d 100644 --- a/crate/findex_client/src/kms/memory_adt.rs +++ b/crate/findex_client/src/kms/memory_adt.rs @@ -91,9 +91,9 @@ impl< .map_err(|e| ClientError::Default(format!("Memory error: {e}")))?; if permuted_addresses.len() < encrypted_words.len() { - return Err(ClientError::Default(format!( - "there can be no more words than addresses" - ))); + return Err(ClientError::Default( + "there can be no more words than addresses".to_string(), + )); } // None values need to be filtered out to compose with batch_decrypt. @@ -110,7 +110,7 @@ impl< // above check guarantees its length is not greater than the // length of permuted_addresses, the following indexing is // guaranteed to be in range. - #[allow(clippy::index_slicing)] + #[allow(clippy::indexing_slicing)] some_encrypted_words .iter() .map(|(i, w)| (&permuted_addresses[*i], w)), From 993aec00b446d555d884da9125593df9e63c63f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9ophile=20BR=C3=89ZOT?= Date: Fri, 25 Apr 2025 12:45:13 +0200 Subject: [PATCH 4/5] better error message --- crate/findex_client/src/kms/memory_adt.rs | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/crate/findex_client/src/kms/memory_adt.rs b/crate/findex_client/src/kms/memory_adt.rs index 6313d42d..8ec26bec 100644 --- a/crate/findex_client/src/kms/memory_adt.rs +++ b/crate/findex_client/src/kms/memory_adt.rs @@ -90,10 +90,12 @@ impl< .await .map_err(|e| ClientError::Default(format!("Memory error: {e}")))?; - if permuted_addresses.len() < encrypted_words.len() { - return Err(ClientError::Default( - "there can be no more words than addresses".to_string(), - )); + if permuted_addresses.len() != encrypted_words.len() { + return Err(ClientError::Default(format!( + "incorrect number of words: expected {}, but {} were given", + permuted_addresses.len(), + encrypted_words.len() + ))); } // None values need to be filtered out to compose with batch_decrypt. From c6a93f317759fcd3de922bd805671bfa3da4f418 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9ophile=20BR=C3=89ZOT?= Date: Fri, 25 Apr 2025 12:46:27 +0200 Subject: [PATCH 5/5] update comment --- crate/findex_client/src/kms/memory_adt.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/crate/findex_client/src/kms/memory_adt.rs b/crate/findex_client/src/kms/memory_adt.rs index 8ec26bec..29f43ae2 100644 --- a/crate/findex_client/src/kms/memory_adt.rs +++ b/crate/findex_client/src/kms/memory_adt.rs @@ -109,9 +109,9 @@ impl< let some_words = self .batch_decrypt( // Since indexes are produced using encrypted_words and the - // above check guarantees its length is not greater than the - // length of permuted_addresses, the following indexing is - // guaranteed to be in range. + // above check guarantees its length is equal to the length of + // permuted_addresses, the following indexing is guaranteed to + // be in range. #[allow(clippy::indexing_slicing)] some_encrypted_words .iter()