feat(withheld) Remember no_olm and implement in memory store

This commit is contained in:
valere
2023-02-16 17:25:08 +01:00
parent 99232f2340
commit 0154cccd4b
7 changed files with 175 additions and 14 deletions

View File

@@ -635,19 +635,26 @@ impl GroupSessionManager {
})
.collect();
no_olm.iter().for_each(|(u, list)| {
list.into_iter().for_each(|d| {
flat_with_code.push((u.to_owned(), d.to_owned(), WithheldCode::NoOlm))
});
});
for (u, list_devices) in no_olm {
for dev in list_devices {
let no_olm_sent =
self.store.is_no_olm_sent(u.to_owned(), dev.to_owned()).await.unwrap_or(false);
if !no_olm_sent {
flat_with_code.push((u.to_owned(), dev.to_owned(), WithheldCode::NoOlm));
changes
.no_olm_sent
.entry(u.to_owned())
.or_insert_with(Vec::new)
.push(dev.to_owned());
}
}
}
let withheld_to_device_request: Vec<Arc<ToDeviceRequest>> = flat_with_code
.chunks(Self::MAX_TO_DEVICE_MESSAGES)
.map(|chunk| {
let mut messages = BTreeMap::new();
// TODO do not resend code if already sent for that session (or target if
// no_olm)
chunk.into_iter().for_each(|(user_id, device_id, code)| {
let content = RoomKeyWithheldContent::create(
algorithm.to_owned(),
@@ -947,6 +954,77 @@ mod tests {
.map(|r| r.message_count())
.sum();
assert_eq!(withheld_count, 2);
requests.iter().filter(|r| r.event_type == "m.room_key.withheld".into()).for_each(|r| {
println!("{:?}", r.messages);
})
}
#[async_test]
async fn test_no_olm_sent_once() {
let machine = machine().await;
let keys_claim = keys_claim_response();
let users = keys_claim.one_time_keys.keys().map(Deref::deref);
let first_room_id = room_id!("!test:localhost");
let requests = machine
.share_room_key(first_room_id, users, EncryptionSettings::default())
.await
.unwrap();
// there will be two no_olm
let withheld_count: usize = requests
.iter()
.filter(|r| r.event_type == "m.room_key.withheld".into())
.map(|r| {
let mut count = 0;
// count targets
for (_, message) in &r.messages {
message.iter().for_each(|(_, content)| {
let withheld: MegolmV1AesSha2WithheldContent =
content.deserialize_as::<MegolmV1AesSha2WithheldContent>().unwrap();
if withheld.code == WithheldCode::NoOlm {
count += 1;
}
})
}
count
})
.sum();
assert_eq!(withheld_count, 2);
// Use another session in an other room
let second_room_id = room_id!("!other:localhost");
let users = keys_claim.one_time_keys.keys().map(Deref::deref);
let requests = machine
.share_room_key(second_room_id, users, EncryptionSettings::default())
.await
.unwrap();
// there will be two no_olm
let withheld_count: usize = requests
.iter()
.filter(|r| r.event_type == "m.room_key.withheld".into())
.map(|r| {
let mut count = 0;
// count targets
for (_, message) in &r.messages {
message.iter().for_each(|(_, content)| {
let withheld: MegolmV1AesSha2WithheldContent =
content.deserialize_as::<MegolmV1AesSha2WithheldContent>().unwrap();
if withheld.code == WithheldCode::NoOlm {
count += 1;
}
})
}
count
})
.sum();
assert_eq!(withheld_count, 0);
// Help how do I simulate the creation of a new session for the device
// with no session now?
}
#[async_test]
@@ -1168,6 +1246,6 @@ mod tests {
withheld.code == WithheldCode::Blacklisted
});
assert!(has_blacklist)
assert!(has_blacklist);
}
}

View File

@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{collections::HashMap, convert::Infallible, sync::Arc};
use std::{collections::HashMap, convert::Infallible, ops::Deref, sync::Arc};
use async_trait::async_trait;
use dashmap::{DashMap, DashSet};
@@ -51,6 +51,7 @@ pub struct MemoryStore {
identities: Arc<DashMap<OwnedUserId, ReadOnlyUserIdentities>>,
outgoing_key_requests: Arc<DashMap<OwnedTransactionId, GossipRequest>>,
key_requests_by_info: Arc<DashMap<String, OwnedTransactionId>>,
no_olm_sent: Arc<DashMap<OwnedUserId, DashSet<OwnedDeviceId>>>,
}
impl Default for MemoryStore {
@@ -63,6 +64,7 @@ impl Default for MemoryStore {
identities: Default::default(),
outgoing_key_requests: Default::default(),
key_requests_by_info: Default::default(),
no_olm_sent: Default::default(),
}
}
}
@@ -88,6 +90,11 @@ impl MemoryStore {
async fn save_sessions(&self, sessions: Vec<Session>) {
for session in sessions {
let _ = self.sessions.add(session.clone()).await;
// we have a new session for that device so we can forgot no_olm
self.no_olm_sent
.entry(session.user_id.deref().to_owned())
.or_default()
.remove(session.device_id.deref());
}
}
@@ -144,6 +151,13 @@ impl CryptoStore for MemoryStore {
self.key_requests_by_info.insert(info_string, id);
}
for (user_id, device_id_list) in changes.no_olm_sent {
self.no_olm_sent
.entry(user_id)
.or_insert_with(DashSet::new)
.extend(device_id_list.into_iter());
}
Ok(())
}
@@ -270,12 +284,27 @@ impl CryptoStore for MemoryStore {
async fn load_backup_keys(&self) -> Result<BackupKeys> {
Ok(BackupKeys::default())
}
async fn is_no_olm_sent(&self, user_id: OwnedUserId, device_id: OwnedDeviceId) -> Result<bool> {
Ok(self.no_olm_sent.entry(user_id).or_default().contains(&device_id))
}
/* async fn forget_no_olm_sent(
&self,
user_id: OwnedUserId,
device_id: OwnedDeviceId,
) -> Result<()> {
self.no_olm_sent.entry(user_id).or_default().remove(&device_id);
Ok(())
}
*/
}
#[cfg(test)]
mod tests {
use matrix_sdk_test::async_test;
use ruma::room_id;
use ruma::{device_id, room_id, user_id};
use vodozemac::{Curve25519PublicKey, Ed25519PublicKey};
use crate::{
@@ -366,4 +395,32 @@ mod tests {
store.save_changes(changes).await.unwrap();
assert!(store.is_message_known(&hash).await.unwrap());
}
#[async_test]
async fn test_no_olm_withheld_store() {
let store = MemoryStore::new();
let user_id = user_id!("@alice:example.com");
let device_id_1 = device_id!("DEV0001");
let device_id_2 = device_id!("DEV0002");
let mut changes = Changes::default();
changes
.no_olm_sent
.entry(user_id.to_owned())
.or_insert_with(Vec::new)
.push(device_id_1.to_owned());
store.save_changes(changes).await.unwrap();
let is_withheld_for =
store.is_no_olm_sent(user_id.to_owned(), device_id_1.to_owned()).await.unwrap();
assert_eq!(true, is_withheld_for);
let is_withheld_for =
store.is_no_olm_sent(user_id.to_owned(), device_id_2.to_owned()).await.unwrap();
assert_eq!(false, is_withheld_for);
}
}

View File

@@ -39,7 +39,7 @@
//! [`CryptoStore`]: trait.Cryptostore.html
use std::{
collections::{HashMap, HashSet},
collections::{BTreeMap, HashMap, HashSet},
fmt::Debug,
ops::Deref,
sync::{atomic::AtomicBool, Arc},
@@ -117,6 +117,7 @@ pub struct Changes {
pub key_requests: Vec<GossipRequest>,
pub identities: IdentityChanges,
pub devices: DeviceChanges,
pub no_olm_sent: BTreeMap<OwnedUserId, Vec<OwnedDeviceId>>,
}
/// A user for which we are tracking the list of devices.
@@ -143,6 +144,7 @@ impl Changes {
&& self.key_requests.is_empty()
&& self.identities.is_empty()
&& self.devices.is_empty()
&& self.no_olm_sent.is_empty()
}
}

View File

@@ -16,7 +16,7 @@ use std::{collections::HashMap, fmt, sync::Arc};
use async_trait::async_trait;
use matrix_sdk_common::{locks::Mutex, AsyncTraitDeps};
use ruma::{DeviceId, OwnedDeviceId, RoomId, TransactionId, UserId};
use ruma::{DeviceId, OwnedDeviceId, OwnedUserId, RoomId, TransactionId, UserId};
use super::{BackupKeys, Changes, CryptoStoreError, Result, RoomKeyCounts};
use crate::{
@@ -186,6 +186,11 @@ pub trait CryptoStore: AsyncTraitDeps {
&self,
request_id: &TransactionId,
) -> Result<(), Self::Error>;
/// Remembers when a no_olm withheld message was already sent to avoid doing
/// it again.
/// Help: Using user/device as key here because it's sent in clear, but
/// should we use the sender_key?
async fn is_no_olm_sent(&self, user_id: OwnedUserId, device_id: OwnedDeviceId) -> Result<bool>;
}
#[repr(transparent)]

View File

@@ -388,7 +388,6 @@ pub(super) mod test {
#[test]
fn no_olm_should_not_have_room_and_session() {
let room_id = room_id!("!DwLygpkclUAfQNnfva:localhost:8481");
let user_id = user_id!("@alice:example.org");
let device_id = device_id!("DEV001");
let sender_key =
Curve25519PublicKey::from_base64("9n7mdWKOjr9c4NTlG6zV8dbFtNK79q9vZADoh7nMUwA")

View File

@@ -35,7 +35,7 @@ use matrix_sdk_crypto::{
TrackedUser,
};
use matrix_sdk_store_encryption::StoreCipher;
use ruma::{DeviceId, OwnedDeviceId, RoomId, TransactionId, UserId};
use ruma::{DeviceId, OwnedDeviceId, OwnedUserId, RoomId, TransactionId, UserId};
use serde::{de::DeserializeOwned, Serialize};
use sled::{
transaction::{ConflictableTransactionError, TransactionError},
@@ -58,6 +58,7 @@ const INBOUND_GROUP_TABLE_NAME: &str = "crypto-store-inbound-group-sessions";
const OUTBOUND_GROUP_TABLE_NAME: &str = "crypto-store-outbound-group-sessions";
const SECRET_REQUEST_BY_INFO_TABLE: &str = "crypto-store-secret-request-by-info";
const TRACKED_USERS_TABLE: &str = "crypto-store-secret-tracked-users";
const NO_OLM_SENT_TABLE: &str = "crypto-store-no-olm-sent";
impl EncodeKey for InboundGroupSession {
fn encode(&self) -> Vec<u8> {
@@ -185,6 +186,8 @@ pub struct SledCryptoStore {
identities: Tree,
tracked_users: Tree,
no_olm_sent: Tree,
}
impl std::fmt::Debug for SledCryptoStore {
@@ -388,6 +391,8 @@ impl SledCryptoStore {
let session_cache = SessionStore::new();
let no_olm_sent = db.open_tree("no_olm_sent")?;
let database = Self {
account_info: RwLock::new(None).into(),
path,
@@ -406,6 +411,7 @@ impl SledCryptoStore {
tracked_users,
olm_hashes,
identities,
no_olm_sent,
};
database.upgrade().await?;
@@ -1010,6 +1016,11 @@ impl CryptoStore for SledCryptoStore {
Ok(key)
}
async fn is_no_olm_sent(&self, user_id: OwnedUserId, device_id: OwnedDeviceId) -> Result<bool> {
// TODO, help not sure how to to that?
Ok(false)
}
}
#[cfg(test)]

View File

@@ -946,6 +946,15 @@ impl CryptoStore for SqliteCryptoStore {
let request_id = self.encode_key("key_requests", request_id.as_bytes());
Ok(self.acquire().await?.delete_key_request(request_id).await?)
}
async fn is_no_olm_sent(
&self,
user_id: OwnedUserId,
device_id: OwnedDeviceId,
) -> StoreResult<bool> {
// TODO help I need some guidance here?
Ok(false)
}
}
#[cfg(test)]