improvement(crypto): Use a type for the result of the room key import method

This commit is contained in:
Damir Jelić
2021-11-11 16:57:31 +01:00
parent 6a66a67454
commit 33a43af20e
4 changed files with 48 additions and 22 deletions

View File

@@ -243,7 +243,7 @@ mod test {
use ruma::room_id;
use super::{decode, decrypt_helper, decrypt_key_export, encrypt_helper, encrypt_key_export};
use crate::{error::OlmResult, machine::test::get_prepared_machine};
use crate::{error::OlmResult, machine::test::get_prepared_machine, RoomKeyImportResult};
const PASSPHRASE: &str = "1234";
@@ -311,7 +311,10 @@ mod test {
let decrypted = decrypt_key_export(Cursor::new(encrypted), "1234").unwrap();
assert_eq!(export, decrypted);
assert_eq!(machine.import_keys(decrypted, |_, _| {}).await.unwrap(), (0, 1));
assert_eq!(
machine.import_keys(decrypted, |_, _| {}).await.unwrap(),
RoomKeyImportResult::new(0, 1)
);
}
#[async_test]
@@ -322,16 +325,22 @@ mod test {
let export = vec![session.export_at_index(10).await];
assert_eq!(machine.import_keys(export.clone(), |_, _| {}).await?, (1, 1));
assert_eq!(machine.import_keys(export, |_, _| {}).await?, (0, 1));
assert_eq!(
machine.import_keys(export.clone(), |_, _| {}).await?,
RoomKeyImportResult::new(1, 1)
);
assert_eq!(machine.import_keys(export, |_, _| {}).await?, RoomKeyImportResult::new(0, 1));
let better_export = vec![session.export().await];
assert_eq!(machine.import_keys(better_export, |_, _| {}).await?, (1, 1));
assert_eq!(
machine.import_keys(better_export, |_, _| {}).await?,
RoomKeyImportResult::new(1, 1)
);
let another_session = machine.create_inbound_session(&room_id).await?;
let export = vec![another_session.export_at_index(10).await];
assert_eq!(machine.import_keys(export, |_, _| {}).await?, (1, 1));
assert_eq!(machine.import_keys(export, |_, _| {}).await?, RoomKeyImportResult::new(1, 1));
Ok(())
}

View File

@@ -37,6 +37,21 @@ pub mod store;
mod utilities;
mod verification;
/// Return type for the room key importing.
#[derive(Debug, Clone, PartialEq)]
pub struct RoomKeyImportResult {
/// The number of room keys that were imported.
pub imported_count: usize,
/// The total number of room keys that were found in the export.
pub total_count: usize,
}
impl RoomKeyImportResult {
pub(crate) fn new(imported_count: usize, total_count: usize) -> Self {
Self { imported_count, total_count }
}
}
pub use error::{MegolmError, OlmError, SignatureError};
pub use file_encryption::{
decrypt_key_export, encrypt_key_export, AttachmentDecryptor, AttachmentEncryptor,

View File

@@ -69,7 +69,7 @@ use crate::{
SecretImportError, Store,
},
verification::{Verification, VerificationMachine, VerificationRequest},
CrossSigningKeyExport, ToDeviceRequest,
CrossSigningKeyExport, RoomKeyImportResult, ToDeviceRequest,
};
/// State machine implementation of the Olm/Megolm encryption protocol used for
@@ -1297,7 +1297,7 @@ impl OlmMachine {
&self,
exported_keys: Vec<ExportedRoomKey>,
progress_listener: impl Fn(usize, usize),
) -> StoreResult<(usize, usize)> {
) -> StoreResult<RoomKeyImportResult> {
type SessionIdToIndexMap = BTreeMap<Arc<str>, u32>;
#[derive(Debug)]
@@ -1338,7 +1338,7 @@ impl OlmMachine {
),
};
let total_sessions = exported_keys.len();
let total_count = exported_keys.len();
for (i, key) in exported_keys.into_iter().enumerate() {
let session = InboundGroupSession::from_export(key)?;
@@ -1350,18 +1350,18 @@ impl OlmMachine {
sessions.push(session)
}
progress_listener(i, total_sessions)
progress_listener(i, total_count)
}
let num_sessions = sessions.len();
let imported_count = sessions.len();
let changes = Changes { inbound_group_sessions: sessions, ..Default::default() };
self.store.save_changes(changes).await?;
info!("Successfully imported {} inbound group sessions", num_sessions);
info!(total_count, imported_count, "Successfully imported room keys");
Ok((num_sessions, total_sessions))
Ok(RoomKeyImportResult::new(imported_count, total_count))
}
/// Export the keys that match the given predicate.

View File

@@ -255,7 +255,7 @@ use std::{
};
use futures_util::stream::{self, StreamExt};
pub use matrix_sdk_base::crypto::{MediaEncryptionInfo, LocalTrust};
pub use matrix_sdk_base::crypto::{MediaEncryptionInfo, LocalTrust, RoomKeyImportResult};
use matrix_sdk_base::{
crypto::{
store::CryptoStoreError, CrossSigningStatus, OutgoingRequest, RoomMessageRequest,
@@ -646,21 +646,23 @@ impl Client {
/// # use futures::executor::block_on;
/// # use url::Url;
/// # block_on(async {
/// # let homeserver = Url::parse("http://localhost:8080").unwrap();
/// # let mut client = Client::new(homeserver).unwrap();
/// # let homeserver = Url::parse("http://localhost:8080")?;
/// # let mut client = Client::new(homeserver)?;
/// let path = PathBuf::from("/home/example/e2e-keys.txt");
/// client
/// .import_keys(path, "secret-passphrase")
/// .await
/// .expect("Can't import keys");
/// # });
/// let result = client.import_keys(path, "secret-passphrase").await?;
///
/// println!(
/// "Imported {} room keys out of {}",
/// result.imported_count, result.total_count
/// );
/// # anyhow::Result::<()>::Ok(()) });
/// ```
#[cfg(all(feature = "encryption", not(target_arch = "wasm32")))]
pub async fn import_keys(
&self,
path: PathBuf,
passphrase: &str,
) -> StdResult<(usize, usize), RoomKeyImportError> {
) -> StdResult<RoomKeyImportResult, RoomKeyImportError> {
let olm = self.base_client.olm_machine().await.ok_or(RoomKeyImportError::StoreClosed)?;
let passphrase = zeroize::Zeroizing::new(passphrase.to_owned());