async stream encryption and decryption

This commit is contained in:
brxken128
2023-01-24 12:49:54 +00:00
parent 320e6a834f
commit 48622104f5

View File

@@ -1,7 +1,7 @@
//! This module contains the crate's STREAM implementation, and wrappers that allow us to support multiple AEADs.
#![allow(clippy::use_self)] // I think: https://github.com/rust-lang/rust-clippy/issues/3909
use std::io::{Cursor, Read, Write};
use std::io::Cursor;
use crate::{
primitives::{AEAD_TAG_SIZE, BLOCK_SIZE, KEY_LEN},
@@ -13,6 +13,7 @@ use aead::{
};
use aes_gcm::Aes256Gcm;
use chacha20poly1305::XChaCha20Poly1305;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
/// These are all possible algorithms that can be used for encryption and decryption
#[derive(Clone, Copy, Eq, PartialEq, Hash)]
@@ -105,14 +106,19 @@ impl StreamEncryption {
/// It requires a reader, a writer, and any AAD to go with it.
///
/// The AAD will be authenticated with each block of data.
pub fn encrypt_streams<R, W>(mut self, mut reader: R, mut writer: W, aad: &[u8]) -> Result<()>
pub async fn encrypt_streams<R, W>(
mut self,
mut reader: R,
mut writer: W,
aad: &[u8],
) -> Result<()>
where
R: Read,
W: Write,
R: AsyncReadExt + Unpin + Send,
W: AsyncWriteExt + Unpin + Send,
{
let mut read_buffer = vec![0u8; BLOCK_SIZE].into_boxed_slice();
loop {
let read_count = reader.read(&mut read_buffer)?;
let read_count = reader.read(&mut read_buffer).await?;
if read_count == BLOCK_SIZE {
let payload = Payload {
aad,
@@ -121,7 +127,7 @@ impl StreamEncryption {
let encrypted_data = self.encrypt_next(payload).map_err(|_| Error::Encrypt)?;
writer.write_all(&encrypted_data)?;
writer.write_all(&encrypted_data).await?;
} else {
// we use `..read_count` in order to only use the read data, and not zeroes also
let payload = Payload {
@@ -130,13 +136,13 @@ impl StreamEncryption {
};
let encrypted_data = self.encrypt_last(payload).map_err(|_| Error::Encrypt)?;
writer.write_all(&encrypted_data)?;
writer.write_all(&encrypted_data).await?;
break;
}
}
writer.flush()?;
writer.flush().await?;
Ok(())
}
@@ -145,7 +151,7 @@ impl StreamEncryption {
///
/// It is just a thin wrapper around `encrypt_streams()`, but reduces the amount of code needed elsewhere.
#[allow(unused_mut)]
pub fn encrypt_bytes(
pub async fn encrypt_bytes(
key: Protected<[u8; KEY_LEN]>,
nonce: &[u8],
algorithm: Algorithm,
@@ -157,6 +163,7 @@ impl StreamEncryption {
encryptor
.encrypt_streams(bytes, &mut writer, aad)
.await
.map_or_else(Err, |_| Ok(writer.into_inner()))
}
}
@@ -218,15 +225,20 @@ impl StreamDecryption {
/// It requires a reader, a writer, and any AAD that was used.
///
/// The AAD will be authenticated with each block of data - if the AAD doesn't match what was used during encryption, an error will be returned.
pub fn decrypt_streams<R, W>(mut self, mut reader: R, mut writer: W, aad: &[u8]) -> Result<()>
pub async fn decrypt_streams<R, W>(
mut self,
mut reader: R,
mut writer: W,
aad: &[u8],
) -> Result<()>
where
R: Read,
W: Write,
R: AsyncReadExt + Unpin + Send,
W: AsyncWriteExt + Unpin + Send,
{
let mut read_buffer = vec![0u8; BLOCK_SIZE + AEAD_TAG_SIZE].into_boxed_slice();
loop {
let read_count = reader.read(&mut read_buffer)?;
let read_count = reader.read(&mut read_buffer).await?;
if read_count == (BLOCK_SIZE + AEAD_TAG_SIZE) {
let payload = Payload {
aad,
@@ -235,7 +247,7 @@ impl StreamDecryption {
let decrypted_data = self.decrypt_next(payload).map_err(|_| Error::Decrypt)?;
writer.write_all(&decrypted_data)?;
writer.write_all(&decrypted_data).await?;
} else {
let payload = Payload {
aad,
@@ -243,13 +255,13 @@ impl StreamDecryption {
};
let decrypted_data = self.decrypt_last(payload).map_err(|_| Error::Decrypt)?;
writer.write_all(&decrypted_data)?;
writer.write_all(&decrypted_data).await?;
break;
}
}
writer.flush()?;
writer.flush().await?;
Ok(())
}
@@ -258,7 +270,7 @@ impl StreamDecryption {
///
/// It is just a thin wrapper around `decrypt_streams()`, but reduces the amount of code needed elsewhere.
#[allow(unused_mut)]
pub fn decrypt_bytes(
pub async fn decrypt_bytes(
key: Protected<[u8; KEY_LEN]>,
nonce: &[u8],
algorithm: Algorithm,
@@ -270,6 +282,7 @@ impl StreamDecryption {
decryptor
.decrypt_streams(bytes, &mut writer, aad)
.await
.map_or_else(Err, |_| Ok(Protected::new(writer.into_inner())))
}
}