mirror of
https://github.com/matrix-org/matrix-rust-sdk.git
synced 2026-05-16 20:49:05 -04:00
Merge pull request #3253 from matrix-org/andybalaam/add-backup_version-arg
crypto: Add a backup_version argument to group session backup methods
This commit is contained in:
@@ -8,6 +8,12 @@ Changed:
|
||||
|
||||
Breaking changes:
|
||||
|
||||
- Add a `backup_version` argument to `CryptoStore`'s
|
||||
`inbound_group_sessions_for_backup`,
|
||||
`mark_inbound_group_sessions_as_backed_up` and
|
||||
`inbound_group_session_counts` methods.
|
||||
([#3253](https://github.com/matrix-org/matrix-rust-sdk/pull/3253))
|
||||
|
||||
- Rename the `OlmMachine::invalidate_group_session` method to
|
||||
`OlmMachine::discard_room_key`
|
||||
|
||||
|
||||
@@ -397,7 +397,8 @@ impl BackupMachine {
|
||||
|
||||
/// Get the number of backed up room keys and the total number of room keys.
|
||||
pub async fn room_key_counts(&self) -> Result<RoomKeyCounts, CryptoStoreError> {
|
||||
self.store.inbound_group_session_counts().await
|
||||
let backup_version = self.backup_key.read().await.as_ref().and_then(|k| k.backup_version());
|
||||
self.store.inbound_group_session_counts(backup_version.as_deref()).await
|
||||
}
|
||||
|
||||
/// Disable and reset our backup state.
|
||||
@@ -476,7 +477,12 @@ impl BackupMachine {
|
||||
|
||||
trace!(request_id = ?r.request_id, keys = ?r.sessions, "Marking room keys as backed up");
|
||||
|
||||
self.store.mark_inbound_group_sessions_as_backed_up(&room_and_session_ids).await?;
|
||||
self.store
|
||||
.mark_inbound_group_sessions_as_backed_up(
|
||||
&r.request.version,
|
||||
&room_and_session_ids,
|
||||
)
|
||||
.await?;
|
||||
|
||||
trace!(
|
||||
request_id = ?r.request_id,
|
||||
@@ -514,7 +520,7 @@ impl BackupMachine {
|
||||
};
|
||||
|
||||
let sessions =
|
||||
self.store.inbound_group_sessions_for_backup(Self::BACKUP_BATCH_SIZE).await?;
|
||||
self.store.inbound_group_sessions_for_backup(&version, Self::BACKUP_BATCH_SIZE).await?;
|
||||
|
||||
if sessions.is_empty() {
|
||||
trace!(?backup_key, "No room keys need to be backed up");
|
||||
@@ -619,6 +625,7 @@ mod tests {
|
||||
use ruma::{device_id, room_id, user_id, CanonicalJsonValue, DeviceId, RoomId, UserId};
|
||||
use serde_json::json;
|
||||
|
||||
use super::BackupMachine;
|
||||
use crate::{
|
||||
olm::BackedUpRoomKey, store::BackupDecryptionKey, types::RoomKeyBackupInfo, OlmError,
|
||||
OlmMachine,
|
||||
@@ -658,7 +665,10 @@ mod tests {
|
||||
|
||||
async fn backup_flow(machine: OlmMachine) -> Result<(), OlmError> {
|
||||
let backup_machine = machine.backup_machine();
|
||||
let counts = backup_machine.store.inbound_group_session_counts().await?;
|
||||
let backup_version = current_backup_version(backup_machine).await;
|
||||
|
||||
let counts =
|
||||
backup_machine.store.inbound_group_session_counts(backup_version.as_deref()).await?;
|
||||
|
||||
assert_eq!(counts.total, 0, "Initially no keys exist");
|
||||
assert_eq!(counts.backed_up, 0, "Initially no backed up keys exist");
|
||||
@@ -666,7 +676,8 @@ mod tests {
|
||||
machine.create_outbound_group_session_with_defaults_test_helper(room_id()).await?;
|
||||
machine.create_outbound_group_session_with_defaults_test_helper(room_id2()).await?;
|
||||
|
||||
let counts = backup_machine.store.inbound_group_session_counts().await?;
|
||||
let counts =
|
||||
backup_machine.store.inbound_group_session_counts(backup_version.as_deref()).await?;
|
||||
assert_eq!(counts.total, 2, "Two room keys need to exist in the store");
|
||||
assert_eq!(counts.backed_up, 0, "No room keys have been backed up yet");
|
||||
|
||||
@@ -685,8 +696,10 @@ mod tests {
|
||||
);
|
||||
|
||||
backup_machine.mark_request_as_sent(&request_id).await?;
|
||||
let backup_version = current_backup_version(backup_machine).await;
|
||||
|
||||
let counts = backup_machine.store.inbound_group_session_counts().await?;
|
||||
let counts =
|
||||
backup_machine.store.inbound_group_session_counts(backup_version.as_deref()).await?;
|
||||
assert_eq!(counts.total, 2);
|
||||
assert_eq!(counts.backed_up, 2, "All room keys have been backed up");
|
||||
|
||||
@@ -696,8 +709,10 @@ mod tests {
|
||||
);
|
||||
|
||||
backup_machine.disable_backup().await?;
|
||||
let backup_version = current_backup_version(backup_machine).await;
|
||||
|
||||
let counts = backup_machine.store.inbound_group_session_counts().await?;
|
||||
let counts =
|
||||
backup_machine.store.inbound_group_session_counts(backup_version.as_deref()).await?;
|
||||
assert_eq!(counts.total, 2);
|
||||
assert_eq!(
|
||||
counts.backed_up, 0,
|
||||
@@ -707,6 +722,10 @@ mod tests {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn current_backup_version(backup_machine: &BackupMachine) -> Option<String> {
|
||||
backup_machine.backup_key.read().await.as_ref().and_then(|k| k.backup_version())
|
||||
}
|
||||
|
||||
#[async_test]
|
||||
async fn memory_store_backups() -> Result<(), OlmError> {
|
||||
let machine = OlmMachine::new(alice_id(), alice_device_id()).await;
|
||||
|
||||
@@ -323,10 +323,10 @@ macro_rules! cryptostore_integration_tests {
|
||||
.unwrap();
|
||||
assert_eq!(session, loaded_session);
|
||||
assert_eq!(store.get_inbound_group_sessions().await.unwrap().len(), 1);
|
||||
assert_eq!(store.inbound_group_session_counts().await.unwrap().total, 1);
|
||||
assert_eq!(store.inbound_group_session_counts().await.unwrap().backed_up, 0);
|
||||
assert_eq!(store.inbound_group_session_counts(None).await.unwrap().total, 1);
|
||||
assert_eq!(store.inbound_group_session_counts(None).await.unwrap().backed_up, 0);
|
||||
|
||||
let to_back_up = store.inbound_group_sessions_for_backup(1).await.unwrap();
|
||||
let to_back_up = store.inbound_group_sessions_for_backup("bkpver", 1).await.unwrap();
|
||||
assert_eq!(to_back_up, vec![session])
|
||||
}
|
||||
|
||||
@@ -342,14 +342,10 @@ macro_rules! cryptostore_integration_tests {
|
||||
}
|
||||
let changes = Changes { inbound_group_sessions: sessions.clone(), ..Default::default() };
|
||||
store.save_changes(changes).await.expect("Can't save group session");
|
||||
assert_eq!(store.inbound_group_sessions_for_backup(100).await.unwrap().len(), 10);
|
||||
|
||||
fn session_info(session: &InboundGroupSession) -> (&RoomId, &str) {
|
||||
(&session.room_id(), &session.session_id())
|
||||
}
|
||||
assert_eq!(store.inbound_group_sessions_for_backup("bkpver", 100).await.unwrap().len(), 10);
|
||||
|
||||
// When I mark some as backed up
|
||||
store.mark_inbound_group_sessions_as_backed_up(&[
|
||||
store.mark_inbound_group_sessions_as_backed_up("bkpver", &[
|
||||
session_info(&sessions[1]),
|
||||
session_info(&sessions[3]),
|
||||
session_info(&sessions[5]),
|
||||
@@ -358,7 +354,7 @@ macro_rules! cryptostore_integration_tests {
|
||||
]).await.expect("Failed to mark sessions as backed up");
|
||||
|
||||
// And ask which still need backing up
|
||||
let to_back_up = store.inbound_group_sessions_for_backup(10).await.unwrap();
|
||||
let to_back_up = store.inbound_group_sessions_for_backup("bkpver", 10).await.unwrap();
|
||||
let needs_backing_up = |i: usize| to_back_up.iter().any(|s| s.session_id() == sessions[i].session_id());
|
||||
|
||||
// Then the sessions we said were backed up no longer need backing up
|
||||
@@ -381,27 +377,35 @@ macro_rules! cryptostore_integration_tests {
|
||||
async fn reset_inbound_group_session_for_backup() {
|
||||
let (account, store) =
|
||||
get_loaded_store("reset_inbound_group_session_for_backup").await;
|
||||
assert_eq!(store.inbound_group_session_counts().await.unwrap().total, 0);
|
||||
assert_eq!(store.inbound_group_session_counts(None).await.unwrap().total, 0);
|
||||
|
||||
let room_id = &room_id!("!test:localhost");
|
||||
let (_, session) = account.create_group_session_pair_with_defaults(room_id).await;
|
||||
|
||||
session.mark_as_backed_up();
|
||||
|
||||
let changes =
|
||||
Changes { inbound_group_sessions: vec![session.clone()], ..Default::default() };
|
||||
|
||||
store.save_changes(changes).await.expect("Can't save group session");
|
||||
|
||||
assert_eq!(store.inbound_group_session_counts().await.unwrap().total, 1);
|
||||
assert_eq!(store.inbound_group_session_counts().await.unwrap().backed_up, 1);
|
||||
// Given we have backed up our session
|
||||
store
|
||||
.mark_inbound_group_sessions_as_backed_up("bkpver1", &[session_info(&session)])
|
||||
.await
|
||||
.expect("Failed to mark_inbound_group_sessions_as_backed_up.");
|
||||
|
||||
let to_back_up = store.inbound_group_sessions_for_backup(1).await.unwrap();
|
||||
assert_eq!(store.inbound_group_session_counts(Some("bkpver1")).await.unwrap().total, 1);
|
||||
assert_eq!(store.inbound_group_session_counts(Some("bkpver1")).await.unwrap().backed_up, 1);
|
||||
|
||||
// Sanity: before resetting, we have nothing to back up
|
||||
let to_back_up = store.inbound_group_sessions_for_backup("bkpver1", 1).await.unwrap();
|
||||
assert_eq!(to_back_up, vec![]);
|
||||
|
||||
// When we reset the backup
|
||||
store.reset_backup_state().await.unwrap();
|
||||
|
||||
let to_back_up = store.inbound_group_sessions_for_backup(1).await.unwrap();
|
||||
// Then after resetting, even if we supply the same backup version number, we need
|
||||
// to back up the session
|
||||
let to_back_up = store.inbound_group_sessions_for_backup("bkpver1", 1).await.unwrap();
|
||||
assert_eq!(to_back_up, vec![session]);
|
||||
}
|
||||
|
||||
@@ -438,7 +442,7 @@ macro_rules! cryptostore_integration_tests {
|
||||
loaded_session.export().await;
|
||||
|
||||
assert_eq!(store.get_inbound_group_sessions().await.unwrap().len(), 1);
|
||||
assert_eq!(store.inbound_group_session_counts().await.unwrap().total, 1);
|
||||
assert_eq!(store.inbound_group_session_counts(None).await.unwrap().total, 1);
|
||||
}
|
||||
|
||||
#[async_test]
|
||||
@@ -960,6 +964,10 @@ macro_rules! cryptostore_integration_tests {
|
||||
let loaded_2 = store.get_custom_value("B").await.unwrap();
|
||||
assert_eq!(None, loaded_2);
|
||||
}
|
||||
|
||||
fn session_info(session: &InboundGroupSession) -> (&RoomId, &str) {
|
||||
(&session.room_id(), &session.session_id())
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@@ -270,7 +270,10 @@ impl CryptoStore for MemoryStore {
|
||||
Ok(self.inbound_group_sessions.get_all())
|
||||
}
|
||||
|
||||
async fn inbound_group_session_counts(&self) -> Result<RoomKeyCounts> {
|
||||
async fn inbound_group_session_counts(
|
||||
&self,
|
||||
_backup_version: Option<&str>,
|
||||
) -> Result<RoomKeyCounts> {
|
||||
let backed_up =
|
||||
self.get_inbound_group_sessions().await?.into_iter().filter(|s| s.backed_up()).count();
|
||||
|
||||
@@ -279,6 +282,7 @@ impl CryptoStore for MemoryStore {
|
||||
|
||||
async fn inbound_group_sessions_for_backup(
|
||||
&self,
|
||||
_backup_version: &str,
|
||||
limit: usize,
|
||||
) -> Result<Vec<InboundGroupSession>> {
|
||||
Ok(self
|
||||
@@ -292,6 +296,7 @@ impl CryptoStore for MemoryStore {
|
||||
|
||||
async fn mark_inbound_group_sessions_as_backed_up(
|
||||
&self,
|
||||
_backup_version: &str,
|
||||
room_and_session_ids: &[(&RoomId, &str)],
|
||||
) -> Result<()> {
|
||||
for (room_id, session_id) in room_and_session_ids {
|
||||
@@ -758,22 +763,29 @@ mod integration_tests {
|
||||
self.0.get_inbound_group_sessions().await
|
||||
}
|
||||
|
||||
async fn inbound_group_session_counts(&self) -> Result<RoomKeyCounts, Self::Error> {
|
||||
self.0.inbound_group_session_counts().await
|
||||
async fn inbound_group_session_counts(
|
||||
&self,
|
||||
backup_version: Option<&str>,
|
||||
) -> Result<RoomKeyCounts, Self::Error> {
|
||||
self.0.inbound_group_session_counts(backup_version).await
|
||||
}
|
||||
|
||||
async fn inbound_group_sessions_for_backup(
|
||||
&self,
|
||||
backup_version: &str,
|
||||
limit: usize,
|
||||
) -> Result<Vec<InboundGroupSession>, Self::Error> {
|
||||
self.0.inbound_group_sessions_for_backup(limit).await
|
||||
self.0.inbound_group_sessions_for_backup(backup_version, limit).await
|
||||
}
|
||||
|
||||
async fn mark_inbound_group_sessions_as_backed_up(
|
||||
&self,
|
||||
backup_version: &str,
|
||||
room_and_session_ids: &[(&RoomId, &str)],
|
||||
) -> Result<(), Self::Error> {
|
||||
self.0.mark_inbound_group_sessions_as_backed_up(room_and_session_ids).await
|
||||
self.0
|
||||
.mark_inbound_group_sessions_as_backed_up(backup_version, room_and_session_ids)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn reset_backup_state(&self) -> Result<(), Self::Error> {
|
||||
|
||||
@@ -105,22 +105,43 @@ pub trait CryptoStore: AsyncTraitDeps {
|
||||
|
||||
/// Get the number inbound group sessions we have and how many of them are
|
||||
/// backed up.
|
||||
async fn inbound_group_session_counts(&self) -> Result<RoomKeyCounts, Self::Error>;
|
||||
async fn inbound_group_session_counts(
|
||||
&self,
|
||||
backup_version: Option<&str>,
|
||||
) -> Result<RoomKeyCounts, Self::Error>;
|
||||
|
||||
/// Get all the inbound group sessions we have not backed up yet.
|
||||
/// Return a batch of ['InboundGroupSession'] ("room keys") that have not
|
||||
/// yet been backed up in the supplied backup version.
|
||||
///
|
||||
/// The size of the returned `Vec` is <= `limit`.
|
||||
///
|
||||
/// Note: some implementations ignore `backup_version` and assume the
|
||||
/// current backup version, which is normally the same.
|
||||
async fn inbound_group_sessions_for_backup(
|
||||
&self,
|
||||
backup_version: &str,
|
||||
limit: usize,
|
||||
) -> Result<Vec<InboundGroupSession>, Self::Error>;
|
||||
|
||||
/// Mark the inbound group sessions with the supplied room and session IDs
|
||||
/// as backed up
|
||||
/// Store the fact that the supplied sessions were backed up into the backup
|
||||
/// with version `backup_version`.
|
||||
///
|
||||
/// Note: some implementations ignore `backup_version` and assume the
|
||||
/// current backup version, which is normally the same.
|
||||
async fn mark_inbound_group_sessions_as_backed_up(
|
||||
&self,
|
||||
backup_version: &str,
|
||||
room_and_session_ids: &[(&RoomId, &str)],
|
||||
) -> Result<(), Self::Error>;
|
||||
|
||||
/// Reset the backup state of all the stored inbound group sessions.
|
||||
///
|
||||
/// Note: this is mostly implemented by stores that ignore the
|
||||
/// `backup_version` argument on `inbound_group_sessions_for_backup` and
|
||||
/// `mark_inbound_group_sessions_as_backed_up`. Implementations that
|
||||
/// pay attention to the supplied backup version probably don't need to
|
||||
/// update their storage when the current backup version changes, so have
|
||||
/// empty implementations of this method.
|
||||
async fn reset_backup_state(&self) -> Result<(), Self::Error>;
|
||||
|
||||
/// Get the backup keys we have stored.
|
||||
@@ -331,23 +352,27 @@ impl<T: CryptoStore> CryptoStore for EraseCryptoStoreError<T> {
|
||||
self.0.get_inbound_group_sessions().await.map_err(Into::into)
|
||||
}
|
||||
|
||||
async fn inbound_group_session_counts(&self) -> Result<RoomKeyCounts> {
|
||||
self.0.inbound_group_session_counts().await.map_err(Into::into)
|
||||
async fn inbound_group_session_counts(
|
||||
&self,
|
||||
backup_version: Option<&str>,
|
||||
) -> Result<RoomKeyCounts> {
|
||||
self.0.inbound_group_session_counts(backup_version).await.map_err(Into::into)
|
||||
}
|
||||
|
||||
async fn inbound_group_sessions_for_backup(
|
||||
&self,
|
||||
backup_version: &str,
|
||||
limit: usize,
|
||||
) -> Result<Vec<InboundGroupSession>> {
|
||||
self.0.inbound_group_sessions_for_backup(limit).await.map_err(Into::into)
|
||||
self.0.inbound_group_sessions_for_backup(backup_version, limit).await.map_err(Into::into)
|
||||
}
|
||||
|
||||
async fn mark_inbound_group_sessions_as_backed_up(
|
||||
&self,
|
||||
backup_version: &str,
|
||||
room_and_session_ids: &[(&RoomId, &str)],
|
||||
) -> Result<()> {
|
||||
self.0
|
||||
.mark_inbound_group_sessions_as_backed_up(room_and_session_ids)
|
||||
.mark_inbound_group_sessions_as_backed_up(backup_version, room_and_session_ids)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
@@ -873,7 +873,7 @@ impl_crypto_store! {
|
||||
).await
|
||||
}
|
||||
|
||||
async fn inbound_group_session_counts(&self) -> Result<RoomKeyCounts> {
|
||||
async fn inbound_group_session_counts(&self, _backup_version: Option<&str>) -> Result<RoomKeyCounts> {
|
||||
let tx = self
|
||||
.inner
|
||||
.transaction_on_one_with_mode(
|
||||
@@ -889,6 +889,7 @@ impl_crypto_store! {
|
||||
|
||||
async fn inbound_group_sessions_for_backup(
|
||||
&self,
|
||||
_backup_version: &str,
|
||||
limit: usize,
|
||||
) -> Result<Vec<InboundGroupSession>> {
|
||||
let tx = self
|
||||
@@ -933,7 +934,10 @@ impl_crypto_store! {
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
async fn mark_inbound_group_sessions_as_backed_up(&self, room_and_session_ids: &[(&RoomId, &str)]) -> Result<()> {
|
||||
async fn mark_inbound_group_sessions_as_backed_up(&self,
|
||||
_backup_version: &str,
|
||||
room_and_session_ids: &[(&RoomId, &str)]
|
||||
) -> Result<()> {
|
||||
let tx = self
|
||||
.inner
|
||||
.transaction_on_one_with_mode(
|
||||
|
||||
@@ -477,7 +477,10 @@ trait SqliteObjectCryptoStoreExt: SqliteObjectExt {
|
||||
.await?)
|
||||
}
|
||||
|
||||
async fn get_inbound_group_session_counts(&self) -> Result<RoomKeyCounts> {
|
||||
async fn get_inbound_group_session_counts(
|
||||
&self,
|
||||
_backup_version: Option<&str>,
|
||||
) -> Result<RoomKeyCounts> {
|
||||
let total = self
|
||||
.query_row("SELECT count(*) FROM inbound_group_session", (), |row| row.get(0))
|
||||
.await?;
|
||||
@@ -940,12 +943,16 @@ impl CryptoStore for SqliteCryptoStore {
|
||||
.collect()
|
||||
}
|
||||
|
||||
async fn inbound_group_session_counts(&self) -> Result<RoomKeyCounts> {
|
||||
Ok(self.acquire().await?.get_inbound_group_session_counts().await?)
|
||||
async fn inbound_group_session_counts(
|
||||
&self,
|
||||
backup_version: Option<&str>,
|
||||
) -> Result<RoomKeyCounts> {
|
||||
Ok(self.acquire().await?.get_inbound_group_session_counts(backup_version).await?)
|
||||
}
|
||||
|
||||
async fn inbound_group_sessions_for_backup(
|
||||
&self,
|
||||
_backup_version: &str,
|
||||
limit: usize,
|
||||
) -> Result<Vec<InboundGroupSession>> {
|
||||
self.acquire()
|
||||
@@ -962,6 +969,7 @@ impl CryptoStore for SqliteCryptoStore {
|
||||
|
||||
async fn mark_inbound_group_sessions_as_backed_up(
|
||||
&self,
|
||||
_backup_version: &str,
|
||||
session_ids: &[(&RoomId, &str)],
|
||||
) -> Result<()> {
|
||||
Ok(self
|
||||
|
||||
Reference in New Issue
Block a user