diff --git a/crates/crypto/src/crypto/stream.rs b/crates/crypto/src/crypto/stream.rs index 83e8cef15..7e363bd3d 100644 --- a/crates/crypto/src/crypto/stream.rs +++ b/crates/crypto/src/crypto/stream.rs @@ -1,7 +1,7 @@ //! This module contains the crate's STREAM implementation, and wrappers that allow us to support multiple AEADs. #![allow(clippy::use_self)] // I think: https://github.com/rust-lang/rust-clippy/issues/3909 -use std::io::{Cursor, Read, Write}; +use std::io::Cursor; use crate::{ primitives::{AEAD_TAG_SIZE, BLOCK_SIZE, KEY_LEN}, @@ -13,6 +13,7 @@ use aead::{ }; use aes_gcm::Aes256Gcm; use chacha20poly1305::XChaCha20Poly1305; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; /// These are all possible algorithms that can be used for encryption and decryption #[derive(Clone, Copy, Eq, PartialEq, Hash)] @@ -105,14 +106,19 @@ impl StreamEncryption { /// It requires a reader, a writer, and any AAD to go with it. /// /// The AAD will be authenticated with each block of data. - pub fn encrypt_streams(mut self, mut reader: R, mut writer: W, aad: &[u8]) -> Result<()> + pub async fn encrypt_streams( + mut self, + mut reader: R, + mut writer: W, + aad: &[u8], + ) -> Result<()> where - R: Read, - W: Write, + R: AsyncReadExt + Unpin + Send, + W: AsyncWriteExt + Unpin + Send, { let mut read_buffer = vec![0u8; BLOCK_SIZE].into_boxed_slice(); loop { - let read_count = reader.read(&mut read_buffer)?; + let read_count = reader.read(&mut read_buffer).await?; if read_count == BLOCK_SIZE { let payload = Payload { aad, @@ -121,7 +127,7 @@ impl StreamEncryption { let encrypted_data = self.encrypt_next(payload).map_err(|_| Error::Encrypt)?; - writer.write_all(&encrypted_data)?; + writer.write_all(&encrypted_data).await?; } else { // we use `..read_count` in order to only use the read data, and not zeroes also let payload = Payload { @@ -130,13 +136,13 @@ impl StreamEncryption { }; let encrypted_data = self.encrypt_last(payload).map_err(|_| Error::Encrypt)?; - writer.write_all(&encrypted_data)?; + writer.write_all(&encrypted_data).await?; break; } } - writer.flush()?; + writer.flush().await?; Ok(()) } @@ -145,7 +151,7 @@ impl StreamEncryption { /// /// It is just a thin wrapper around `encrypt_streams()`, but reduces the amount of code needed elsewhere. #[allow(unused_mut)] - pub fn encrypt_bytes( + pub async fn encrypt_bytes( key: Protected<[u8; KEY_LEN]>, nonce: &[u8], algorithm: Algorithm, @@ -157,6 +163,7 @@ impl StreamEncryption { encryptor .encrypt_streams(bytes, &mut writer, aad) + .await .map_or_else(Err, |_| Ok(writer.into_inner())) } } @@ -218,15 +225,20 @@ impl StreamDecryption { /// It requires a reader, a writer, and any AAD that was used. /// /// The AAD will be authenticated with each block of data - if the AAD doesn't match what was used during encryption, an error will be returned. - pub fn decrypt_streams(mut self, mut reader: R, mut writer: W, aad: &[u8]) -> Result<()> + pub async fn decrypt_streams( + mut self, + mut reader: R, + mut writer: W, + aad: &[u8], + ) -> Result<()> where - R: Read, - W: Write, + R: AsyncReadExt + Unpin + Send, + W: AsyncWriteExt + Unpin + Send, { let mut read_buffer = vec![0u8; BLOCK_SIZE + AEAD_TAG_SIZE].into_boxed_slice(); loop { - let read_count = reader.read(&mut read_buffer)?; + let read_count = reader.read(&mut read_buffer).await?; if read_count == (BLOCK_SIZE + AEAD_TAG_SIZE) { let payload = Payload { aad, @@ -235,7 +247,7 @@ impl StreamDecryption { let decrypted_data = self.decrypt_next(payload).map_err(|_| Error::Decrypt)?; - writer.write_all(&decrypted_data)?; + writer.write_all(&decrypted_data).await?; } else { let payload = Payload { aad, @@ -243,13 +255,13 @@ impl StreamDecryption { }; let decrypted_data = self.decrypt_last(payload).map_err(|_| Error::Decrypt)?; - writer.write_all(&decrypted_data)?; + writer.write_all(&decrypted_data).await?; break; } } - writer.flush()?; + writer.flush().await?; Ok(()) } @@ -258,7 +270,7 @@ impl StreamDecryption { /// /// It is just a thin wrapper around `decrypt_streams()`, but reduces the amount of code needed elsewhere. #[allow(unused_mut)] - pub fn decrypt_bytes( + pub async fn decrypt_bytes( key: Protected<[u8; KEY_LEN]>, nonce: &[u8], algorithm: Algorithm, @@ -270,6 +282,7 @@ impl StreamDecryption { decryptor .decrypt_streams(bytes, &mut writer, aad) + .await .map_or_else(Err, |_| Ok(Protected::new(writer.into_inner()))) } }