mirror of
https://github.com/matrix-org/matrix-rust-sdk.git
synced 2026-05-07 23:44:53 -04:00
feat(withheld) Remember no_olm and implement in memory store
This commit is contained in:
@@ -635,19 +635,26 @@ impl GroupSessionManager {
|
||||
})
|
||||
.collect();
|
||||
|
||||
no_olm.iter().for_each(|(u, list)| {
|
||||
list.into_iter().for_each(|d| {
|
||||
flat_with_code.push((u.to_owned(), d.to_owned(), WithheldCode::NoOlm))
|
||||
});
|
||||
});
|
||||
for (u, list_devices) in no_olm {
|
||||
for dev in list_devices {
|
||||
let no_olm_sent =
|
||||
self.store.is_no_olm_sent(u.to_owned(), dev.to_owned()).await.unwrap_or(false);
|
||||
if !no_olm_sent {
|
||||
flat_with_code.push((u.to_owned(), dev.to_owned(), WithheldCode::NoOlm));
|
||||
changes
|
||||
.no_olm_sent
|
||||
.entry(u.to_owned())
|
||||
.or_insert_with(Vec::new)
|
||||
.push(dev.to_owned());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let withheld_to_device_request: Vec<Arc<ToDeviceRequest>> = flat_with_code
|
||||
.chunks(Self::MAX_TO_DEVICE_MESSAGES)
|
||||
.map(|chunk| {
|
||||
let mut messages = BTreeMap::new();
|
||||
|
||||
// TODO do not resend code if already sent for that session (or target if
|
||||
// no_olm)
|
||||
chunk.into_iter().for_each(|(user_id, device_id, code)| {
|
||||
let content = RoomKeyWithheldContent::create(
|
||||
algorithm.to_owned(),
|
||||
@@ -947,6 +954,77 @@ mod tests {
|
||||
.map(|r| r.message_count())
|
||||
.sum();
|
||||
assert_eq!(withheld_count, 2);
|
||||
|
||||
requests.iter().filter(|r| r.event_type == "m.room_key.withheld".into()).for_each(|r| {
|
||||
println!("{:?}", r.messages);
|
||||
})
|
||||
}
|
||||
|
||||
#[async_test]
|
||||
async fn test_no_olm_sent_once() {
|
||||
let machine = machine().await;
|
||||
let keys_claim = keys_claim_response();
|
||||
|
||||
let users = keys_claim.one_time_keys.keys().map(Deref::deref);
|
||||
|
||||
let first_room_id = room_id!("!test:localhost");
|
||||
let requests = machine
|
||||
.share_room_key(first_room_id, users, EncryptionSettings::default())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// there will be two no_olm
|
||||
let withheld_count: usize = requests
|
||||
.iter()
|
||||
.filter(|r| r.event_type == "m.room_key.withheld".into())
|
||||
.map(|r| {
|
||||
let mut count = 0;
|
||||
// count targets
|
||||
for (_, message) in &r.messages {
|
||||
message.iter().for_each(|(_, content)| {
|
||||
let withheld: MegolmV1AesSha2WithheldContent =
|
||||
content.deserialize_as::<MegolmV1AesSha2WithheldContent>().unwrap();
|
||||
if withheld.code == WithheldCode::NoOlm {
|
||||
count += 1;
|
||||
}
|
||||
})
|
||||
}
|
||||
count
|
||||
})
|
||||
.sum();
|
||||
assert_eq!(withheld_count, 2);
|
||||
|
||||
// Use another session in an other room
|
||||
let second_room_id = room_id!("!other:localhost");
|
||||
let users = keys_claim.one_time_keys.keys().map(Deref::deref);
|
||||
let requests = machine
|
||||
.share_room_key(second_room_id, users, EncryptionSettings::default())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// there will be two no_olm
|
||||
let withheld_count: usize = requests
|
||||
.iter()
|
||||
.filter(|r| r.event_type == "m.room_key.withheld".into())
|
||||
.map(|r| {
|
||||
let mut count = 0;
|
||||
// count targets
|
||||
for (_, message) in &r.messages {
|
||||
message.iter().for_each(|(_, content)| {
|
||||
let withheld: MegolmV1AesSha2WithheldContent =
|
||||
content.deserialize_as::<MegolmV1AesSha2WithheldContent>().unwrap();
|
||||
if withheld.code == WithheldCode::NoOlm {
|
||||
count += 1;
|
||||
}
|
||||
})
|
||||
}
|
||||
count
|
||||
})
|
||||
.sum();
|
||||
assert_eq!(withheld_count, 0);
|
||||
|
||||
// Help how do I simulate the creation of a new session for the device
|
||||
// with no session now?
|
||||
}
|
||||
|
||||
#[async_test]
|
||||
@@ -1168,6 +1246,6 @@ mod tests {
|
||||
withheld.code == WithheldCode::Blacklisted
|
||||
});
|
||||
|
||||
assert!(has_blacklist)
|
||||
assert!(has_blacklist);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::{collections::HashMap, convert::Infallible, sync::Arc};
|
||||
use std::{collections::HashMap, convert::Infallible, ops::Deref, sync::Arc};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use dashmap::{DashMap, DashSet};
|
||||
@@ -51,6 +51,7 @@ pub struct MemoryStore {
|
||||
identities: Arc<DashMap<OwnedUserId, ReadOnlyUserIdentities>>,
|
||||
outgoing_key_requests: Arc<DashMap<OwnedTransactionId, GossipRequest>>,
|
||||
key_requests_by_info: Arc<DashMap<String, OwnedTransactionId>>,
|
||||
no_olm_sent: Arc<DashMap<OwnedUserId, DashSet<OwnedDeviceId>>>,
|
||||
}
|
||||
|
||||
impl Default for MemoryStore {
|
||||
@@ -63,6 +64,7 @@ impl Default for MemoryStore {
|
||||
identities: Default::default(),
|
||||
outgoing_key_requests: Default::default(),
|
||||
key_requests_by_info: Default::default(),
|
||||
no_olm_sent: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -88,6 +90,11 @@ impl MemoryStore {
|
||||
async fn save_sessions(&self, sessions: Vec<Session>) {
|
||||
for session in sessions {
|
||||
let _ = self.sessions.add(session.clone()).await;
|
||||
// we have a new session for that device so we can forgot no_olm
|
||||
self.no_olm_sent
|
||||
.entry(session.user_id.deref().to_owned())
|
||||
.or_default()
|
||||
.remove(session.device_id.deref());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -144,6 +151,13 @@ impl CryptoStore for MemoryStore {
|
||||
self.key_requests_by_info.insert(info_string, id);
|
||||
}
|
||||
|
||||
for (user_id, device_id_list) in changes.no_olm_sent {
|
||||
self.no_olm_sent
|
||||
.entry(user_id)
|
||||
.or_insert_with(DashSet::new)
|
||||
.extend(device_id_list.into_iter());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -270,12 +284,27 @@ impl CryptoStore for MemoryStore {
|
||||
async fn load_backup_keys(&self) -> Result<BackupKeys> {
|
||||
Ok(BackupKeys::default())
|
||||
}
|
||||
|
||||
async fn is_no_olm_sent(&self, user_id: OwnedUserId, device_id: OwnedDeviceId) -> Result<bool> {
|
||||
Ok(self.no_olm_sent.entry(user_id).or_default().contains(&device_id))
|
||||
}
|
||||
|
||||
/* async fn forget_no_olm_sent(
|
||||
&self,
|
||||
user_id: OwnedUserId,
|
||||
device_id: OwnedDeviceId,
|
||||
) -> Result<()> {
|
||||
self.no_olm_sent.entry(user_id).or_default().remove(&device_id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
*/
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use matrix_sdk_test::async_test;
|
||||
use ruma::room_id;
|
||||
use ruma::{device_id, room_id, user_id};
|
||||
use vodozemac::{Curve25519PublicKey, Ed25519PublicKey};
|
||||
|
||||
use crate::{
|
||||
@@ -366,4 +395,32 @@ mod tests {
|
||||
store.save_changes(changes).await.unwrap();
|
||||
assert!(store.is_message_known(&hash).await.unwrap());
|
||||
}
|
||||
|
||||
#[async_test]
|
||||
async fn test_no_olm_withheld_store() {
|
||||
let store = MemoryStore::new();
|
||||
|
||||
let user_id = user_id!("@alice:example.com");
|
||||
let device_id_1 = device_id!("DEV0001");
|
||||
let device_id_2 = device_id!("DEV0002");
|
||||
|
||||
let mut changes = Changes::default();
|
||||
changes
|
||||
.no_olm_sent
|
||||
.entry(user_id.to_owned())
|
||||
.or_insert_with(Vec::new)
|
||||
.push(device_id_1.to_owned());
|
||||
|
||||
store.save_changes(changes).await.unwrap();
|
||||
|
||||
let is_withheld_for =
|
||||
store.is_no_olm_sent(user_id.to_owned(), device_id_1.to_owned()).await.unwrap();
|
||||
|
||||
assert_eq!(true, is_withheld_for);
|
||||
|
||||
let is_withheld_for =
|
||||
store.is_no_olm_sent(user_id.to_owned(), device_id_2.to_owned()).await.unwrap();
|
||||
|
||||
assert_eq!(false, is_withheld_for);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -39,7 +39,7 @@
|
||||
//! [`CryptoStore`]: trait.Cryptostore.html
|
||||
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
collections::{BTreeMap, HashMap, HashSet},
|
||||
fmt::Debug,
|
||||
ops::Deref,
|
||||
sync::{atomic::AtomicBool, Arc},
|
||||
@@ -117,6 +117,7 @@ pub struct Changes {
|
||||
pub key_requests: Vec<GossipRequest>,
|
||||
pub identities: IdentityChanges,
|
||||
pub devices: DeviceChanges,
|
||||
pub no_olm_sent: BTreeMap<OwnedUserId, Vec<OwnedDeviceId>>,
|
||||
}
|
||||
|
||||
/// A user for which we are tracking the list of devices.
|
||||
@@ -143,6 +144,7 @@ impl Changes {
|
||||
&& self.key_requests.is_empty()
|
||||
&& self.identities.is_empty()
|
||||
&& self.devices.is_empty()
|
||||
&& self.no_olm_sent.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ use std::{collections::HashMap, fmt, sync::Arc};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use matrix_sdk_common::{locks::Mutex, AsyncTraitDeps};
|
||||
use ruma::{DeviceId, OwnedDeviceId, RoomId, TransactionId, UserId};
|
||||
use ruma::{DeviceId, OwnedDeviceId, OwnedUserId, RoomId, TransactionId, UserId};
|
||||
|
||||
use super::{BackupKeys, Changes, CryptoStoreError, Result, RoomKeyCounts};
|
||||
use crate::{
|
||||
@@ -186,6 +186,11 @@ pub trait CryptoStore: AsyncTraitDeps {
|
||||
&self,
|
||||
request_id: &TransactionId,
|
||||
) -> Result<(), Self::Error>;
|
||||
/// Remembers when a no_olm withheld message was already sent to avoid doing
|
||||
/// it again.
|
||||
/// Help: Using user/device as key here because it's sent in clear, but
|
||||
/// should we use the sender_key?
|
||||
async fn is_no_olm_sent(&self, user_id: OwnedUserId, device_id: OwnedDeviceId) -> Result<bool>;
|
||||
}
|
||||
|
||||
#[repr(transparent)]
|
||||
|
||||
@@ -388,7 +388,6 @@ pub(super) mod test {
|
||||
#[test]
|
||||
fn no_olm_should_not_have_room_and_session() {
|
||||
let room_id = room_id!("!DwLygpkclUAfQNnfva:localhost:8481");
|
||||
let user_id = user_id!("@alice:example.org");
|
||||
let device_id = device_id!("DEV001");
|
||||
let sender_key =
|
||||
Curve25519PublicKey::from_base64("9n7mdWKOjr9c4NTlG6zV8dbFtNK79q9vZADoh7nMUwA")
|
||||
|
||||
@@ -35,7 +35,7 @@ use matrix_sdk_crypto::{
|
||||
TrackedUser,
|
||||
};
|
||||
use matrix_sdk_store_encryption::StoreCipher;
|
||||
use ruma::{DeviceId, OwnedDeviceId, RoomId, TransactionId, UserId};
|
||||
use ruma::{DeviceId, OwnedDeviceId, OwnedUserId, RoomId, TransactionId, UserId};
|
||||
use serde::{de::DeserializeOwned, Serialize};
|
||||
use sled::{
|
||||
transaction::{ConflictableTransactionError, TransactionError},
|
||||
@@ -58,6 +58,7 @@ const INBOUND_GROUP_TABLE_NAME: &str = "crypto-store-inbound-group-sessions";
|
||||
const OUTBOUND_GROUP_TABLE_NAME: &str = "crypto-store-outbound-group-sessions";
|
||||
const SECRET_REQUEST_BY_INFO_TABLE: &str = "crypto-store-secret-request-by-info";
|
||||
const TRACKED_USERS_TABLE: &str = "crypto-store-secret-tracked-users";
|
||||
const NO_OLM_SENT_TABLE: &str = "crypto-store-no-olm-sent";
|
||||
|
||||
impl EncodeKey for InboundGroupSession {
|
||||
fn encode(&self) -> Vec<u8> {
|
||||
@@ -185,6 +186,8 @@ pub struct SledCryptoStore {
|
||||
identities: Tree,
|
||||
|
||||
tracked_users: Tree,
|
||||
|
||||
no_olm_sent: Tree,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for SledCryptoStore {
|
||||
@@ -388,6 +391,8 @@ impl SledCryptoStore {
|
||||
|
||||
let session_cache = SessionStore::new();
|
||||
|
||||
let no_olm_sent = db.open_tree("no_olm_sent")?;
|
||||
|
||||
let database = Self {
|
||||
account_info: RwLock::new(None).into(),
|
||||
path,
|
||||
@@ -406,6 +411,7 @@ impl SledCryptoStore {
|
||||
tracked_users,
|
||||
olm_hashes,
|
||||
identities,
|
||||
no_olm_sent,
|
||||
};
|
||||
|
||||
database.upgrade().await?;
|
||||
@@ -1010,6 +1016,11 @@ impl CryptoStore for SledCryptoStore {
|
||||
|
||||
Ok(key)
|
||||
}
|
||||
|
||||
async fn is_no_olm_sent(&self, user_id: OwnedUserId, device_id: OwnedDeviceId) -> Result<bool> {
|
||||
// TODO, help not sure how to to that?
|
||||
Ok(false)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -946,6 +946,15 @@ impl CryptoStore for SqliteCryptoStore {
|
||||
let request_id = self.encode_key("key_requests", request_id.as_bytes());
|
||||
Ok(self.acquire().await?.delete_key_request(request_id).await?)
|
||||
}
|
||||
|
||||
async fn is_no_olm_sent(
|
||||
&self,
|
||||
user_id: OwnedUserId,
|
||||
device_id: OwnedDeviceId,
|
||||
) -> StoreResult<bool> {
|
||||
// TODO help I need some guidance here?
|
||||
Ok(false)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
Reference in New Issue
Block a user