Remove Olm Session cache from the individual crypto store implementations

This commit is contained in:
Damir Jelić
2024-07-12 12:47:13 +02:00
parent c83fa3a532
commit 96b615ba8e
14 changed files with 106 additions and 339 deletions

View File

@@ -338,10 +338,6 @@ async fn save_changes(
processed_steps += 1;
listener(processed_steps, total_steps);
// The Sessions were created with incorrect device keys, so clear the cache
// so that they'll get recreated with correct ones.
store.clear_caches().await;
Ok(())
}

View File

@@ -1528,17 +1528,6 @@ impl OlmMachine {
}
.into()
}
/// Clear any in-memory caches because they may be out of sync with the
/// underlying data store.
///
/// The crypto store layer is caching olm sessions for a given device.
/// When used in a multi-process context this cache will get outdated.
/// If the machine is used by another process, the cache must be
/// invalidating when the main process is resumed.
pub async fn clear_crypto_cache(&self) {
self.inner.clear_crypto_cache().await
}
}
impl OlmMachine {

View File

@@ -256,7 +256,6 @@ impl BaseClient {
// Recreate the `OlmMachine` and wipe the in-memory cache in the store
// because we suspect it has stale data.
self.crypto_store.clear_caches().await;
let olm_machine = OlmMachine::with_store(
&session_meta.user_id,
&session_meta.device_id,

View File

@@ -30,7 +30,6 @@ use ruma::{
};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tokio::sync::Mutex;
use tracing::{instrument, trace, warn};
use vodozemac::{olm::SessionConfig, Curve25519PublicKey, Ed25519PublicKey};
@@ -339,13 +338,7 @@ impl Device {
)
}
/// Get the Olm sessions that belong to this device.
pub(crate) async fn get_sessions(&self) -> StoreResult<Option<Arc<Mutex<Vec<Session>>>>> {
let Some(k) = self.curve25519_key() else { return Ok(None) };
self.verification_machine.store.get_sessions(&k.to_base64()).await
}
#[cfg(test)]
/// Get the most recently created session that belongs to this device.
pub(crate) async fn get_most_recent_session(&self) -> OlmResult<Option<Session>> {
self.inner.get_most_recent_session(self.verification_machine.store.inner()).await
}
@@ -693,9 +686,7 @@ impl DeviceData {
store: &CryptoStoreWrapper,
) -> OlmResult<Option<Session>> {
if let Some(sender_key) = self.curve25519_key() {
if let Some(s) = store.get_sessions(&sender_key.to_base64()).await? {
let mut sessions = s.lock().await;
if let Some(mut sessions) = store.get_sessions(&sender_key.to_base64()).await? {
sessions.sort_by_key(|s| s.creation_time);
Ok(sessions.last().cloned())

View File

@@ -2300,13 +2300,6 @@ impl OlmMachine {
let account = cache.account().await?;
Ok(account.uploaded_key_count())
}
/// Clear any in-memory caches because they may be out of sync with the
/// underlying data store.
pub async fn clear_crypto_cache(&self) {
let crypto_store = self.store().crypto_store();
crypto_store.as_ref().clear_caches().await;
}
}
fn sender_data_to_verification_state(
@@ -2937,7 +2930,7 @@ pub(crate) mod tests {
.unwrap()
.unwrap();
assert!(!session.lock().await.is_empty())
assert!(!session.is_empty())
}
#[async_test]
@@ -2976,11 +2969,13 @@ pub(crate) mod tests {
.await;
}
let mut changes = Changes::default();
// Since the sessions are created quickly in succession and our timestamps have
// a resolution in seconds, it's very likely that we're going to end up
// with the same timestamps, so we manually masage them to be 10s apart.
let session_id = {
let sessions = alice_machine
let mut sessions = alice_machine
.store()
.get_sessions(&bob_machine.identity_keys().curve25519.to_base64())
.await
@@ -2988,7 +2983,6 @@ pub(crate) mod tests {
.unwrap();
let mut use_time = SystemTime::now();
let mut sessions = sessions.lock().await;
let mut session_id = None;
@@ -3001,11 +2995,14 @@ pub(crate) mod tests {
use_time += Duration::from_secs(10);
session_id = Some(session.session_id().to_owned());
changes.sessions.push(session.clone());
}
session_id.unwrap()
};
alice_machine.store().save_changes(changes).await.unwrap();
let newest_session = device.get_most_recent_session().await.unwrap().unwrap();
assert_eq!(

View File

@@ -1250,11 +1250,9 @@ impl Account {
let mut errors_by_olm_session = Vec::new();
if let Some(sessions) = existing_sessions {
let sessions = &mut *sessions.lock().await;
// Try to decrypt the message using each Session we share with the
// given curve25519 sender key.
for session in sessions.iter_mut() {
for mut session in sessions {
match session.decrypt(message).await {
Ok(p) => {
// success!
@@ -1282,7 +1280,7 @@ impl Account {
OlmMessage::PreKey(prekey_message) => {
// First try to decrypt using an existing session.
if let Some(sessions) = existing_sessions {
for session in sessions.lock().await.iter_mut() {
for mut session in sessions {
if prekey_message.session_id() != session.session_id() {
// wrong session
continue;

View File

@@ -102,41 +102,32 @@ impl SessionManager {
&self,
sender: &UserId,
curve_key: Curve25519PublicKey,
) -> StoreResult<()> {
) -> OlmResult<()> {
if let Some(device) = self.store.get_device_from_curve_key(sender, curve_key).await? {
let sessions = device.get_sessions().await?;
if let Some(session) = device.get_most_recent_session().await? {
info!(sender_key = ?curve_key, "Marking session to be unwedged");
if let Some(sessions) = sessions {
let mut sessions = sessions.lock().await;
sessions.sort_by_key(|s| s.creation_time);
let creation_time = Duration::from_secs(session.creation_time.get().into());
let now = Duration::from_secs(SecondsSinceUnixEpoch::now().get().into());
let session = sessions.first();
let should_unwedge = now
.checked_sub(creation_time)
.map(|elapsed| elapsed > Self::UNWEDGING_INTERVAL)
.unwrap_or(true);
if let Some(session) = session {
info!(sender_key = ?curve_key, "Marking session to be unwedged");
let creation_time = Duration::from_secs(session.creation_time.get().into());
let now = Duration::from_secs(SecondsSinceUnixEpoch::now().get().into());
let should_unwedge = now
.checked_sub(creation_time)
.map(|elapsed| elapsed > Self::UNWEDGING_INTERVAL)
.unwrap_or(true);
if should_unwedge {
self.users_for_key_claim
.write()
.unwrap()
.entry(device.user_id().to_owned())
.or_default()
.insert(device.device_id().into());
self.wedged_devices
.write()
.unwrap()
.entry(device.user_id().to_owned())
.or_default()
.insert(device.device_id().into());
}
if should_unwedge {
self.users_for_key_claim
.write()
.unwrap()
.entry(device.user_id().to_owned())
.or_default()
.insert(device.device_id().into());
self.wedged_devices
.write()
.unwrap()
.entry(device.user_id().to_owned())
.or_default()
.insert(device.device_id().into());
}
}
}
@@ -253,11 +244,8 @@ impl SessionManager {
} else if let Some(sender_key) = device.curve25519_key() {
let sessions = self.store.get_sessions(&sender_key.to_base64()).await?;
let is_missing = if let Some(sessions) = sessions {
sessions.lock().await.is_empty()
} else {
true
};
let is_missing =
if let Some(sessions) = sessions { sessions.is_empty() } else { true };
let is_timed_out = self.is_user_timed_out(&user_id, &device_id);

View File

@@ -217,7 +217,11 @@ macro_rules! cryptostore_integration_tests {
.await
.expect("Can't save account");
let changes = Changes { sessions: vec![session.clone()], ..Default::default() };
let changes = Changes {
sessions: vec![session.clone()],
devices: DeviceChanges { new: vec![DeviceData::from_account(&account)], ..Default::default() },
..Default::default()
};
store.save_changes(changes).await.unwrap();
@@ -226,9 +230,9 @@ macro_rules! cryptostore_integration_tests {
.await
.expect("Can't load sessions")
.unwrap();
let loaded_session = sessions.lock().await.get(0).cloned().unwrap();
let loaded_session = sessions.get(0).cloned().expect("We should find the session in the store.");
assert_eq!(&session, &loaded_session);
assert_eq!(&session, &loaded_session, "The loaded session should be the same one we put into the store.");
}
#[async_test]
@@ -263,8 +267,7 @@ macro_rules! cryptostore_integration_tests {
store.save_changes(changes).await.unwrap();
let sessions = store.get_sessions(&sender_key).await.unwrap().unwrap();
let sessions_lock = sessions.lock().await;
let session = &sessions_lock[0];
let session = &sessions[0];
assert_eq!(session_id, session.session_id());
@@ -279,8 +282,7 @@ macro_rules! cryptostore_integration_tests {
assert_eq!(account, loaded_account);
let sessions = store.get_sessions(&sender_key).await.unwrap().unwrap();
let sessions_lock = sessions.lock().await;
let session = &sessions_lock[0];
let session = &sessions[0];
assert_eq!(session_id, session.session_id());
}

View File

@@ -15,7 +15,7 @@
use std::{
collections::{hash_map::Entry, BTreeMap, HashMap, HashSet},
convert::Infallible,
sync::{Arc, RwLock as StdRwLock},
sync::RwLock as StdRwLock,
time::{Duration, Instant},
};
@@ -24,11 +24,11 @@ use ruma::{
events::secret::request::SecretName, DeviceId, OwnedDeviceId, OwnedRoomId, OwnedTransactionId,
OwnedUserId, RoomId, TransactionId, UserId,
};
use tokio::sync::{Mutex, RwLock};
use tokio::sync::RwLock;
use tracing::warn;
use super::{
caches::{DeviceStore, GroupSessionStore, SessionStore},
caches::{DeviceStore, GroupSessionStore},
Account, BackupKeys, Changes, CryptoStore, InboundGroupSession, PendingChanges, RoomKeyCounts,
RoomSettings, Session,
};
@@ -69,7 +69,7 @@ impl BackupVersion {
#[derive(Debug)]
pub struct MemoryStore {
account: StdRwLock<Option<Account>>,
sessions: SessionStore,
sessions: StdRwLock<BTreeMap<String, Vec<Session>>>,
inbound_group_sessions: GroupSessionStore,
/// Map room id -> session id -> backup order number
@@ -99,7 +99,7 @@ impl Default for MemoryStore {
fn default() -> Self {
MemoryStore {
account: Default::default(),
sessions: SessionStore::new(),
sessions: Default::default(),
inbound_group_sessions: GroupSessionStore::new(),
inbound_group_sessions_backed_up_to: Default::default(),
outbound_group_sessions: Default::default(),
@@ -139,9 +139,17 @@ impl MemoryStore {
}
}
async fn save_sessions(&self, sessions: Vec<Session>) {
fn save_sessions(&self, sessions: Vec<Session>) {
let mut session_store = self.sessions.write().unwrap();
for session in sessions {
let _ = self.sessions.add(session.clone()).await;
let entry = session_store.entry(session.sender_key().to_base64()).or_default();
if let Some(old_entry) = entry.iter_mut().find(|entry| &session == *entry) {
*old_entry = session;
} else {
entry.push(session);
}
}
}
@@ -190,14 +198,6 @@ type Result<T> = std::result::Result<T, Infallible>;
impl CryptoStore for MemoryStore {
type Error = Infallible;
async fn clear_caches(&self) {
// no-op: it makes no sense to delete fields here as we would forget our
// identity, etc Effectively we have no caches as the fields
// *are* the underlying store. Calling this method only makes
// sense if there is some other layer (e.g disk) persistence
// happening.
}
async fn load_account(&self) -> Result<Option<Account>> {
Ok(self.account.read().unwrap().as_ref().map(|acc| acc.deep_clone()))
}
@@ -219,7 +219,7 @@ impl CryptoStore for MemoryStore {
}
async fn save_changes(&self, changes: Changes) -> Result<()> {
self.save_sessions(changes.sessions).await;
self.save_sessions(changes.sessions);
self.save_inbound_group_sessions(changes.inbound_group_sessions, None).await?;
self.save_outbound_group_sessions(changes.outbound_group_sessions);
self.save_private_identity(changes.private_identity);
@@ -327,8 +327,8 @@ impl CryptoStore for MemoryStore {
Ok(())
}
async fn get_sessions(&self, sender_key: &str) -> Result<Option<Arc<Mutex<Vec<Session>>>>> {
Ok(self.sessions.get(sender_key))
async fn get_sessions(&self, sender_key: &str) -> Result<Option<Vec<Session>>> {
Ok(self.sessions.read().unwrap().get(sender_key).cloned())
}
async fn get_inbound_group_session(
@@ -638,10 +638,9 @@ mod tests {
assert!(store.load_account().await.unwrap().is_none());
store.save_pending_changes(PendingChanges { account: Some(account) }).await.unwrap();
store.save_sessions(vec![session.clone()]).await;
store.save_sessions(vec![session.clone()]);
let sessions = store.get_sessions(&session.sender_key.to_base64()).await.unwrap().unwrap();
let sessions = sessions.lock().await;
let loaded_session = &sessions[0];
@@ -1161,10 +1160,6 @@ mod integration_tests {
impl CryptoStore for PersistentMemoryStore {
type Error = <MemoryStore as CryptoStore>::Error;
async fn clear_caches(&self) {
self.0.clear_caches().await
}
async fn load_account(&self) -> Result<Option<Account>, Self::Error> {
self.0.load_account().await
}
@@ -1192,7 +1187,7 @@ mod integration_tests {
async fn get_sessions(
&self,
sender_key: &str,
) -> Result<Option<Arc<tokio::sync::Mutex<Vec<Session>>>>, Self::Error> {
) -> Result<Option<Vec<Session>>, Self::Error> {
self.0.get_sessions(sender_key).await
}

View File

@@ -345,7 +345,6 @@ impl<'a> SyncedKeyQueryManager<'a> {
#[derive(Debug)]
pub(crate) struct StoreCache {
store: Arc<CryptoStoreWrapper>,
tracked_users: StdRwLock<BTreeSet<OwnedUserId>>,
loaded_tracked_users: RwLock<bool>,
account: Mutex<Option<Account>>,

View File

@@ -19,7 +19,6 @@ use matrix_sdk_common::AsyncTraitDeps;
use ruma::{
events::secret::request::SecretName, DeviceId, OwnedDeviceId, RoomId, TransactionId, UserId,
};
use tokio::sync::Mutex;
use super::{
BackupKeys, Changes, CryptoStoreError, PendingChanges, Result, RoomKeyCounts, RoomSettings,
@@ -85,10 +84,7 @@ pub trait CryptoStore: AsyncTraitDeps {
/// # Arguments
///
/// * `sender_key` - The sender key that was used to establish the sessions.
async fn get_sessions(
&self,
sender_key: &str,
) -> Result<Option<Arc<Mutex<Vec<Session>>>>, Self::Error>;
async fn get_sessions(&self, sender_key: &str) -> Result<Option<Vec<Session>>, Self::Error>;
/// Get the inbound group session from our store.
///
@@ -324,13 +320,6 @@ pub trait CryptoStore: AsyncTraitDeps {
/// Load the next-batch token for a to-device query, if any.
async fn next_batch_token(&self) -> Result<Option<String>, Self::Error>;
/// Clear any in-memory caches because they may be out of sync with the
/// underlying data store.
///
/// If the store does not have any underlying persistence (e.g in-memory
/// store) then this should be a no-op.
async fn clear_caches(&self);
}
#[repr(transparent)]
@@ -348,10 +337,6 @@ impl<T: fmt::Debug> fmt::Debug for EraseCryptoStoreError<T> {
impl<T: CryptoStore> CryptoStore for EraseCryptoStoreError<T> {
type Error = CryptoStoreError;
async fn clear_caches(&self) {
self.0.clear_caches().await
}
async fn load_account(&self) -> Result<Option<Account>> {
self.0.load_account().await.map_err(Into::into)
}
@@ -376,7 +361,7 @@ impl<T: CryptoStore> CryptoStore for EraseCryptoStoreError<T> {
self.0.save_inbound_group_sessions(sessions, backed_up_to_version).await.map_err(Into::into)
}
async fn get_sessions(&self, sender_key: &str) -> Result<Option<Arc<Mutex<Vec<Session>>>>> {
async fn get_sessions(&self, sender_key: &str) -> Result<Option<Vec<Session>>> {
self.0.get_sessions(sender_key).await.map_err(Into::into)
}

View File

@@ -52,7 +52,7 @@ use tracing::{error, info, trace, warn};
use crate::{
error::SignatureError,
gossiping::{GossipMachine, GossipRequest},
olm::{PrivateCrossSigningIdentity, Session, StaticAccountData},
olm::{PrivateCrossSigningIdentity, StaticAccountData},
store::{Changes, CryptoStoreWrapper},
types::Signatures,
CryptoStoreError, DeviceData, LocalTrust, OutgoingVerificationRequest, OwnUserIdentityData,
@@ -166,13 +166,6 @@ impl VerificationStore {
self.inner.get_user_devices(user_id).await
}
pub async fn get_sessions(
&self,
sender_key: &str,
) -> Result<Option<Arc<Mutex<Vec<Session>>>>, CryptoStoreError> {
self.inner.get_sessions(sender_key).await
}
/// Get the signatures that have signed our own device.
pub async fn device_signatures(&self) -> Result<Option<Signatures>, CryptoStoreError> {
Ok(self

View File

@@ -27,8 +27,8 @@ use matrix_sdk_crypto::{
PrivateCrossSigningIdentity, Session, StaticAccountData,
},
store::{
caches::SessionStore, BackupKeys, Changes, CryptoStore, CryptoStoreError, PendingChanges,
RoomKeyCounts, RoomSettings,
BackupKeys, Changes, CryptoStore, CryptoStoreError, PendingChanges, RoomKeyCounts,
RoomSettings,
},
types::events::room_key_withheld::RoomKeyWithheldEvent,
vodozemac::base64_encode,
@@ -112,7 +112,6 @@ pub struct IndexeddbCryptoStore {
pub(crate) inner: IdbDatabase,
serializer: IndexeddbSerializer,
session_cache: SessionStore,
save_changes_lock: Arc<Mutex<()>>,
}
@@ -261,11 +260,9 @@ impl IndexeddbCryptoStore {
let serializer = IndexeddbSerializer::new(store_cipher);
debug!("IndexedDbCryptoStore: opening main store {name}");
let db = open_and_upgrade_db(&name, &serializer).await?;
let session_cache = SessionStore::new();
Ok(Self {
name,
session_cache,
inner: db,
serializer,
static_account: RwLock::new(None),
@@ -462,10 +459,7 @@ impl IndexeddbCryptoStore {
/// Returns a tuple where the first item is a `PendingIndexeddbChanges`
/// struct, and the second item is a boolean indicating whether the session
/// cache should be cleared.
async fn prepare_for_transaction(
&self,
changes: &Changes,
) -> Result<(PendingIndexeddbChanges, bool)> {
async fn prepare_for_transaction(&self, changes: &Changes) -> Result<PendingIndexeddbChanges> {
let mut indexeddb_changes = PendingIndexeddbChanges::new();
let private_identity_pickle =
@@ -553,17 +547,7 @@ impl IndexeddbCryptoStore {
let mut device_store = indexeddb_changes.get(keys::DEVICES);
let account_info = self.get_static_account();
let mut clear_caches = false;
for device in device_changes.new.iter().chain(&device_changes.changed) {
// If our own device key changes, we need to clear the session
// cache because the sessions contain a copy of our device key, and
// we want the sessions to use the new version.
if account_info.as_ref().is_some_and(|info| {
info.user_id == device.user_id() && info.device_id == device.device_id()
}) {
clear_caches = true;
}
let key =
self.serializer.encode_key(keys::DEVICES, (device.user_id(), device.device_id()));
let device = self.serializer.serialize_value(&device)?;
@@ -646,7 +630,7 @@ impl IndexeddbCryptoStore {
}
}
Ok((indexeddb_changes, clear_caches))
Ok(indexeddb_changes)
}
}
@@ -726,7 +710,7 @@ impl_crypto_store! {
// TODO: #2000 should make this lock go away, or change its shape.
let _guard = self.save_changes_lock.lock().await;
let (indexeddb_changes, clear_caches) = self.prepare_for_transaction(&changes).await?;
let indexeddb_changes = self.prepare_for_transaction(&changes).await?;
let stores = indexeddb_changes.touched_stores();
@@ -742,17 +726,6 @@ impl_crypto_store! {
tx.await.into_result()?;
if clear_caches {
self.clear_caches().await;
} else {
// All good, let's update our caches:indexeddb.
// We only do this if clear_caches is false, because the sessions may
// have been created using old device_keys.
for session in changes.sessions {
self.session_cache.add(session).await;
}
}
Ok(())
}
@@ -894,15 +867,14 @@ impl_crypto_store! {
}
}
async fn get_sessions(&self, sender_key: &str) -> Result<Option<Arc<Mutex<Vec<Session>>>>> {
if self.session_cache.get(sender_key).is_none() {
async fn get_sessions(&self, sender_key: &str) -> Result<Option<Vec<Session>>> {
let device_keys = self.get_own_device()
.await?
.as_device_keys()
.clone();
let range = self.serializer.encode_to_range(keys::SESSION, sender_key)?;
let sessions: Vec<Session> = self
let range = self.serializer.encode_to_range(keys::SESSION, sender_key)?;
let sessions: Vec<Session> = self
.inner
.transaction_on_one_with_mode(keys::SESSION, IdbTransactionMode::Readonly)?
.object_store(keys::SESSION)?
@@ -918,10 +890,11 @@ impl_crypto_store! {
}))
.collect::<Result<Vec<Session>>>()?;
self.session_cache.set_for_sender(sender_key, sessions);
if sessions.is_empty() {
Ok(None)
} else {
Ok(Some(sessions))
}
Ok(self.session_cache.get(sender_key))
}
async fn get_inbound_group_session(
@@ -1375,13 +1348,6 @@ impl_crypto_store! {
}
}
}
#[allow(clippy::unused_async)] // Mandated by trait on wasm.
async fn clear_caches(&self) {
self.session_cache.clear()
// We don't need to clear `static_account` as it only contains immutable data
// therefore cannot get out of sync with the underlying store.
}
}
impl Drop for IndexeddbCryptoStore {
@@ -1764,44 +1730,6 @@ mod tests {
.expect("Can't create store without passphrase"),
}
}
#[async_test]
async fn cache_cleared_after_device_update() {
let store = get_store("cache_cleared_after_device_update", None, true).await;
// Given we created a session and saved it in the store
let (account, session) = cryptostore_integration_tests::get_account_and_session().await;
let sender_key = session.sender_key.to_base64();
store
.save_pending_changes(PendingChanges { account: Some(account.deep_clone()) })
.await
.expect("Can't save account");
let changes = Changes { sessions: vec![session.clone()], ..Default::default() };
store.save_changes(changes).await.unwrap();
store.session_cache.get(&sender_key).expect("We should have a session");
// When we save a new version of our device keys
store
.save_changes(Changes {
devices: DeviceChanges {
new: vec![DeviceData::from_account(&account)],
..Default::default()
},
..Default::default()
})
.await
.unwrap();
// Then the session is no longer in the cache
assert!(
store.session_cache.get(&sender_key).is_none(),
"Session should not be in the cache!"
);
}
cryptostore_integration_tests!();
}
#[cfg(all(test, target_arch = "wasm32"))]

View File

@@ -27,10 +27,7 @@ use matrix_sdk_crypto::{
InboundGroupSession, OutboundGroupSession, PickledInboundGroupSession,
PrivateCrossSigningIdentity, Session, StaticAccountData,
},
store::{
caches::SessionStore, BackupKeys, Changes, CryptoStore, PendingChanges, RoomKeyCounts,
RoomSettings,
},
store::{BackupKeys, Changes, CryptoStore, PendingChanges, RoomKeyCounts, RoomSettings},
types::events::room_key_withheld::RoomKeyWithheldEvent,
Account, DeviceData, GossipRequest, GossippedSecret, SecretInfo, TrackedUser, UserIdentityData,
};
@@ -63,7 +60,6 @@ pub struct SqliteCryptoStore {
// DB values cached in memory
static_account: Arc<RwLock<Option<StaticAccountData>>>,
session_cache: SessionStore,
save_changes_lock: Arc<Mutex<()>>,
}
@@ -112,7 +108,6 @@ impl SqliteCryptoStore {
path: None,
pool,
static_account: Arc::new(RwLock::new(None)),
session_cache: SessionStore::new(),
save_changes_lock: Default::default(),
})
}
@@ -681,13 +676,6 @@ impl SqliteObjectCryptoStoreExt for deadpool_sqlite::Object {}
impl CryptoStore for SqliteCryptoStore {
type Error = Error;
async fn clear_caches(&self) {
self.session_cache.clear()
// We don't need to clear `static_account` as it only contains immutable
// data therefore cannot get out of sync with the underlying
// store.
}
async fn load_account(&self) -> Result<Option<Account>> {
let conn = self.acquire().await?;
if let Some(pickle) = conn.get_kv("account").await? {
@@ -754,13 +742,12 @@ impl CryptoStore for SqliteCryptoStore {
if let Some(i) = changes.private_identity { Some(i.pickle().await) } else { None };
let mut session_changes = Vec::new();
for session in changes.sessions {
let session_id = self.encode_key("session", session.session_id());
let sender_key = self.encode_key("session", session.sender_key().to_base64());
let pickle = session.pickle().await;
session_changes.push((session_id, sender_key, pickle));
self.session_cache.add(session).await;
}
let mut inbound_session_changes = Vec::new();
@@ -779,11 +766,9 @@ impl CryptoStore for SqliteCryptoStore {
}
let this = self.clone();
let clear_caches = self
.acquire()
self.acquire()
.await?
.with_transaction(move |txn| {
let mut clear_caches = false;
if let Some(pickled_private_identity) = &pickled_private_identity {
let serialized_private_identity =
this.serialize_value(pickled_private_identity)?;
@@ -805,16 +790,7 @@ impl CryptoStore for SqliteCryptoStore {
txn.set_kv("backup_version_v1", &serialized_backup_version)?;
}
let account_info = this.get_static_account();
for device in changes.devices.new.iter().chain(&changes.devices.changed) {
// If our own device key changes, we need to clear the
// session cache because the sessions contain a copy of our
// device key.
if account_info.clone().is_some_and(|info| {
info.user_id == device.user_id() && info.device_id == device.device_id()
}) {
clear_caches = true;
}
let user_id = this.encode_key("device", device.user_id().as_bytes());
let device_id = this.encode_key("device", device.device_id().as_bytes());
let data = this.serialize_value(&device)?;
@@ -885,14 +861,10 @@ impl CryptoStore for SqliteCryptoStore {
txn.set_secret(&secret_name, &value)?;
}
Ok::<_, Error>(clear_caches)
Ok::<_, Error>(())
})
.await?;
if clear_caches {
self.clear_caches().await;
}
Ok(())
}
@@ -918,27 +890,26 @@ impl CryptoStore for SqliteCryptoStore {
self.save_changes(Changes { inbound_group_sessions: sessions, ..Changes::default() }).await
}
async fn get_sessions(&self, sender_key: &str) -> Result<Option<Arc<Mutex<Vec<Session>>>>> {
if self.session_cache.get(sender_key).is_none() {
let device_keys = self.get_own_device().await?.as_device_keys().clone();
async fn get_sessions(&self, sender_key: &str) -> Result<Option<Vec<Session>>> {
let device_keys = self.get_own_device().await?.as_device_keys().clone();
let sessions = self
.acquire()
.await?
.get_sessions_for_sender_key(self.encode_key("session", sender_key.as_bytes()))
.await?
.into_iter()
.map(|bytes| {
let pickle = self.deserialize_value(&bytes)?;
Session::from_pickle(device_keys.clone(), pickle)
.map_err(|_| Error::AccountUnset)
})
.collect::<Result<_>>()?;
let sessions: Vec<_> = self
.acquire()
.await?
.get_sessions_for_sender_key(self.encode_key("session", sender_key.as_bytes()))
.await?
.into_iter()
.map(|bytes| {
let pickle = self.deserialize_value(&bytes)?;
Session::from_pickle(device_keys.clone(), pickle).map_err(|_| Error::AccountUnset)
})
.collect::<Result<_>>()?;
self.session_cache.set_for_sender(sender_key, sessions);
if sessions.is_empty() {
Ok(None)
} else {
Ok(Some(sessions))
}
Ok(self.session_cache.get(sender_key))
}
#[instrument(skip(self))]
@@ -1122,7 +1093,11 @@ impl CryptoStore for SqliteCryptoStore {
async fn get_own_device(&self) -> Result<DeviceData> {
let account_info = self.get_static_account().ok_or(Error::AccountUnset)?;
Ok(self.get_device(&account_info.user_id, &account_info.device_id).await?.unwrap())
Ok(self
.get_device(&account_info.user_id, &account_info.device_id)
.await?
.expect("We should be able to find our own device."))
}
async fn get_user_identity(&self, user_id: &UserId) -> Result<Option<UserIdentityData>> {
@@ -1451,12 +1426,7 @@ mod tests {
#[cfg(test)]
mod encrypted_tests {
use matrix_sdk_crypto::{
cryptostore_integration_tests, cryptostore_integration_tests_time,
store::{Changes, CryptoStore as _, DeviceChanges, PendingChanges},
DeviceData,
};
use matrix_sdk_test::async_test;
use matrix_sdk_crypto::{cryptostore_integration_tests, cryptostore_integration_tests_time};
use once_cell::sync::Lazy;
use tempfile::{tempdir, TempDir};
use tokio::fs;
@@ -1482,69 +1452,6 @@ mod encrypted_tests {
.expect("Can't create a passphrase protected store")
}
#[async_test]
async fn cache_cleared() {
let store = get_store("cache_cleared", None, true).await;
// Given we created a session and saved it in the store
let (account, session) = cryptostore_integration_tests::get_account_and_session().await;
let sender_key = session.sender_key.to_base64();
store
.save_pending_changes(PendingChanges { account: Some(account.deep_clone()) })
.await
.expect("Can't save account");
let changes = Changes { sessions: vec![session.clone()], ..Default::default() };
store.save_changes(changes).await.unwrap();
store.session_cache.get(&sender_key).expect("We should have a session");
// When we clear the caches
store.clear_caches().await;
// Then the session is no longer in the cache
assert!(
store.session_cache.get(&sender_key).is_none(),
"Session should not be in the cache!"
);
}
#[async_test]
async fn cache_cleared_after_device_update() {
let store = get_store("cache_cleared_after_device_update", None, true).await;
// Given we created a session and saved it in the store
let (account, session) = cryptostore_integration_tests::get_account_and_session().await;
let sender_key = session.sender_key.to_base64();
store
.save_pending_changes(PendingChanges { account: Some(account.deep_clone()) })
.await
.expect("Can't save account");
let changes = Changes { sessions: vec![session.clone()], ..Default::default() };
store.save_changes(changes).await.unwrap();
store.session_cache.get(&sender_key).expect("We should have a session");
// When we save a new version of our device keys
store
.save_changes(Changes {
devices: DeviceChanges {
new: vec![DeviceData::from_account(&account)],
..Default::default()
},
..Default::default()
})
.await
.unwrap();
// Then the session is no longer in the cache
assert!(
store.session_cache.get(&sender_key).is_none(),
"Session should not be in the cache!"
);
}
cryptostore_integration_tests!();
cryptostore_integration_tests_time!();
}