Merge pull request #3253 from matrix-org/andybalaam/add-backup_version-arg

crypto: Add a backup_version argument to group session backup methods
This commit is contained in:
Andy Balaam
2024-04-09 16:44:50 +01:00
committed by GitHub
7 changed files with 126 additions and 44 deletions

View File

@@ -8,6 +8,12 @@ Changed:
Breaking changes:
- Add a `backup_version` argument to `CryptoStore`'s
`inbound_group_sessions_for_backup`,
`mark_inbound_group_sessions_as_backed_up` and
`inbound_group_session_counts` methods.
([#3253](https://github.com/matrix-org/matrix-rust-sdk/pull/3253))
- Rename the `OlmMachine::invalidate_group_session` method to
`OlmMachine::discard_room_key`

View File

@@ -397,7 +397,8 @@ impl BackupMachine {
/// Get the number of backed up room keys and the total number of room keys.
pub async fn room_key_counts(&self) -> Result<RoomKeyCounts, CryptoStoreError> {
self.store.inbound_group_session_counts().await
let backup_version = self.backup_key.read().await.as_ref().and_then(|k| k.backup_version());
self.store.inbound_group_session_counts(backup_version.as_deref()).await
}
/// Disable and reset our backup state.
@@ -476,7 +477,12 @@ impl BackupMachine {
trace!(request_id = ?r.request_id, keys = ?r.sessions, "Marking room keys as backed up");
self.store.mark_inbound_group_sessions_as_backed_up(&room_and_session_ids).await?;
self.store
.mark_inbound_group_sessions_as_backed_up(
&r.request.version,
&room_and_session_ids,
)
.await?;
trace!(
request_id = ?r.request_id,
@@ -514,7 +520,7 @@ impl BackupMachine {
};
let sessions =
self.store.inbound_group_sessions_for_backup(Self::BACKUP_BATCH_SIZE).await?;
self.store.inbound_group_sessions_for_backup(&version, Self::BACKUP_BATCH_SIZE).await?;
if sessions.is_empty() {
trace!(?backup_key, "No room keys need to be backed up");
@@ -619,6 +625,7 @@ mod tests {
use ruma::{device_id, room_id, user_id, CanonicalJsonValue, DeviceId, RoomId, UserId};
use serde_json::json;
use super::BackupMachine;
use crate::{
olm::BackedUpRoomKey, store::BackupDecryptionKey, types::RoomKeyBackupInfo, OlmError,
OlmMachine,
@@ -658,7 +665,10 @@ mod tests {
async fn backup_flow(machine: OlmMachine) -> Result<(), OlmError> {
let backup_machine = machine.backup_machine();
let counts = backup_machine.store.inbound_group_session_counts().await?;
let backup_version = current_backup_version(backup_machine).await;
let counts =
backup_machine.store.inbound_group_session_counts(backup_version.as_deref()).await?;
assert_eq!(counts.total, 0, "Initially no keys exist");
assert_eq!(counts.backed_up, 0, "Initially no backed up keys exist");
@@ -666,7 +676,8 @@ mod tests {
machine.create_outbound_group_session_with_defaults_test_helper(room_id()).await?;
machine.create_outbound_group_session_with_defaults_test_helper(room_id2()).await?;
let counts = backup_machine.store.inbound_group_session_counts().await?;
let counts =
backup_machine.store.inbound_group_session_counts(backup_version.as_deref()).await?;
assert_eq!(counts.total, 2, "Two room keys need to exist in the store");
assert_eq!(counts.backed_up, 0, "No room keys have been backed up yet");
@@ -685,8 +696,10 @@ mod tests {
);
backup_machine.mark_request_as_sent(&request_id).await?;
let backup_version = current_backup_version(backup_machine).await;
let counts = backup_machine.store.inbound_group_session_counts().await?;
let counts =
backup_machine.store.inbound_group_session_counts(backup_version.as_deref()).await?;
assert_eq!(counts.total, 2);
assert_eq!(counts.backed_up, 2, "All room keys have been backed up");
@@ -696,8 +709,10 @@ mod tests {
);
backup_machine.disable_backup().await?;
let backup_version = current_backup_version(backup_machine).await;
let counts = backup_machine.store.inbound_group_session_counts().await?;
let counts =
backup_machine.store.inbound_group_session_counts(backup_version.as_deref()).await?;
assert_eq!(counts.total, 2);
assert_eq!(
counts.backed_up, 0,
@@ -707,6 +722,10 @@ mod tests {
Ok(())
}
async fn current_backup_version(backup_machine: &BackupMachine) -> Option<String> {
backup_machine.backup_key.read().await.as_ref().and_then(|k| k.backup_version())
}
#[async_test]
async fn memory_store_backups() -> Result<(), OlmError> {
let machine = OlmMachine::new(alice_id(), alice_device_id()).await;

View File

@@ -323,10 +323,10 @@ macro_rules! cryptostore_integration_tests {
.unwrap();
assert_eq!(session, loaded_session);
assert_eq!(store.get_inbound_group_sessions().await.unwrap().len(), 1);
assert_eq!(store.inbound_group_session_counts().await.unwrap().total, 1);
assert_eq!(store.inbound_group_session_counts().await.unwrap().backed_up, 0);
assert_eq!(store.inbound_group_session_counts(None).await.unwrap().total, 1);
assert_eq!(store.inbound_group_session_counts(None).await.unwrap().backed_up, 0);
let to_back_up = store.inbound_group_sessions_for_backup(1).await.unwrap();
let to_back_up = store.inbound_group_sessions_for_backup("bkpver", 1).await.unwrap();
assert_eq!(to_back_up, vec![session])
}
@@ -342,14 +342,10 @@ macro_rules! cryptostore_integration_tests {
}
let changes = Changes { inbound_group_sessions: sessions.clone(), ..Default::default() };
store.save_changes(changes).await.expect("Can't save group session");
assert_eq!(store.inbound_group_sessions_for_backup(100).await.unwrap().len(), 10);
fn session_info(session: &InboundGroupSession) -> (&RoomId, &str) {
(&session.room_id(), &session.session_id())
}
assert_eq!(store.inbound_group_sessions_for_backup("bkpver", 100).await.unwrap().len(), 10);
// When I mark some as backed up
store.mark_inbound_group_sessions_as_backed_up(&[
store.mark_inbound_group_sessions_as_backed_up("bkpver", &[
session_info(&sessions[1]),
session_info(&sessions[3]),
session_info(&sessions[5]),
@@ -358,7 +354,7 @@ macro_rules! cryptostore_integration_tests {
]).await.expect("Failed to mark sessions as backed up");
// And ask which still need backing up
let to_back_up = store.inbound_group_sessions_for_backup(10).await.unwrap();
let to_back_up = store.inbound_group_sessions_for_backup("bkpver", 10).await.unwrap();
let needs_backing_up = |i: usize| to_back_up.iter().any(|s| s.session_id() == sessions[i].session_id());
// Then the sessions we said were backed up no longer need backing up
@@ -381,27 +377,35 @@ macro_rules! cryptostore_integration_tests {
async fn reset_inbound_group_session_for_backup() {
let (account, store) =
get_loaded_store("reset_inbound_group_session_for_backup").await;
assert_eq!(store.inbound_group_session_counts().await.unwrap().total, 0);
assert_eq!(store.inbound_group_session_counts(None).await.unwrap().total, 0);
let room_id = &room_id!("!test:localhost");
let (_, session) = account.create_group_session_pair_with_defaults(room_id).await;
session.mark_as_backed_up();
let changes =
Changes { inbound_group_sessions: vec![session.clone()], ..Default::default() };
store.save_changes(changes).await.expect("Can't save group session");
assert_eq!(store.inbound_group_session_counts().await.unwrap().total, 1);
assert_eq!(store.inbound_group_session_counts().await.unwrap().backed_up, 1);
// Given we have backed up our session
store
.mark_inbound_group_sessions_as_backed_up("bkpver1", &[session_info(&session)])
.await
.expect("Failed to mark_inbound_group_sessions_as_backed_up.");
let to_back_up = store.inbound_group_sessions_for_backup(1).await.unwrap();
assert_eq!(store.inbound_group_session_counts(Some("bkpver1")).await.unwrap().total, 1);
assert_eq!(store.inbound_group_session_counts(Some("bkpver1")).await.unwrap().backed_up, 1);
// Sanity: before resetting, we have nothing to back up
let to_back_up = store.inbound_group_sessions_for_backup("bkpver1", 1).await.unwrap();
assert_eq!(to_back_up, vec![]);
// When we reset the backup
store.reset_backup_state().await.unwrap();
let to_back_up = store.inbound_group_sessions_for_backup(1).await.unwrap();
// Then after resetting, even if we supply the same backup version number, we need
// to back up the session
let to_back_up = store.inbound_group_sessions_for_backup("bkpver1", 1).await.unwrap();
assert_eq!(to_back_up, vec![session]);
}
@@ -438,7 +442,7 @@ macro_rules! cryptostore_integration_tests {
loaded_session.export().await;
assert_eq!(store.get_inbound_group_sessions().await.unwrap().len(), 1);
assert_eq!(store.inbound_group_session_counts().await.unwrap().total, 1);
assert_eq!(store.inbound_group_session_counts(None).await.unwrap().total, 1);
}
#[async_test]
@@ -960,6 +964,10 @@ macro_rules! cryptostore_integration_tests {
let loaded_2 = store.get_custom_value("B").await.unwrap();
assert_eq!(None, loaded_2);
}
fn session_info(session: &InboundGroupSession) -> (&RoomId, &str) {
(&session.room_id(), &session.session_id())
}
}
};
}

View File

@@ -270,7 +270,10 @@ impl CryptoStore for MemoryStore {
Ok(self.inbound_group_sessions.get_all())
}
async fn inbound_group_session_counts(&self) -> Result<RoomKeyCounts> {
async fn inbound_group_session_counts(
&self,
_backup_version: Option<&str>,
) -> Result<RoomKeyCounts> {
let backed_up =
self.get_inbound_group_sessions().await?.into_iter().filter(|s| s.backed_up()).count();
@@ -279,6 +282,7 @@ impl CryptoStore for MemoryStore {
async fn inbound_group_sessions_for_backup(
&self,
_backup_version: &str,
limit: usize,
) -> Result<Vec<InboundGroupSession>> {
Ok(self
@@ -292,6 +296,7 @@ impl CryptoStore for MemoryStore {
async fn mark_inbound_group_sessions_as_backed_up(
&self,
_backup_version: &str,
room_and_session_ids: &[(&RoomId, &str)],
) -> Result<()> {
for (room_id, session_id) in room_and_session_ids {
@@ -758,22 +763,29 @@ mod integration_tests {
self.0.get_inbound_group_sessions().await
}
async fn inbound_group_session_counts(&self) -> Result<RoomKeyCounts, Self::Error> {
self.0.inbound_group_session_counts().await
async fn inbound_group_session_counts(
&self,
backup_version: Option<&str>,
) -> Result<RoomKeyCounts, Self::Error> {
self.0.inbound_group_session_counts(backup_version).await
}
async fn inbound_group_sessions_for_backup(
&self,
backup_version: &str,
limit: usize,
) -> Result<Vec<InboundGroupSession>, Self::Error> {
self.0.inbound_group_sessions_for_backup(limit).await
self.0.inbound_group_sessions_for_backup(backup_version, limit).await
}
async fn mark_inbound_group_sessions_as_backed_up(
&self,
backup_version: &str,
room_and_session_ids: &[(&RoomId, &str)],
) -> Result<(), Self::Error> {
self.0.mark_inbound_group_sessions_as_backed_up(room_and_session_ids).await
self.0
.mark_inbound_group_sessions_as_backed_up(backup_version, room_and_session_ids)
.await
}
async fn reset_backup_state(&self) -> Result<(), Self::Error> {

View File

@@ -105,22 +105,43 @@ pub trait CryptoStore: AsyncTraitDeps {
/// Get the number inbound group sessions we have and how many of them are
/// backed up.
async fn inbound_group_session_counts(&self) -> Result<RoomKeyCounts, Self::Error>;
async fn inbound_group_session_counts(
&self,
backup_version: Option<&str>,
) -> Result<RoomKeyCounts, Self::Error>;
/// Get all the inbound group sessions we have not backed up yet.
/// Return a batch of ['InboundGroupSession'] ("room keys") that have not
/// yet been backed up in the supplied backup version.
///
/// The size of the returned `Vec` is <= `limit`.
///
/// Note: some implementations ignore `backup_version` and assume the
/// current backup version, which is normally the same.
async fn inbound_group_sessions_for_backup(
&self,
backup_version: &str,
limit: usize,
) -> Result<Vec<InboundGroupSession>, Self::Error>;
/// Mark the inbound group sessions with the supplied room and session IDs
/// as backed up
/// Store the fact that the supplied sessions were backed up into the backup
/// with version `backup_version`.
///
/// Note: some implementations ignore `backup_version` and assume the
/// current backup version, which is normally the same.
async fn mark_inbound_group_sessions_as_backed_up(
&self,
backup_version: &str,
room_and_session_ids: &[(&RoomId, &str)],
) -> Result<(), Self::Error>;
/// Reset the backup state of all the stored inbound group sessions.
///
/// Note: this is mostly implemented by stores that ignore the
/// `backup_version` argument on `inbound_group_sessions_for_backup` and
/// `mark_inbound_group_sessions_as_backed_up`. Implementations that
/// pay attention to the supplied backup version probably don't need to
/// update their storage when the current backup version changes, so have
/// empty implementations of this method.
async fn reset_backup_state(&self) -> Result<(), Self::Error>;
/// Get the backup keys we have stored.
@@ -331,23 +352,27 @@ impl<T: CryptoStore> CryptoStore for EraseCryptoStoreError<T> {
self.0.get_inbound_group_sessions().await.map_err(Into::into)
}
async fn inbound_group_session_counts(&self) -> Result<RoomKeyCounts> {
self.0.inbound_group_session_counts().await.map_err(Into::into)
async fn inbound_group_session_counts(
&self,
backup_version: Option<&str>,
) -> Result<RoomKeyCounts> {
self.0.inbound_group_session_counts(backup_version).await.map_err(Into::into)
}
async fn inbound_group_sessions_for_backup(
&self,
backup_version: &str,
limit: usize,
) -> Result<Vec<InboundGroupSession>> {
self.0.inbound_group_sessions_for_backup(limit).await.map_err(Into::into)
self.0.inbound_group_sessions_for_backup(backup_version, limit).await.map_err(Into::into)
}
async fn mark_inbound_group_sessions_as_backed_up(
&self,
backup_version: &str,
room_and_session_ids: &[(&RoomId, &str)],
) -> Result<()> {
self.0
.mark_inbound_group_sessions_as_backed_up(room_and_session_ids)
.mark_inbound_group_sessions_as_backed_up(backup_version, room_and_session_ids)
.await
.map_err(Into::into)
}

View File

@@ -873,7 +873,7 @@ impl_crypto_store! {
).await
}
async fn inbound_group_session_counts(&self) -> Result<RoomKeyCounts> {
async fn inbound_group_session_counts(&self, _backup_version: Option<&str>) -> Result<RoomKeyCounts> {
let tx = self
.inner
.transaction_on_one_with_mode(
@@ -889,6 +889,7 @@ impl_crypto_store! {
async fn inbound_group_sessions_for_backup(
&self,
_backup_version: &str,
limit: usize,
) -> Result<Vec<InboundGroupSession>> {
let tx = self
@@ -933,7 +934,10 @@ impl_crypto_store! {
Ok(result)
}
async fn mark_inbound_group_sessions_as_backed_up(&self, room_and_session_ids: &[(&RoomId, &str)]) -> Result<()> {
async fn mark_inbound_group_sessions_as_backed_up(&self,
_backup_version: &str,
room_and_session_ids: &[(&RoomId, &str)]
) -> Result<()> {
let tx = self
.inner
.transaction_on_one_with_mode(

View File

@@ -477,7 +477,10 @@ trait SqliteObjectCryptoStoreExt: SqliteObjectExt {
.await?)
}
async fn get_inbound_group_session_counts(&self) -> Result<RoomKeyCounts> {
async fn get_inbound_group_session_counts(
&self,
_backup_version: Option<&str>,
) -> Result<RoomKeyCounts> {
let total = self
.query_row("SELECT count(*) FROM inbound_group_session", (), |row| row.get(0))
.await?;
@@ -940,12 +943,16 @@ impl CryptoStore for SqliteCryptoStore {
.collect()
}
async fn inbound_group_session_counts(&self) -> Result<RoomKeyCounts> {
Ok(self.acquire().await?.get_inbound_group_session_counts().await?)
async fn inbound_group_session_counts(
&self,
backup_version: Option<&str>,
) -> Result<RoomKeyCounts> {
Ok(self.acquire().await?.get_inbound_group_session_counts(backup_version).await?)
}
async fn inbound_group_sessions_for_backup(
&self,
_backup_version: &str,
limit: usize,
) -> Result<Vec<InboundGroupSession>> {
self.acquire()
@@ -962,6 +969,7 @@ impl CryptoStore for SqliteCryptoStore {
async fn mark_inbound_group_sessions_as_backed_up(
&self,
_backup_version: &str,
session_ids: &[(&RoomId, &str)],
) -> Result<()> {
Ok(self