Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crate/cli/src/actions/findex_server/tests/findex/basic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use crate::{
};

pub(crate) fn findex_number_of_threads() -> Option<usize> {
std::env::var("GITHUB_ACTIONS").is_ok().then_some(1)
std::env::var("GITHUB_ACTIONS").map(|_| 1).ok()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

while 100% equivalent the old syntax was more verbose and hence easier to understand for me but feel free to close this if you prefer the new one

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a concurrent modification I kept through the rebase. I guess which version is more readable depends on personal taste. I for example feels the other version to be less straightforward since it involves more Rust-specific helper functions (I don't know by heart the signature of is_ok(), while map is standard).

I can revert this change if you prefer.

}

#[tokio::test]
Expand Down
21 changes: 9 additions & 12 deletions crate/findex_client/src/kms/encryption_layer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,13 @@ impl<
}

/// Compute multiple HMAC on given memory addresses.
pub(crate) async fn hmac(
pub(crate) async fn batch_permute<'a>(
&self,
addresses: Vec<Memory::Address>,
addresses: impl Iterator<Item = &'a Memory::Address>,
) -> ClientResult<Vec<Memory::Address>> {
trace!("hmac: Computing HMAC on addresses: {:?}", addresses);
let tokens = self
.kms_client
.message(self.build_mac_message_request(&addresses)?)
.message(self.build_mac_message_request(addresses)?)
.await?
.extract_items_data()?
.into_iter()
Expand All @@ -93,29 +92,27 @@ impl<
}

/// Bulk encrypts the given words using AES-XTS-512 and the given memory addresses as tweak.
pub(crate) async fn encrypt(
pub(crate) async fn batch_encrypt<'a>(
&self,
words: &[[u8; WORD_LENGTH]],
tokens: &[Memory::Address],
bindings: impl Iterator<Item = (&'a Memory::Address, &'a [u8; WORD_LENGTH])>,
) -> ClientResult<Vec<[u8; WORD_LENGTH]>> {
Self::extract_words(
&self
.kms_client
.message(self.build_encrypt_message_request(words, tokens)?)
.message(self.build_encrypt_message_request(bindings)?)
.await?,
)
}

/// Decrypts these ciphertexts using the given addresses as tweak.
pub(crate) async fn decrypt(
pub(crate) async fn batch_decrypt<'a>(
&self,
words: &[[u8; WORD_LENGTH]],
tokens: &[Memory::Address],
bindings: impl Iterator<Item = (&'a Memory::Address, &'a [u8; WORD_LENGTH])>,
) -> ClientResult<Vec<[u8; WORD_LENGTH]>> {
Self::extract_words(
&self
.kms_client
.message(self.build_decrypt_message_request(words, tokens)?)
.message(self.build_decrypt_message_request(bindings)?)
.await?,
)
}
Expand Down
189 changes: 91 additions & 98 deletions crate/findex_client/src/kms/memory_adt.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::iter::once;

use cosmian_findex::{ADDRESS_LENGTH, Address, MemoryADT};
use tracing::trace;

Expand All @@ -18,131 +20,122 @@ impl<
guard: (Self::Address, Option<Self::Word>),
bindings: Vec<(Self::Address, Self::Word)>,
) -> Result<Option<Self::Word>, Self::Error> {
trace!("guarded_write: guard: {:?}", guard);
let (address, optional_word) = guard;

// Split bindings into two vectors
let (mut bindings, mut bindings_words): (Vec<_>, Vec<_>) = bindings.into_iter().unzip();
trace!("guarded_write: bindings_addresses: {bindings:?}");
trace!("guarded_write: bindings_words: {bindings_words:?}");

// Compute HMAC of all addresses together (including the guard address)
bindings.push(address); // size: n+1
let mut tokens = self.hmac(bindings).await?;
trace!("guarded_write: tokens: {tokens:?}");

// Put apart the last token
let token = tokens
.pop()
.ok_or_else(|| ClientError::Default("No token found".to_owned()))?;

let (ciphertexts_and_tokens, old) = if let Some(word) = optional_word {
// Zip words and tokens
bindings_words.push(word); // size: n+1
tokens.push(token); // size: n+1

// Bulk Encrypt
let mut ciphertexts = self.encrypt(&bindings_words, &tokens).await?;
trace!("guarded_write: ciphertexts: {ciphertexts:?}");

// Pop the old value
let old = ciphertexts
.pop()
.ok_or_else(|| ClientError::Default("No ciphertext found".to_owned()))?;

// Zip ciphertexts and tokens
(ciphertexts.into_iter().zip(tokens), Some(old))
} else {
// Bulk Encrypt
let ciphertexts = self.encrypt(&bindings_words, &tokens).await?;
trace!("guarded_write: ciphertexts: {ciphertexts:?}");

// Zip ciphertexts and tokens
(ciphertexts.into_iter().zip(tokens), None)
};
// Cryptographic operations being delegated to the KMS, it is better to
// perform them in batch. Since permuted addresses are used as tweak in
// the AES-XTS encryption of the words, two batches are required.

//
// Send bindings to server
let cur = self
.mem
.guarded_write(
(token, old),
ciphertexts_and_tokens
.into_iter()
.map(|(w, a)| (a, w))
.collect(),
trace!("guarded_write: {guard:?}, {bindings:?}");

let permuted_addresses = self
.batch_permute(bindings.iter().map(|(a, _)| a).chain([&guard.0]))
.await?;

let encrypted_words = self
.batch_encrypt(
permuted_addresses
.iter()
.zip(bindings.iter().map(|(_, w)| w).chain(guard.1.iter())),
)
.await?;

let encrypted_guard = (
*permuted_addresses.get(bindings.len()).ok_or_else(|| {
ClientError::Default("no permuted guard address found".to_owned())
})?,
encrypted_words.get(bindings.len()).copied(),
);

let encrypted_bindings = permuted_addresses
.into_iter()
.zip(encrypted_words)
.take(bindings.len())
.collect::<Vec<_>>();

let permuted_ag = encrypted_guard.0;

// Perform the actual call to the memory.
let encrypted_wg_cur = self
.mem
.guarded_write(encrypted_guard, encrypted_bindings)
.await
.map_err(|e| ClientError::Default(format!("Memory error: {e}")))?;

//
// Decrypt the current value (if any)
let res = match cur {
let wg_cur = match encrypted_wg_cur {
Some(ctx) => Some(
*self
.decrypt(&[ctx], &[token])
.batch_decrypt(once((&permuted_ag, &ctx)))
.await?
.first()
.ok_or_else(|| ClientError::Default("No plaintext found".to_owned()))?,
),
None => None,
};
trace!("guarded_write: res: {res:?}");

Ok(res)
trace!("guarded_write: current guard word: {wg_cur:?}");

Ok(wg_cur)
}

async fn batch_read(
&self,
addresses: Vec<Self::Address>,
) -> Result<Vec<Option<Self::Word>>, Self::Error> {
trace!("batch_read: Addresses: {:?}", addresses);
trace!("batch_read: addresses: {:?}", addresses);

// Compute HMAC of all addresses
let tokens = self.hmac(addresses).await?;
trace!("batch_read: tokens: {:?}", tokens);
let permuted_addresses = self.batch_permute(addresses.iter()).await?;

// Read encrypted values server-side
let ciphertexts = self
let encrypted_words = self
.mem
.batch_read(tokens.clone())
.batch_read(permuted_addresses.clone())
.await
.map_err(|e| ClientError::Default(format!("Memory error: {e}")))?;
trace!("batch_read: ciphertexts: {ciphertexts:?}");

// Track the positions of None values and bulk ciphertexts and tokens
let (stripped_ciphertexts, stripped_tokens, none_positions): (Vec<_>, Vec<_>, Vec<_>) =
ciphertexts
.into_iter()
.zip(tokens.into_iter())
.enumerate()
.fold(
(vec![], vec![], vec![]),
|(mut ctxs, mut ts, mut ns), (i, (c, t))| {
match c {
Some(cipher) => {
ctxs.push(cipher);
ts.push(t);
}
None => ns.push(i),
}
(ctxs, ts, ns)
},
);

// Recover plaintext-words
let words = self
.decrypt(&stripped_ciphertexts, &stripped_tokens)

if permuted_addresses.len() != encrypted_words.len() {
return Err(ClientError::Default(format!(
"incorrect number of words: expected {}, but {} were given",
permuted_addresses.len(),
encrypted_words.len()
)));
}

// None values need to be filtered out to compose with batch_decrypt.
// However, their positions shall not be lost.
let some_encrypted_words = encrypted_words
.into_iter()
.enumerate()
.filter_map(|(i, w)| w.map(|w| (i, w)))
.collect::<Vec<_>>();

let some_words = self
.batch_decrypt(
// Since indexes are produced using encrypted_words and the
// above check guarantees its length is equal to the length of
// permuted_addresses, the following indexing is guaranteed to
// be in range.
#[allow(clippy::indexing_slicing)]
some_encrypted_words
.iter()
.map(|(i, w)| (&permuted_addresses[*i], w)),
)
.await?;
trace!("batch_read: words: {:?}", words);

let mut res = words.into_iter().map(Some).collect::<Vec<_>>();
for i in none_positions {
res.insert(i, None);
// Replace the None values in the list of decrypted words at the same
// position as in the list of encrypted words.
let mut pos = some_encrypted_words.into_iter().map(|(i, _)| i).peekable();
let mut words = Vec::with_capacity(addresses.len());
let mut some_words = some_words.into_iter();
for i in 0..addresses.len() {
if Some(&i) == pos.peek() {
pos.next();
words.push(some_words.next());
} else {
words.push(None);
}
}
trace!("batch_read: res: {:?}", res);

Ok(res)
trace!("batch_read: words: {:?}", words);

Ok(words)
}
}

Expand Down Expand Up @@ -227,8 +220,8 @@ mod tests {

handles.push(task::spawn(async move {
for _ in 0..1_000 {
let ctx = layer.encrypt(&[ptx], &[tok]).await?.remove(0);
let res = layer.decrypt(&[ctx], &[tok]).await?.remove(0);
let ctx = layer.batch_encrypt(once((&tok, &ptx))).await?.remove(0);
let res = layer.batch_decrypt(once((&tok, &ctx))).await?.remove(0);
assert_eq!(ptx, res);
assert_eq!(ptx.len(), res.len());
}
Expand Down
27 changes: 10 additions & 17 deletions crate/findex_client/src/kms/requests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,11 @@ impl<
}
}

pub(crate) fn build_mac_message_request(
pub(crate) fn build_mac_message_request<'a>(
&self,
addresses: &[Memory::Address],
addresses: impl Iterator<Item = &'a Memory::Address>,
) -> ClientResult<Message> {
let items = addresses
.iter()
.map(|address| {
MessageBatchItem::new(Operation::Mac(self.build_mac_request(address.to_vec())))
})
Expand All @@ -73,15 +72,12 @@ impl<
)?)
}

pub(crate) fn build_encrypt_message_request(
pub(crate) fn build_encrypt_message_request<'a>(
&self,
words: &[[u8; WORD_LENGTH]],
tokens: &[Memory::Address],
bindings: impl Iterator<Item = (&'a Memory::Address, &'a [u8; WORD_LENGTH])>,
) -> ClientResult<Message> {
let items = words
.iter()
.zip(tokens)
.map(|(word, address)| {
let items = bindings
.map(|(address, word)| {
self.build_encrypt_request(word.to_vec(), address.to_vec())
.map(|encrypt_request| {
MessageBatchItem::new(Operation::Encrypt(encrypt_request))
Expand All @@ -105,15 +101,12 @@ impl<
}
}

pub(crate) fn build_decrypt_message_request(
pub(crate) fn build_decrypt_message_request<'a>(
&self,
words: &[[u8; WORD_LENGTH]],
tokens: &[Memory::Address],
bindings: impl Iterator<Item = (&'a Memory::Address, &'a [u8; WORD_LENGTH])>,
) -> ClientResult<Message> {
let items = words
.iter()
.zip(tokens)
.map(|(word, address)| {
let items = bindings
.map(|(address, word)| {
MessageBatchItem::new(Operation::Decrypt(
self.build_decrypt_request(word.to_vec(), address.to_vec()),
))
Expand Down