mirror of
https://github.com/matrix-org/matrix-rust-sdk.git
synced 2026-05-19 06:04:31 -04:00
improvement(crypto): Use a type for the result of the room key import method
This commit is contained in:
@@ -243,7 +243,7 @@ mod test {
|
||||
use ruma::room_id;
|
||||
|
||||
use super::{decode, decrypt_helper, decrypt_key_export, encrypt_helper, encrypt_key_export};
|
||||
use crate::{error::OlmResult, machine::test::get_prepared_machine};
|
||||
use crate::{error::OlmResult, machine::test::get_prepared_machine, RoomKeyImportResult};
|
||||
|
||||
const PASSPHRASE: &str = "1234";
|
||||
|
||||
@@ -311,7 +311,10 @@ mod test {
|
||||
let decrypted = decrypt_key_export(Cursor::new(encrypted), "1234").unwrap();
|
||||
|
||||
assert_eq!(export, decrypted);
|
||||
assert_eq!(machine.import_keys(decrypted, |_, _| {}).await.unwrap(), (0, 1));
|
||||
assert_eq!(
|
||||
machine.import_keys(decrypted, |_, _| {}).await.unwrap(),
|
||||
RoomKeyImportResult::new(0, 1)
|
||||
);
|
||||
}
|
||||
|
||||
#[async_test]
|
||||
@@ -322,16 +325,22 @@ mod test {
|
||||
|
||||
let export = vec![session.export_at_index(10).await];
|
||||
|
||||
assert_eq!(machine.import_keys(export.clone(), |_, _| {}).await?, (1, 1));
|
||||
assert_eq!(machine.import_keys(export, |_, _| {}).await?, (0, 1));
|
||||
assert_eq!(
|
||||
machine.import_keys(export.clone(), |_, _| {}).await?,
|
||||
RoomKeyImportResult::new(1, 1)
|
||||
);
|
||||
assert_eq!(machine.import_keys(export, |_, _| {}).await?, RoomKeyImportResult::new(0, 1));
|
||||
|
||||
let better_export = vec![session.export().await];
|
||||
|
||||
assert_eq!(machine.import_keys(better_export, |_, _| {}).await?, (1, 1));
|
||||
assert_eq!(
|
||||
machine.import_keys(better_export, |_, _| {}).await?,
|
||||
RoomKeyImportResult::new(1, 1)
|
||||
);
|
||||
|
||||
let another_session = machine.create_inbound_session(&room_id).await?;
|
||||
let export = vec![another_session.export_at_index(10).await];
|
||||
assert_eq!(machine.import_keys(export, |_, _| {}).await?, (1, 1));
|
||||
assert_eq!(machine.import_keys(export, |_, _| {}).await?, RoomKeyImportResult::new(1, 1));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -37,6 +37,21 @@ pub mod store;
|
||||
mod utilities;
|
||||
mod verification;
|
||||
|
||||
/// Return type for the room key importing.
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct RoomKeyImportResult {
|
||||
/// The number of room keys that were imported.
|
||||
pub imported_count: usize,
|
||||
/// The total number of room keys that were found in the export.
|
||||
pub total_count: usize,
|
||||
}
|
||||
|
||||
impl RoomKeyImportResult {
|
||||
pub(crate) fn new(imported_count: usize, total_count: usize) -> Self {
|
||||
Self { imported_count, total_count }
|
||||
}
|
||||
}
|
||||
|
||||
pub use error::{MegolmError, OlmError, SignatureError};
|
||||
pub use file_encryption::{
|
||||
decrypt_key_export, encrypt_key_export, AttachmentDecryptor, AttachmentEncryptor,
|
||||
|
||||
@@ -69,7 +69,7 @@ use crate::{
|
||||
SecretImportError, Store,
|
||||
},
|
||||
verification::{Verification, VerificationMachine, VerificationRequest},
|
||||
CrossSigningKeyExport, ToDeviceRequest,
|
||||
CrossSigningKeyExport, RoomKeyImportResult, ToDeviceRequest,
|
||||
};
|
||||
|
||||
/// State machine implementation of the Olm/Megolm encryption protocol used for
|
||||
@@ -1297,7 +1297,7 @@ impl OlmMachine {
|
||||
&self,
|
||||
exported_keys: Vec<ExportedRoomKey>,
|
||||
progress_listener: impl Fn(usize, usize),
|
||||
) -> StoreResult<(usize, usize)> {
|
||||
) -> StoreResult<RoomKeyImportResult> {
|
||||
type SessionIdToIndexMap = BTreeMap<Arc<str>, u32>;
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -1338,7 +1338,7 @@ impl OlmMachine {
|
||||
),
|
||||
};
|
||||
|
||||
let total_sessions = exported_keys.len();
|
||||
let total_count = exported_keys.len();
|
||||
|
||||
for (i, key) in exported_keys.into_iter().enumerate() {
|
||||
let session = InboundGroupSession::from_export(key)?;
|
||||
@@ -1350,18 +1350,18 @@ impl OlmMachine {
|
||||
sessions.push(session)
|
||||
}
|
||||
|
||||
progress_listener(i, total_sessions)
|
||||
progress_listener(i, total_count)
|
||||
}
|
||||
|
||||
let num_sessions = sessions.len();
|
||||
let imported_count = sessions.len();
|
||||
|
||||
let changes = Changes { inbound_group_sessions: sessions, ..Default::default() };
|
||||
|
||||
self.store.save_changes(changes).await?;
|
||||
|
||||
info!("Successfully imported {} inbound group sessions", num_sessions);
|
||||
info!(total_count, imported_count, "Successfully imported room keys");
|
||||
|
||||
Ok((num_sessions, total_sessions))
|
||||
Ok(RoomKeyImportResult::new(imported_count, total_count))
|
||||
}
|
||||
|
||||
/// Export the keys that match the given predicate.
|
||||
|
||||
@@ -255,7 +255,7 @@ use std::{
|
||||
};
|
||||
|
||||
use futures_util::stream::{self, StreamExt};
|
||||
pub use matrix_sdk_base::crypto::{MediaEncryptionInfo, LocalTrust};
|
||||
pub use matrix_sdk_base::crypto::{MediaEncryptionInfo, LocalTrust, RoomKeyImportResult};
|
||||
use matrix_sdk_base::{
|
||||
crypto::{
|
||||
store::CryptoStoreError, CrossSigningStatus, OutgoingRequest, RoomMessageRequest,
|
||||
@@ -646,21 +646,23 @@ impl Client {
|
||||
/// # use futures::executor::block_on;
|
||||
/// # use url::Url;
|
||||
/// # block_on(async {
|
||||
/// # let homeserver = Url::parse("http://localhost:8080").unwrap();
|
||||
/// # let mut client = Client::new(homeserver).unwrap();
|
||||
/// # let homeserver = Url::parse("http://localhost:8080")?;
|
||||
/// # let mut client = Client::new(homeserver)?;
|
||||
/// let path = PathBuf::from("/home/example/e2e-keys.txt");
|
||||
/// client
|
||||
/// .import_keys(path, "secret-passphrase")
|
||||
/// .await
|
||||
/// .expect("Can't import keys");
|
||||
/// # });
|
||||
/// let result = client.import_keys(path, "secret-passphrase").await?;
|
||||
///
|
||||
/// println!(
|
||||
/// "Imported {} room keys out of {}",
|
||||
/// result.imported_count, result.total_count
|
||||
/// );
|
||||
/// # anyhow::Result::<()>::Ok(()) });
|
||||
/// ```
|
||||
#[cfg(all(feature = "encryption", not(target_arch = "wasm32")))]
|
||||
pub async fn import_keys(
|
||||
&self,
|
||||
path: PathBuf,
|
||||
passphrase: &str,
|
||||
) -> StdResult<(usize, usize), RoomKeyImportError> {
|
||||
) -> StdResult<RoomKeyImportResult, RoomKeyImportError> {
|
||||
let olm = self.base_client.olm_machine().await.ok_or(RoomKeyImportError::StoreClosed)?;
|
||||
let passphrase = zeroize::Zeroizing::new(passphrase.to_owned());
|
||||
|
||||
|
||||
Reference in New Issue
Block a user