Merge pull request #4558 from matrix-org/valere/memory_store_consistent_with_other_stores

refactor(crypto): Make memory store behave more like other stores
This commit is contained in:
Valere
2025-01-28 15:27:03 +01:00
committed by GitHub
3 changed files with 189 additions and 141 deletions

View File

@@ -1115,6 +1115,7 @@ mod tests {
use crate::{
gossiping::KeyForwardDecision,
olm::OutboundGroupSession,
store::{CryptoStore, DeviceChanges},
types::requests::AnyOutgoingRequest,
types::{
events::{
@@ -1177,13 +1178,12 @@ mod tests {
let user_id = user_id.to_owned();
let device_id = DeviceId::new();
let account = Account::with_device_id(&user_id, &device_id);
let store = Arc::new(CryptoStoreWrapper::new(&user_id, &device_id, MemoryStore::new()));
let store = Arc::new(store_with_account_helper(&user_id, &device_id).await);
let static_data = store.load_account().await.unwrap().unwrap().static_data;
let identity = Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty(alice_id())));
let verification =
VerificationMachine::new(account.static_data.clone(), identity.clone(), store.clone());
let store = Store::new(account.static_data().clone(), identity, store, verification);
store.save_pending_changes(PendingChanges { account: Some(account) }).await.unwrap();
VerificationMachine::new(static_data.clone(), identity.clone(), store.clone());
let store = Store::new(static_data, identity, store, verification);
let session_cache = GroupSessionCache::new(store.clone());
let identity_manager = IdentityManager::new(store.clone());
@@ -1191,6 +1191,28 @@ mod tests {
GossipMachine::new(store, identity_manager, session_cache, Default::default())
}
#[cfg(feature = "automatic-room-key-forwarding")]
async fn store_with_account_helper(
user_id: &UserId,
device_id: &DeviceId,
) -> CryptoStoreWrapper {
// Properly create the store by first saving the own device and then the account
// data.
let account = Account::with_device_id(user_id, device_id);
let device = DeviceData::from_account(&account);
device.set_trust_state(LocalTrust::Verified);
let changes = Changes {
devices: DeviceChanges { new: vec![device], ..Default::default() },
..Default::default()
};
let mem_store = MemoryStore::new();
mem_store.save_changes(changes).await.unwrap();
mem_store.save_pending_changes(PendingChanges { account: Some(account) }).await.unwrap();
CryptoStoreWrapper::new(user_id, device_id, mem_store)
}
async fn get_machine_test_helper() -> GossipMachine {
let user_id = alice_id().to_owned();
let account = Account::with_device_id(&user_id, alice_device_id());

View File

@@ -27,15 +27,12 @@ use std::{
};
use matrix_sdk_common::locks::RwLock as StdRwLock;
use ruma::{DeviceId, OwnedDeviceId, OwnedRoomId, OwnedUserId, RoomId, UserId};
use ruma::{DeviceId, OwnedDeviceId, OwnedUserId, UserId};
use serde::{Deserialize, Serialize};
use tokio::sync::{Mutex, RwLock};
use tracing::{field::display, instrument, trace, Span};
use crate::{
identities::DeviceData,
olm::{InboundGroupSession, Session},
};
use crate::{identities::DeviceData, olm::Session};
/// In-memory store for Olm Sessions.
#[derive(Debug, Default, Clone)]
@@ -86,52 +83,6 @@ impl SessionStore {
}
}
#[derive(Debug, Default)]
/// In-memory store that holds inbound group sessions.
pub struct GroupSessionStore {
entries: StdRwLock<BTreeMap<OwnedRoomId, HashMap<String, InboundGroupSession>>>,
}
impl GroupSessionStore {
/// Create a new empty store.
pub fn new() -> Self {
Self::default()
}
/// Add an inbound group session to the store.
///
/// Returns true if the session was added, false if the session was
/// already in the store.
pub fn add(&self, session: InboundGroupSession) -> bool {
self.entries
.write()
.entry(session.room_id().to_owned())
.or_default()
.insert(session.session_id().to_owned(), session)
.is_none()
}
/// Get all the group sessions the store knows about.
pub fn get_all(&self) -> Vec<InboundGroupSession> {
self.entries.read().values().flat_map(HashMap::values).cloned().collect()
}
/// Get the number of `InboundGroupSession`s we have.
pub fn count(&self) -> usize {
self.entries.read().values().map(HashMap::len).sum()
}
/// Get a inbound group session from our store.
///
/// # Arguments
/// * `room_id` - The room id of the room that the session belongs to.
///
/// * `session_id` - The unique id of the session.
pub fn get(&self, room_id: &RoomId, session_id: &str) -> Option<InboundGroupSession> {
self.entries.read().get(room_id)?.get(session_id).cloned()
}
}
/// In-memory store holding the devices of users.
#[derive(Debug, Default)]
pub struct DeviceStore {
@@ -381,13 +332,10 @@ impl UsersForKeyQuery {
mod tests {
use matrix_sdk_test::async_test;
use proptest::prelude::*;
use ruma::room_id;
use vodozemac::{Curve25519PublicKey, Ed25519PublicKey};
use super::{DeviceStore, GroupSessionStore, SequenceNumber, SessionStore};
use super::{DeviceStore, SequenceNumber, SessionStore};
use crate::{
identities::device::testing::get_device,
olm::{tests::get_account_and_session_test_helper, InboundGroupSession, SenderData},
identities::device::testing::get_device, olm::tests::get_account_and_session_test_helper,
};
#[async_test]
@@ -422,37 +370,6 @@ mod tests {
assert_eq!(&session, loaded_session);
}
#[async_test]
async fn test_group_session_store() {
let (account, _) = get_account_and_session_test_helper();
let room_id = room_id!("!test:localhost");
let curve_key = "Nn0L2hkcCMFKqynTjyGsJbth7QrVmX3lbrksMkrGOAw";
let (outbound, _) = account.create_group_session_pair_with_defaults(room_id).await;
assert_eq!(0, outbound.message_index().await);
assert!(!outbound.shared());
outbound.mark_as_shared();
assert!(outbound.shared());
let inbound = InboundGroupSession::new(
Curve25519PublicKey::from_base64(curve_key).unwrap(),
Ed25519PublicKey::from_base64("ee3Ek+J2LkkPmjGPGLhMxiKnhiX//xcqaVL4RP6EypE").unwrap(),
room_id,
&outbound.session_key().await,
SenderData::unknown(),
outbound.settings().algorithm.to_owned(),
None,
)
.unwrap();
let store = GroupSessionStore::new();
store.add(inbound.clone());
let loaded_session = store.get(room_id, outbound.session_id()).unwrap();
assert_eq!(inbound, loaded_session);
}
#[async_test]
async fn test_device_store() {
let device = get_device();

View File

@@ -15,6 +15,7 @@
use std::{
collections::{BTreeMap, HashMap, HashSet},
convert::Infallible,
sync::Arc,
};
use async_trait::async_trait;
@@ -25,19 +26,21 @@ use ruma::{
events::secret::request::SecretName, time::Instant, DeviceId, OwnedDeviceId, OwnedRoomId,
OwnedTransactionId, OwnedUserId, RoomId, TransactionId, UserId,
};
use tokio::sync::RwLock;
use tokio::sync::{Mutex, RwLock};
use tracing::warn;
use vodozemac::Curve25519PublicKey;
use super::{
caches::{DeviceStore, GroupSessionStore},
Account, BackupKeys, Changes, CryptoStore, DehydratedDeviceKey, InboundGroupSession,
PendingChanges, RoomKeyCounts, RoomSettings, Session,
caches::DeviceStore, Account, BackupKeys, Changes, CryptoStore, DehydratedDeviceKey,
InboundGroupSession, PendingChanges, RoomKeyCounts, RoomSettings, Session,
};
use crate::{
gossiping::{GossipRequest, GossippedSecret, SecretInfo},
identities::{DeviceData, UserIdentityData},
olm::{OutboundGroupSession, PrivateCrossSigningIdentity, SenderDataType},
olm::{
OutboundGroupSession, PickledAccount, PickledInboundGroupSession, PickledSession,
PrivateCrossSigningIdentity, SenderDataType, StaticAccountData,
},
types::events::room_key_withheld::RoomKeyWithheldEvent,
TrackedUser,
};
@@ -70,9 +73,12 @@ impl BackupVersion {
/// An in-memory only store that will forget all the E2EE key once it's dropped.
#[derive(Debug)]
pub struct MemoryStore {
account: StdRwLock<Option<Account>>,
sessions: StdRwLock<BTreeMap<String, Vec<Session>>>,
inbound_group_sessions: GroupSessionStore,
static_account: Arc<StdRwLock<Option<StaticAccountData>>>,
account: StdRwLock<Option<String>>,
// Map of sender_key to map of session_id to serialized pickle
sessions: StdRwLock<BTreeMap<String, BTreeMap<String, String>>>,
inbound_group_sessions: StdRwLock<BTreeMap<OwnedRoomId, HashMap<String, String>>>,
/// Map room id -> session id -> backup order number
/// The latest backup in which this session is stored. Equivalent to
@@ -96,14 +102,16 @@ pub struct MemoryStore {
dehydrated_device_pickle_key: RwLock<Option<DehydratedDeviceKey>>,
next_batch_token: RwLock<Option<String>>,
room_settings: StdRwLock<HashMap<OwnedRoomId, RoomSettings>>,
save_changes_lock: Arc<Mutex<()>>,
}
impl Default for MemoryStore {
fn default() -> Self {
MemoryStore {
static_account: Default::default(),
account: Default::default(),
sessions: Default::default(),
inbound_group_sessions: GroupSessionStore::new(),
inbound_group_sessions: Default::default(),
inbound_group_sessions_backed_up_to: Default::default(),
outbound_group_sessions: Default::default(),
private_identity: Default::default(),
@@ -121,6 +129,7 @@ impl Default for MemoryStore {
secret_inbox: Default::default(),
next_batch_token: Default::default(),
room_settings: Default::default(),
save_changes_lock: Default::default(),
}
}
}
@@ -131,6 +140,10 @@ impl MemoryStore {
Self::default()
}
fn get_static_account(&self) -> Option<StaticAccountData> {
self.static_account.read().clone()
}
pub(crate) fn save_devices(&self, devices: Vec<DeviceData>) {
for device in devices {
let _ = self.devices.add(device);
@@ -143,17 +156,17 @@ impl MemoryStore {
}
}
fn save_sessions(&self, sessions: Vec<Session>) {
fn save_sessions(&self, sessions: Vec<(String, PickledSession)>) {
let mut session_store = self.sessions.write();
for session in sessions {
let entry = session_store.entry(session.sender_key().to_base64()).or_default();
for (session_id, pickle) in sessions {
let entry = session_store.entry(pickle.sender_key.to_base64()).or_default();
if let Some(old_entry) = entry.iter_mut().find(|entry| &session == *entry) {
*old_entry = session;
} else {
entry.push(session);
}
// insert or replace if exists
entry.insert(
session_id,
serde_json::to_string(&pickle).expect("Failed to serialize olm session"),
);
}
}
@@ -201,7 +214,21 @@ impl CryptoStore for MemoryStore {
type Error = Infallible;
async fn load_account(&self) -> Result<Option<Account>> {
Ok(self.account.read().as_ref().map(|acc| acc.deep_clone()))
let pickled_account: Option<PickledAccount> = self.account.read().as_ref().map(|acc| {
serde_json::from_str(acc)
.expect("Deserialization failed: invalid pickled account JSON format")
});
if let Some(pickle) = pickled_account {
let account =
Account::from_pickle(pickle).expect("From pickle failed: invalid pickle format");
*self.static_account.write() = Some(account.static_data().clone());
Ok(Some(account))
} else {
Ok(None)
}
}
async fn load_identity(&self) -> Result<Option<PrivateCrossSigningIdentity>> {
@@ -213,15 +240,34 @@ impl CryptoStore for MemoryStore {
}
async fn save_pending_changes(&self, changes: PendingChanges) -> Result<()> {
if let Some(account) = changes.account {
*self.account.write() = Some(account);
}
let _guard = self.save_changes_lock.lock().await;
let pickled_account = if let Some(account) = changes.account {
*self.static_account.write() = Some(account.static_data().clone());
Some(account.pickle())
} else {
None
};
*self.account.write() = pickled_account.map(|pickle| {
serde_json::to_string(&pickle)
.expect("Serialization failed: invalid pickled account JSON format")
});
Ok(())
}
async fn save_changes(&self, changes: Changes) -> Result<()> {
self.save_sessions(changes.sessions);
let _guard = self.save_changes_lock.lock().await;
let mut pickled_session: Vec<(String, PickledSession)> = Vec::new();
for session in changes.sessions {
let session_id = session.session_id().to_owned();
let pickle = session.pickle().await;
pickled_session.push((session_id.clone(), pickle));
}
self.save_sessions(pickled_session);
self.save_inbound_group_sessions(changes.inbound_group_sessions, None).await?;
self.save_outbound_group_sessions(changes.outbound_group_sessions);
self.save_private_identity(changes.private_identity);
@@ -235,7 +281,8 @@ impl CryptoStore for MemoryStore {
for identity in changes.identities.new.into_iter().chain(changes.identities.changed) {
identities.insert(
identity.user_id().to_owned(),
serde_json::to_string(&identity).unwrap(),
serde_json::to_string(&identity)
.expect("UserIdentityData should always serialize to json"),
);
}
}
@@ -309,9 +356,6 @@ impl CryptoStore for MemoryStore {
sessions: Vec<InboundGroupSession>,
backed_up_to_version: Option<&str>,
) -> Result<()> {
let mut inbound_group_sessions_backed_up_to =
self.inbound_group_sessions_backed_up_to.write();
for session in sessions {
let room_id = session.room_id();
let session_id = session.session_id();
@@ -327,18 +371,43 @@ impl CryptoStore for MemoryStore {
}
if let Some(backup_version) = backed_up_to_version {
inbound_group_sessions_backed_up_to
self.inbound_group_sessions_backed_up_to
.write()
.entry(room_id.to_owned())
.or_default()
.insert(session_id.to_owned(), BackupVersion::from(backup_version));
}
self.inbound_group_sessions.add(session);
let pickle = session.pickle().await;
self.inbound_group_sessions
.write()
.entry(session.room_id().to_owned())
.or_default()
.insert(
session.session_id().to_owned(),
serde_json::to_string(&pickle)
.expect("Pickle pickle data should serialize to json"),
);
}
Ok(())
}
async fn get_sessions(&self, sender_key: &str) -> Result<Option<Vec<Session>>> {
Ok(self.sessions.read().get(sender_key).cloned())
let device_keys = self.get_own_device().await?.as_device_keys().clone();
if let Some(pickles) = self.sessions.read().get(sender_key) {
let mut sessions: Vec<Session> = Vec::new();
for serialized_pickle in pickles.values() {
let pickle: PickledSession = serde_json::from_str(serialized_pickle.as_str())
.expect("Pickle pickle deserialization should work");
let session = Session::from_pickle(device_keys.clone(), pickle)
.expect("Expect from pickle to always work");
sessions.push(session);
}
Ok(Some(sessions))
} else {
Ok(None)
}
}
async fn get_inbound_group_session(
@@ -346,7 +415,18 @@ impl CryptoStore for MemoryStore {
room_id: &RoomId,
session_id: &str,
) -> Result<Option<InboundGroupSession>> {
Ok(self.inbound_group_sessions.get(room_id, session_id))
let pickle: Option<PickledInboundGroupSession> = self
.inbound_group_sessions
.read()
.get(room_id)
.and_then(|m| m.get(session_id))
.and_then(|ser| {
serde_json::from_str(ser).expect("Pickle pickle deserialization should work")
});
Ok(pickle.map(|p| {
InboundGroupSession::from_pickle(p).expect("Expect from pickle to always work")
}))
}
async fn get_withheld_info(
@@ -362,7 +442,18 @@ impl CryptoStore for MemoryStore {
}
async fn get_inbound_group_sessions(&self) -> Result<Vec<InboundGroupSession>> {
Ok(self.inbound_group_sessions.get_all())
let inbounds = self
.inbound_group_sessions
.read()
.values()
.flat_map(HashMap::values)
.map(|ser| {
let pickle: PickledInboundGroupSession =
serde_json::from_str(ser).expect("Pickle deserialization should work");
InboundGroupSession::from_pickle(pickle).expect("Expect from pickle to always work")
})
.collect();
Ok(inbounds)
}
async fn inbound_group_session_counts(
@@ -383,7 +474,8 @@ impl CryptoStore for MemoryStore {
0
};
Ok(RoomKeyCounts { total: self.inbound_group_sessions.count(), backed_up })
let total = self.inbound_group_sessions.read().values().map(HashMap::len).sum();
Ok(RoomKeyCounts { total, backed_up })
}
async fn get_inbound_group_sessions_for_device_batch(
@@ -455,21 +547,26 @@ impl CryptoStore for MemoryStore {
backup_version: &str,
room_and_session_ids: &[(&RoomId, &str)],
) -> Result<()> {
let mut inbound_group_sessions_backed_up_to =
self.inbound_group_sessions_backed_up_to.write();
for &(room_id, session_id) in room_and_session_ids {
let session = self.inbound_group_sessions.get(room_id, session_id);
let session = self.get_inbound_group_session(room_id, session_id).await?;
if let Some(session) = session {
session.mark_as_backed_up();
inbound_group_sessions_backed_up_to
self.inbound_group_sessions_backed_up_to
.write()
.entry(room_id.to_owned())
.or_default()
.insert(session_id.to_owned(), BackupVersion::from(backup_version));
self.inbound_group_sessions.add(session);
// Save it back
let updated_pickle = session.pickle().await;
self.inbound_group_sessions.write().entry(room_id.to_owned()).or_default().insert(
session_id.to_owned(),
serde_json::to_string(&updated_pickle)
.expect("Pickle serialization should work"),
);
}
}
@@ -535,9 +632,13 @@ impl CryptoStore for MemoryStore {
}
async fn get_own_device(&self) -> Result<DeviceData> {
let account = self.load_account().await?.unwrap();
let account =
self.get_static_account().expect("Expect account to exist when getting own device");
Ok(self.devices.get(&account.user_id, &account.device_id).unwrap())
Ok(self
.devices
.get(&account.user_id, &account.device_id)
.expect("Invalid state: Should always have a own device"))
}
async fn get_user_identity(&self, user_id: &UserId) -> Result<Option<UserIdentityData>> {
@@ -545,7 +646,8 @@ impl CryptoStore for MemoryStore {
match serialized {
None => Ok(None),
Some(serialized) => {
let id: UserIdentityData = serde_json::from_str(serialized.as_str()).unwrap();
let id: UserIdentityData = serde_json::from_str(serialized.as_str())
.expect("Only valid serialized identity are saved");
Ok(Some(id))
}
}
@@ -656,18 +758,31 @@ mod tests {
tests::get_account_and_session_test_helper, Account, InboundGroupSession,
OlmMessageHash, PrivateCrossSigningIdentity, SenderData,
},
store::{memorystore::MemoryStore, Changes, CryptoStore, PendingChanges},
store::{memorystore::MemoryStore, Changes, CryptoStore, DeviceChanges, PendingChanges},
DeviceData,
};
#[async_test]
async fn test_session_store() {
let (account, session) = get_account_and_session_test_helper();
let own_device = DeviceData::from_account(&account);
let store = MemoryStore::new();
assert!(store.load_account().await.unwrap().is_none());
store
.save_changes(Changes {
devices: DeviceChanges { new: vec![own_device], ..Default::default() },
..Default::default()
})
.await
.unwrap();
store.save_pending_changes(PendingChanges { account: Some(account) }).await.unwrap();
store.save_sessions(vec![session.clone()]);
store
.save_changes(Changes { sessions: (vec![session.clone()]), ..Default::default() })
.await
.unwrap();
let sessions = store.get_sessions(&session.sender_key.to_base64()).await.unwrap().unwrap();
@@ -1154,12 +1269,6 @@ mod integration_tests {
}
}
impl MemoryStore {
fn get_static_account(&self) -> Option<StaticAccountData> {
self.account.read().as_ref().map(|acc| acc.static_data().clone())
}
}
/// Return a clone of the store for the test with the supplied name. Note:
/// dropping this store won't destroy its data, since
/// [PersistentMemoryStore] is a reference-counted smart pointer