mirror of
https://github.com/matrix-org/matrix-rust-sdk.git
synced 2026-05-06 15:04:11 -04:00
fix: Use the DisplayName struct to protect against homoglyph attacks
This commit is contained in:
@@ -16,7 +16,7 @@
|
||||
#[cfg(feature = "e2e-encryption")]
|
||||
use std::sync::Arc;
|
||||
use std::{
|
||||
collections::{BTreeMap, BTreeSet},
|
||||
collections::{BTreeMap, BTreeSet, HashMap},
|
||||
fmt, iter,
|
||||
ops::Deref,
|
||||
};
|
||||
@@ -68,7 +68,7 @@ use crate::latest_event::{is_suitable_for_latest_event, LatestEvent, PossibleLat
|
||||
#[cfg(feature = "e2e-encryption")]
|
||||
use crate::RoomMemberships;
|
||||
use crate::{
|
||||
deserialized_responses::{RawAnySyncOrStrippedTimelineEvent, SyncTimelineEvent},
|
||||
deserialized_responses::{DisplayName, RawAnySyncOrStrippedTimelineEvent, SyncTimelineEvent},
|
||||
error::{Error, Result},
|
||||
event_cache::store::EventCacheStoreLock,
|
||||
response_processors::AccountDataProcessor,
|
||||
@@ -1332,7 +1332,7 @@ impl BaseClient {
|
||||
#[cfg(feature = "e2e-encryption")]
|
||||
let mut user_ids = BTreeSet::new();
|
||||
|
||||
let mut ambiguity_map: BTreeMap<String, BTreeSet<OwnedUserId>> = BTreeMap::new();
|
||||
let mut ambiguity_map: HashMap<DisplayName, BTreeSet<OwnedUserId>> = Default::default();
|
||||
|
||||
for raw_event in &response.chunk {
|
||||
let member = match raw_event.deserialize() {
|
||||
@@ -1363,7 +1363,11 @@ impl BaseClient {
|
||||
|
||||
if let StateEvent::Original(e) = &member {
|
||||
if let Some(d) = &e.content.displayname {
|
||||
ambiguity_map.entry(d.clone()).or_default().insert(member.state_key().clone());
|
||||
let display_name = DisplayName::new(d);
|
||||
ambiguity_map
|
||||
.entry(display_name)
|
||||
.or_default()
|
||||
.insert(member.state_key().clone());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -469,10 +469,12 @@ impl MemberEvent {
|
||||
///
|
||||
/// It there is no `displayname` in the event's content, the localpart or
|
||||
/// the user ID is returned.
|
||||
pub fn display_name(&self) -> &str {
|
||||
self.original_content()
|
||||
.and_then(|c| c.displayname.as_deref())
|
||||
.unwrap_or_else(|| self.user_id().localpart())
|
||||
pub fn display_name(&self) -> DisplayName {
|
||||
DisplayName::new(
|
||||
self.original_content()
|
||||
.and_then(|c| c.displayname.as_deref())
|
||||
.unwrap_or_else(|| self.user_id().localpart()),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
// limitations under the License.
|
||||
|
||||
use std::{
|
||||
collections::{BTreeMap, BTreeSet},
|
||||
collections::{BTreeSet, HashMap},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
@@ -30,7 +30,8 @@ use ruma::{
|
||||
};
|
||||
|
||||
use crate::{
|
||||
deserialized_responses::{MemberEvent, SyncOrStrippedState},
|
||||
deserialized_responses::{DisplayName, MemberEvent, SyncOrStrippedState},
|
||||
store::ambiguity_map::is_display_name_ambiguous,
|
||||
MinimalRoomMemberEvent,
|
||||
};
|
||||
|
||||
@@ -67,8 +68,10 @@ impl RoomMember {
|
||||
} = room_info;
|
||||
|
||||
let is_room_creator = room_creator.as_deref() == Some(event.user_id());
|
||||
let display_name_ambiguous =
|
||||
users_display_names.get(event.display_name()).is_some_and(|s| s.len() > 1);
|
||||
let display_name = event.display_name();
|
||||
let display_name_ambiguous = users_display_names
|
||||
.get(&display_name)
|
||||
.is_some_and(|s| is_display_name_ambiguous(&display_name, s));
|
||||
let is_ignored = ignored_users.as_ref().is_some_and(|s| s.contains(event.user_id()));
|
||||
|
||||
Self {
|
||||
@@ -245,6 +248,6 @@ pub(crate) struct MemberRoomInfo<'a> {
|
||||
pub(crate) power_levels: Arc<Option<SyncOrStrippedState<RoomPowerLevelsEventContent>>>,
|
||||
pub(crate) max_power_level: i64,
|
||||
pub(crate) room_creator: Option<OwnedUserId>,
|
||||
pub(crate) users_display_names: BTreeMap<&'a str, BTreeSet<OwnedUserId>>,
|
||||
pub(crate) users_display_names: HashMap<&'a DisplayName, BTreeSet<OwnedUserId>>,
|
||||
pub(crate) ignored_users: Option<BTreeSet<OwnedUserId>>,
|
||||
}
|
||||
|
||||
@@ -67,7 +67,7 @@ use super::{
|
||||
#[cfg(feature = "experimental-sliding-sync")]
|
||||
use crate::latest_event::LatestEvent;
|
||||
use crate::{
|
||||
deserialized_responses::{MemberEvent, RawSyncOrStrippedState},
|
||||
deserialized_responses::{DisplayName, MemberEvent, RawSyncOrStrippedState},
|
||||
notification_settings::RoomNotificationMode,
|
||||
read_receipts::RoomReadReceipts,
|
||||
store::{DynStateStore, Result as StoreResult, StateStoreExt},
|
||||
@@ -819,8 +819,7 @@ impl Room {
|
||||
})
|
||||
.collect::<BTreeMap<_, _>>();
|
||||
|
||||
let display_names =
|
||||
member_events.iter().map(|e| e.display_name().to_owned()).collect::<Vec<_>>();
|
||||
let display_names = member_events.iter().map(|e| e.display_name()).collect::<Vec<_>>();
|
||||
let room_info = self.member_room_info(&display_names).await?;
|
||||
|
||||
let mut members = Vec::new();
|
||||
@@ -900,7 +899,7 @@ impl Room {
|
||||
|
||||
let profile = self.store.get_profile(self.room_id(), user_id).await?;
|
||||
|
||||
let display_names = [event.display_name().to_owned()];
|
||||
let display_names = [event.display_name()];
|
||||
let room_info = self.member_room_info(&display_names).await?;
|
||||
|
||||
Ok(Some(RoomMember::from_parts(event, profile, presence, &room_info)))
|
||||
@@ -911,7 +910,7 @@ impl Room {
|
||||
/// Async because it can read from storage.
|
||||
async fn member_room_info<'a>(
|
||||
&self,
|
||||
display_names: &'a [String],
|
||||
display_names: &'a [DisplayName],
|
||||
) -> StoreResult<MemberRoomInfo<'a>> {
|
||||
let max_power_level = self.max_power_level();
|
||||
let room_creator = self.inner.read().creator().map(ToOwned::to_owned);
|
||||
|
||||
@@ -695,6 +695,10 @@ async fn cache_latest_events(
|
||||
changes: Option<&StateChanges>,
|
||||
store: Option<&Store>,
|
||||
) {
|
||||
use crate::{
|
||||
deserialized_responses::DisplayName, store::ambiguity_map::is_display_name_ambiguous,
|
||||
};
|
||||
|
||||
let mut encrypted_events =
|
||||
Vec::with_capacity(room.latest_encrypted_events.read().unwrap().capacity());
|
||||
|
||||
@@ -752,11 +756,13 @@ async fn cache_latest_events(
|
||||
.as_original()
|
||||
.and_then(|profile| profile.content.displayname.as_ref())
|
||||
.and_then(|display_name| {
|
||||
let display_name = DisplayName::new(display_name);
|
||||
|
||||
changes.ambiguity_maps.get(room.room_id()).and_then(
|
||||
|map_for_room| {
|
||||
map_for_room
|
||||
.get(display_name)
|
||||
.map(|user_ids| user_ids.len() > 1)
|
||||
map_for_room.get(&display_name).map(|users| {
|
||||
is_display_name_ambiguous(&display_name, users)
|
||||
})
|
||||
},
|
||||
)
|
||||
});
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
// limitations under the License.
|
||||
|
||||
use std::{
|
||||
collections::{BTreeMap, BTreeSet},
|
||||
collections::{BTreeMap, BTreeSet, HashMap},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
@@ -24,18 +24,18 @@ use ruma::{
|
||||
},
|
||||
OwnedEventId, OwnedRoomId, OwnedUserId, RoomId, UserId,
|
||||
};
|
||||
use tracing::trace;
|
||||
use tracing::{instrument, trace};
|
||||
|
||||
use super::{DynStateStore, Result, StateChanges};
|
||||
use crate::{
|
||||
deserialized_responses::{AmbiguityChange, RawMemberEvent},
|
||||
deserialized_responses::{AmbiguityChange, DisplayName, RawMemberEvent},
|
||||
store::StateStoreExt,
|
||||
};
|
||||
|
||||
/// A map of users that use a certain display name.
|
||||
#[derive(Debug, Clone)]
|
||||
struct DisplayNameUsers {
|
||||
display_name: String,
|
||||
display_name: DisplayName,
|
||||
users: BTreeSet<OwnedUserId>,
|
||||
}
|
||||
|
||||
@@ -70,7 +70,7 @@ impl DisplayNameUsers {
|
||||
|
||||
/// Is the display name considered to be ambiguous.
|
||||
fn is_ambiguous(&self) -> bool {
|
||||
self.user_count() > 1
|
||||
is_display_name_ambiguous(&self.display_name, &self.users)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -82,10 +82,19 @@ fn is_member_active(membership: &MembershipState) -> bool {
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct AmbiguityCache {
|
||||
pub store: Arc<DynStateStore>,
|
||||
pub cache: BTreeMap<OwnedRoomId, BTreeMap<String, BTreeSet<OwnedUserId>>>,
|
||||
pub cache: BTreeMap<OwnedRoomId, HashMap<DisplayName, BTreeSet<OwnedUserId>>>,
|
||||
pub changes: BTreeMap<OwnedRoomId, BTreeMap<OwnedEventId, AmbiguityChange>>,
|
||||
}
|
||||
|
||||
#[instrument(ret)]
|
||||
pub(crate) fn is_display_name_ambiguous(
|
||||
display_name: &DisplayName,
|
||||
users_with_display_name: &BTreeSet<OwnedUserId>,
|
||||
) -> bool {
|
||||
trace!("Checking if a display name is ambiguous");
|
||||
display_name.is_inherently_ambiguous() || users_with_display_name.len() > 1
|
||||
}
|
||||
|
||||
impl AmbiguityCache {
|
||||
/// Create a new [`AmbiguityCache`] backed by the given state store.
|
||||
pub fn new(store: Arc<DynStateStore>) -> Self {
|
||||
@@ -224,18 +233,15 @@ impl AmbiguityCache {
|
||||
async fn get_users_with_display_name(
|
||||
&mut self,
|
||||
room_id: &RoomId,
|
||||
display_name: &str,
|
||||
display_name: &DisplayName,
|
||||
) -> Result<DisplayNameUsers> {
|
||||
Ok(if let Some(u) = self.cache.entry(room_id.to_owned()).or_default().get(display_name) {
|
||||
DisplayNameUsers { display_name: display_name.to_owned(), users: u.clone() }
|
||||
DisplayNameUsers { display_name: display_name.clone(), users: u.clone() }
|
||||
} else {
|
||||
let users_with_display_name =
|
||||
self.store.get_users_with_display_name(room_id, display_name).await?;
|
||||
|
||||
DisplayNameUsers {
|
||||
display_name: display_name.to_owned(),
|
||||
users: users_with_display_name,
|
||||
}
|
||||
DisplayNameUsers { display_name: display_name.clone(), users: users_with_display_name }
|
||||
})
|
||||
}
|
||||
|
||||
@@ -254,7 +260,8 @@ impl AmbiguityCache {
|
||||
let old_display_name = self.get_old_display_name(changes, room_id, member_event).await?;
|
||||
|
||||
let old_map = if let Some(old_name) = old_display_name.as_deref() {
|
||||
Some(self.get_users_with_display_name(room_id, old_name).await?)
|
||||
let old_display_name = DisplayName::new(old_name);
|
||||
Some(self.get_users_with_display_name(room_id, &old_display_name).await?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
@@ -275,11 +282,221 @@ impl AmbiguityCache {
|
||||
new
|
||||
};
|
||||
|
||||
Some(self.get_users_with_display_name(room_id, new_display_name).await?)
|
||||
let new_display_name = DisplayName::new(new_display_name);
|
||||
|
||||
Some(self.get_users_with_display_name(room_id, &new_display_name).await?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok((old_map, new_map))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn check(&self, room_id: &RoomId, display_name: &DisplayName) -> bool {
|
||||
self.cache
|
||||
.get(room_id)
|
||||
.and_then(|display_names| {
|
||||
display_names
|
||||
.get(display_name)
|
||||
.map(|user_ids| is_display_name_ambiguous(display_name, user_ids))
|
||||
})
|
||||
.unwrap_or_else(|| {
|
||||
panic!(
|
||||
"The display name {:?} should be part of the cache {:?}",
|
||||
display_name, self.cache
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use matrix_sdk_test::async_test;
|
||||
use ruma::{room_id, server_name, user_id, EventId};
|
||||
use serde_json::json;
|
||||
|
||||
use super::*;
|
||||
use crate::store::{IntoStateStore, MemoryStore};
|
||||
|
||||
fn generate_event(user_id: &UserId, display_name: &str) -> SyncRoomMemberEvent {
|
||||
let server_name = server_name!("localhost");
|
||||
serde_json::from_value(json!({
|
||||
"content": {
|
||||
"displayname": display_name,
|
||||
"membership": "join"
|
||||
},
|
||||
"event_id": EventId::new(server_name),
|
||||
"origin_server_ts": 152037280,
|
||||
"sender": user_id,
|
||||
"state_key": user_id,
|
||||
"type": "m.room.member",
|
||||
|
||||
}))
|
||||
.expect("We should be able to deserialize the static member event")
|
||||
}
|
||||
|
||||
macro_rules! assert_ambiguity {
|
||||
(
|
||||
[ $( ($user:literal, $display_name:literal) ),* ],
|
||||
[ $( ($check_display_name:literal, $ambiguous:expr) ),* ] $(,)?
|
||||
) => {
|
||||
assert_ambiguity!(
|
||||
[ $( ($user, $display_name) ),* ],
|
||||
[ $( ($check_display_name, $ambiguous) ),* ],
|
||||
"The test failed the ambiguity assertions"
|
||||
)
|
||||
};
|
||||
|
||||
(
|
||||
[ $( ($user:literal, $display_name:literal) ),* ],
|
||||
[ $( ($check_display_name:literal, $ambiguous:expr) ),* ],
|
||||
$description:literal $(,)?
|
||||
) => {
|
||||
let store = MemoryStore::new();
|
||||
let mut ambiguity_cache = AmbiguityCache::new(store.into_state_store());
|
||||
|
||||
let changes = Default::default();
|
||||
let room_id = room_id!("!foo:bar");
|
||||
|
||||
macro_rules! add_display_name {
|
||||
($u:literal, $n:literal) => {
|
||||
let event = generate_event(user_id!($u), $n);
|
||||
|
||||
ambiguity_cache
|
||||
.handle_event(&changes, room_id, &event)
|
||||
.await
|
||||
.expect("We should be able to handle a member event to calculate the ambiguity.");
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! assert_display_name_ambiguity {
|
||||
($n:literal, $a:expr) => {
|
||||
let display_name = DisplayName::new($n);
|
||||
|
||||
if ambiguity_cache.check(room_id, &display_name) != $a {
|
||||
let foo = if $a { "be" } else { "not be" };
|
||||
panic!("{}: the display name {} should {} ambiguous", $description, $n, foo);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
$(
|
||||
add_display_name!($user, $display_name);
|
||||
)*
|
||||
|
||||
$(
|
||||
assert_display_name_ambiguity!($check_display_name, $ambiguous);
|
||||
)*
|
||||
};
|
||||
}
|
||||
|
||||
#[async_test]
|
||||
async fn test_disambiguation() {
|
||||
assert_ambiguity!(
|
||||
[("@alice:localhost", "alice")],
|
||||
[("alice", false)],
|
||||
"Alice is alone in the room"
|
||||
);
|
||||
|
||||
assert_ambiguity!(
|
||||
[("@alice:localhost", "alice")],
|
||||
[("Alice", false)],
|
||||
"Alice is alone in the room and has a capitalized display name"
|
||||
);
|
||||
|
||||
assert_ambiguity!(
|
||||
[("@alice:localhost", "alice"), ("@bob:localhost", "alice")],
|
||||
[("alice", true)],
|
||||
"Alice and bob share a display name"
|
||||
);
|
||||
|
||||
assert_ambiguity!(
|
||||
[
|
||||
("@alice:localhost", "alice"),
|
||||
("@bob:localhost", "alice"),
|
||||
("@carol:localhost", "carol")
|
||||
],
|
||||
[("alice", true), ("carol", false)],
|
||||
"Alice and Bob share a display name, while Carol is unique"
|
||||
);
|
||||
|
||||
assert_ambiguity!(
|
||||
[("@alice:localhost", "alice"), ("@bob:localhost", "ALICE")],
|
||||
[("alice", true)],
|
||||
"Alice and Bob share a display name that is differently capitalized"
|
||||
);
|
||||
|
||||
assert_ambiguity!(
|
||||
[("@alice:localhost", "alice"), ("@bob:localhost", "аlice")],
|
||||
[("alice", true)],
|
||||
"Bob tries to impersonate Alice using a cyrilic а"
|
||||
);
|
||||
|
||||
assert_ambiguity!(
|
||||
[("@alice:localhost", "@bob:localhost"), ("@bob:localhost", "аlice")],
|
||||
[("@bob:localhost", true)],
|
||||
"Alice tries to impersonate bob using an mxid"
|
||||
);
|
||||
|
||||
assert_ambiguity!(
|
||||
[("@alice:localhost", "Sahasrahla"), ("@bob:localhost", "𝒮𝒶𝒽𝒶𝓈𝓇𝒶𝒽𝓁𝒶")],
|
||||
[("Sahasrahla", true)],
|
||||
"Bob tries to impersonate Alice using scripture symbols"
|
||||
);
|
||||
|
||||
assert_ambiguity!(
|
||||
[("@alice:localhost", "Sahasrahla"), ("@bob:localhost", "𝔖𝔞𝔥𝔞𝔰𝔯𝔞𝔥𝔩𝔞")],
|
||||
[("Sahasrahla", true)],
|
||||
"Bob tries to impersonate Alice using fraktur symbols"
|
||||
);
|
||||
|
||||
assert_ambiguity!(
|
||||
[("@alice:localhost", "Sahasrahla"), ("@bob:localhost", "Ⓢⓐⓗⓐⓢⓡⓐⓗⓛⓐ")],
|
||||
[("Sahasrahla", true)],
|
||||
"Bob tries to impersonate Alice using circled symbols"
|
||||
);
|
||||
|
||||
assert_ambiguity!(
|
||||
[("@alice:localhost", "Sahasrahla"), ("@bob:localhost", "🅂🄰🄷🄰🅂🅁🄰🄷🄻🄰")],
|
||||
[("Sahasrahla", true)],
|
||||
"Bob tries to impersonate Alice using squared symbols"
|
||||
);
|
||||
|
||||
assert_ambiguity!(
|
||||
[("@alice:localhost", "Sahasrahla"), ("@bob:localhost", "Sahasrahla")],
|
||||
[("Sahasrahla", true)],
|
||||
"Bob tries to impersonate Alice using big unicode letters"
|
||||
);
|
||||
|
||||
assert_ambiguity!(
|
||||
[("@alice:localhost", "Sahasrahla"), ("@bob:localhost", "\u{202e}alharsahas")],
|
||||
[("Sahasrahla", true)],
|
||||
"Bob tries to impersonate Alice using left to right shenanigans"
|
||||
);
|
||||
|
||||
assert_ambiguity!(
|
||||
[("@alice:localhost", "Sahasrahla"), ("@bob:localhost", "Sa̴hasrahla")],
|
||||
[("Sahasrahla", true)],
|
||||
"Bob tries to impersonate Alice using a diacritical mark"
|
||||
);
|
||||
|
||||
assert_ambiguity!(
|
||||
[("@alice:localhost", "Sahasrahla"), ("@bob:localhost", "Sahas\u{200B}rahla")],
|
||||
[("Sahasrahla", true)],
|
||||
"Bob tries to impersonate Alice using a zero-width space"
|
||||
);
|
||||
|
||||
assert_ambiguity!(
|
||||
[("@alice:localhost", "Sahasrahla"), ("@bob:localhost", "Sahas\u{200D}rahla")],
|
||||
[("Sahasrahla", true)],
|
||||
"Bob tries to impersonate Alice using a zero-width space"
|
||||
);
|
||||
|
||||
assert_ambiguity!(
|
||||
[("@alice:localhost", "ff"), ("@bob:localhost", "\u{FB00}")],
|
||||
[("ff", true)],
|
||||
"Bob tries to impersonate Alice using a ligature"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
//! Trait and macro of integration tests for StateStore implementations.
|
||||
|
||||
use std::collections::{BTreeMap, BTreeSet};
|
||||
use std::collections::{BTreeMap, BTreeSet, HashMap};
|
||||
|
||||
use assert_matches::assert_matches;
|
||||
use assert_matches2::assert_let;
|
||||
@@ -34,7 +34,8 @@ use ruma::{
|
||||
use serde_json::{json, value::Value as JsonValue};
|
||||
|
||||
use super::{
|
||||
send_queue::SentRequestKey, DependentQueuedRequestKind, DynStateStore, ServerCapabilities,
|
||||
send_queue::SentRequestKey, DependentQueuedRequestKind, DisplayName, DynStateStore,
|
||||
ServerCapabilities,
|
||||
};
|
||||
use crate::{
|
||||
deserialized_responses::MemberEvent,
|
||||
@@ -141,13 +142,15 @@ impl StateStoreIntegrationTests for DynStateStore {
|
||||
room.handle_state_event(&topic_event);
|
||||
changes.add_state_event(room_id, topic_event, topic_raw);
|
||||
|
||||
let mut room_ambiguity_map = BTreeMap::new();
|
||||
let mut room_ambiguity_map = HashMap::new();
|
||||
let mut room_profiles = BTreeMap::new();
|
||||
|
||||
let member_json: &JsonValue = &test_json::MEMBER;
|
||||
let member_event: SyncRoomMemberEvent =
|
||||
serde_json::from_value(member_json.clone()).unwrap();
|
||||
let displayname = member_event.as_original().unwrap().content.displayname.clone().unwrap();
|
||||
let displayname = DisplayName::new(
|
||||
member_event.as_original().unwrap().content.displayname.as_ref().unwrap(),
|
||||
);
|
||||
room_ambiguity_map.insert(displayname.clone(), BTreeSet::from([user_id.to_owned()]));
|
||||
room_profiles.insert(user_id.to_owned(), (&member_event).into());
|
||||
|
||||
@@ -256,6 +259,8 @@ impl StateStoreIntegrationTests for DynStateStore {
|
||||
async fn test_populate_store(&self) -> Result<()> {
|
||||
let room_id = room_id();
|
||||
let user_id = user_id();
|
||||
let display_name = DisplayName::new("example");
|
||||
|
||||
self.populate().await?;
|
||||
|
||||
assert!(self.get_kv_data(StateStoreDataKey::SyncToken).await?.is_some());
|
||||
@@ -290,7 +295,7 @@ impl StateStoreIntegrationTests for DynStateStore {
|
||||
"Expected to find 1 joined user ids"
|
||||
);
|
||||
assert_eq!(
|
||||
self.get_users_with_display_name(room_id, "example").await?.len(),
|
||||
self.get_users_with_display_name(room_id, &display_name).await?.len(),
|
||||
2,
|
||||
"Expected to find 2 display names for room"
|
||||
);
|
||||
@@ -962,6 +967,7 @@ impl StateStoreIntegrationTests for DynStateStore {
|
||||
async fn test_room_removal(&self) -> Result<()> {
|
||||
let room_id = room_id();
|
||||
let user_id = user_id();
|
||||
let display_name = DisplayName::new("example");
|
||||
let stripped_room_id = stripped_room_id();
|
||||
|
||||
self.populate().await?;
|
||||
@@ -990,7 +996,7 @@ impl StateStoreIntegrationTests for DynStateStore {
|
||||
"still joined users found"
|
||||
);
|
||||
assert!(
|
||||
self.get_users_with_display_name(room_id, "example").await?.is_empty(),
|
||||
self.get_users_with_display_name(room_id, &display_name).await?.is_empty(),
|
||||
"still display names found"
|
||||
);
|
||||
assert!(self
|
||||
@@ -1145,15 +1151,15 @@ impl StateStoreIntegrationTests for DynStateStore {
|
||||
async fn test_display_names_saving(&self) {
|
||||
let room_id = room_id!("!test_display_names_saving:localhost");
|
||||
let user_id = user_id();
|
||||
let user_display_name = "User";
|
||||
let user_display_name = DisplayName::new("User");
|
||||
let second_user_id = user_id!("@second:localhost");
|
||||
let third_user_id = user_id!("@third:localhost");
|
||||
let other_display_name = "Raoul";
|
||||
let unknown_display_name = "Unknown";
|
||||
let other_display_name = DisplayName::new("Raoul");
|
||||
let unknown_display_name = DisplayName::new("Unknown");
|
||||
|
||||
// No event in store.
|
||||
let mut display_names = vec![user_display_name.to_owned()];
|
||||
let users = self.get_users_with_display_name(room_id, user_display_name).await.unwrap();
|
||||
let users = self.get_users_with_display_name(room_id, &user_display_name).await.unwrap();
|
||||
assert!(users.is_empty());
|
||||
let names = self.get_users_with_display_names(room_id, &display_names).await.unwrap();
|
||||
assert!(names.is_empty());
|
||||
@@ -1167,7 +1173,7 @@ impl StateStoreIntegrationTests for DynStateStore {
|
||||
.insert(user_display_name.to_owned(), [user_id.to_owned()].into());
|
||||
self.save_changes(&changes).await.unwrap();
|
||||
|
||||
let users = self.get_users_with_display_name(room_id, user_display_name).await.unwrap();
|
||||
let users = self.get_users_with_display_name(room_id, &user_display_name).await.unwrap();
|
||||
assert_eq!(users.len(), 1);
|
||||
let names = self.get_users_with_display_names(room_id, &display_names).await.unwrap();
|
||||
assert_eq!(names.len(), 1);
|
||||
@@ -1182,9 +1188,9 @@ impl StateStoreIntegrationTests for DynStateStore {
|
||||
self.save_changes(&changes).await.unwrap();
|
||||
|
||||
display_names.push(other_display_name.to_owned());
|
||||
let users = self.get_users_with_display_name(room_id, user_display_name).await.unwrap();
|
||||
let users = self.get_users_with_display_name(room_id, &user_display_name).await.unwrap();
|
||||
assert_eq!(users.len(), 1);
|
||||
let users = self.get_users_with_display_name(room_id, other_display_name).await.unwrap();
|
||||
let users = self.get_users_with_display_name(room_id, &other_display_name).await.unwrap();
|
||||
assert_eq!(users.len(), 2);
|
||||
let names = self.get_users_with_display_names(room_id, &display_names).await.unwrap();
|
||||
assert_eq!(names.len(), 2);
|
||||
|
||||
@@ -42,7 +42,8 @@ use super::{
|
||||
StateChanges, StateStore, StoreError,
|
||||
};
|
||||
use crate::{
|
||||
deserialized_responses::RawAnySyncOrStrippedState, store::QueueWedgeError,
|
||||
deserialized_responses::{DisplayName, RawAnySyncOrStrippedState},
|
||||
store::QueueWedgeError,
|
||||
MinimalRoomMemberEvent, RoomMemberships, StateStoreDataKey, StateStoreDataValue,
|
||||
};
|
||||
|
||||
@@ -61,7 +62,7 @@ pub struct MemoryStore {
|
||||
utd_hook_manager_data: StdRwLock<Option<GrowableBloom>>,
|
||||
account_data: StdRwLock<HashMap<GlobalAccountDataEventType, Raw<AnyGlobalAccountDataEvent>>>,
|
||||
profiles: StdRwLock<HashMap<OwnedRoomId, HashMap<OwnedUserId, MinimalRoomMemberEvent>>>,
|
||||
display_names: StdRwLock<HashMap<OwnedRoomId, HashMap<String, BTreeSet<OwnedUserId>>>>,
|
||||
display_names: StdRwLock<HashMap<OwnedRoomId, HashMap<DisplayName, BTreeSet<OwnedUserId>>>>,
|
||||
members: StdRwLock<HashMap<OwnedRoomId, HashMap<OwnedUserId, MembershipState>>>,
|
||||
room_info: StdRwLock<HashMap<OwnedRoomId, RoomInfo>>,
|
||||
room_state: StdRwLock<
|
||||
@@ -701,7 +702,7 @@ impl StateStore for MemoryStore {
|
||||
async fn get_users_with_display_name(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
display_name: &str,
|
||||
display_name: &DisplayName,
|
||||
) -> Result<BTreeSet<OwnedUserId>> {
|
||||
Ok(self
|
||||
.display_names
|
||||
@@ -715,21 +716,18 @@ impl StateStore for MemoryStore {
|
||||
async fn get_users_with_display_names<'a>(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
display_names: &'a [String],
|
||||
) -> Result<BTreeMap<&'a str, BTreeSet<OwnedUserId>>> {
|
||||
display_names: &'a [DisplayName],
|
||||
) -> Result<HashMap<&'a DisplayName, BTreeSet<OwnedUserId>>> {
|
||||
if display_names.is_empty() {
|
||||
return Ok(BTreeMap::new());
|
||||
return Ok(HashMap::new());
|
||||
}
|
||||
|
||||
let read_guard = &self.display_names.read().unwrap();
|
||||
let Some(room_names) = read_guard.get(room_id) else {
|
||||
return Ok(BTreeMap::new());
|
||||
return Ok(HashMap::new());
|
||||
};
|
||||
|
||||
Ok(display_names
|
||||
.iter()
|
||||
.filter_map(|n| room_names.get(n).map(|d| (n.as_str(), d.clone())))
|
||||
.collect())
|
||||
Ok(display_names.iter().filter_map(|n| room_names.get(n).map(|d| (n, d.clone()))).collect())
|
||||
}
|
||||
|
||||
async fn get_account_data_event(
|
||||
|
||||
@@ -21,7 +21,7 @@
|
||||
//! store.
|
||||
|
||||
use std::{
|
||||
collections::{BTreeMap, BTreeSet},
|
||||
collections::{BTreeMap, BTreeSet, HashMap},
|
||||
fmt,
|
||||
ops::Deref,
|
||||
result::Result as StdResult,
|
||||
@@ -58,6 +58,7 @@ use tokio::sync::{broadcast, Mutex, RwLock};
|
||||
use tracing::warn;
|
||||
|
||||
use crate::{
|
||||
deserialized_responses::DisplayName,
|
||||
event_cache::store as event_cache_store,
|
||||
rooms::{normal::RoomInfoNotableUpdate, RoomInfo, RoomState},
|
||||
MinimalRoomMemberEvent, Room, RoomStateFilter, SessionMeta,
|
||||
@@ -384,7 +385,7 @@ pub struct StateChanges {
|
||||
|
||||
/// A map from room id to a map of a display name and a set of user ids that
|
||||
/// share that display name in the given room.
|
||||
pub ambiguity_maps: BTreeMap<OwnedRoomId, BTreeMap<String, BTreeSet<OwnedUserId>>>,
|
||||
pub ambiguity_maps: BTreeMap<OwnedRoomId, HashMap<DisplayName, BTreeSet<OwnedUserId>>>,
|
||||
}
|
||||
|
||||
impl StateChanges {
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
use std::{
|
||||
borrow::Borrow,
|
||||
collections::{BTreeMap, BTreeSet},
|
||||
collections::{BTreeMap, BTreeSet, HashMap},
|
||||
fmt,
|
||||
sync::Arc,
|
||||
};
|
||||
@@ -46,7 +46,9 @@ use super::{
|
||||
StoreError,
|
||||
};
|
||||
use crate::{
|
||||
deserialized_responses::{RawAnySyncOrStrippedState, RawMemberEvent, RawSyncOrStrippedState},
|
||||
deserialized_responses::{
|
||||
DisplayName, RawAnySyncOrStrippedState, RawMemberEvent, RawSyncOrStrippedState,
|
||||
},
|
||||
MinimalRoomMemberEvent, RoomInfo, RoomMemberships,
|
||||
};
|
||||
|
||||
@@ -206,7 +208,7 @@ pub trait StateStore: AsyncTraitDeps {
|
||||
async fn get_users_with_display_name(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
display_name: &str,
|
||||
display_name: &DisplayName,
|
||||
) -> Result<BTreeSet<OwnedUserId>, Self::Error>;
|
||||
|
||||
/// Get all the users that use the given display names in the given room.
|
||||
@@ -219,8 +221,8 @@ pub trait StateStore: AsyncTraitDeps {
|
||||
async fn get_users_with_display_names<'a>(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
display_names: &'a [String],
|
||||
) -> Result<BTreeMap<&'a str, BTreeSet<OwnedUserId>>, Self::Error>;
|
||||
display_names: &'a [DisplayName],
|
||||
) -> Result<HashMap<&'a DisplayName, BTreeSet<OwnedUserId>>, Self::Error>;
|
||||
|
||||
/// Get an event out of the account data store.
|
||||
///
|
||||
@@ -567,7 +569,7 @@ impl<T: StateStore> StateStore for EraseStateStoreError<T> {
|
||||
async fn get_users_with_display_name(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
display_name: &str,
|
||||
display_name: &DisplayName,
|
||||
) -> Result<BTreeSet<OwnedUserId>, Self::Error> {
|
||||
self.0.get_users_with_display_name(room_id, display_name).await.map_err(Into::into)
|
||||
}
|
||||
@@ -575,8 +577,8 @@ impl<T: StateStore> StateStore for EraseStateStoreError<T> {
|
||||
async fn get_users_with_display_names<'a>(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
display_names: &'a [String],
|
||||
) -> Result<BTreeMap<&'a str, BTreeSet<OwnedUserId>>, Self::Error> {
|
||||
display_names: &'a [DisplayName],
|
||||
) -> Result<HashMap<&'a DisplayName, BTreeSet<OwnedUserId>>, Self::Error> {
|
||||
self.0.get_users_with_display_names(room_id, display_names).await.map_err(Into::into)
|
||||
}
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
// limitations under the License.
|
||||
|
||||
use std::{
|
||||
collections::{BTreeMap, BTreeSet, HashSet},
|
||||
collections::{BTreeMap, BTreeSet, HashMap, HashSet},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
@@ -23,7 +23,7 @@ use gloo_utils::format::JsValueSerdeExt;
|
||||
use growable_bloom_filter::GrowableBloom;
|
||||
use indexed_db_futures::prelude::*;
|
||||
use matrix_sdk_base::{
|
||||
deserialized_responses::RawAnySyncOrStrippedState,
|
||||
deserialized_responses::{DisplayName, RawAnySyncOrStrippedState},
|
||||
store::{
|
||||
ChildTransactionId, ComposerDraft, DependentQueuedRequest, DependentQueuedRequestKind,
|
||||
QueuedRequest, QueuedRequestKind, SentRequestKey, SerializableEventContent,
|
||||
@@ -665,7 +665,15 @@ impl_state_store!({
|
||||
let store = tx.object_store(keys::DISPLAY_NAMES)?;
|
||||
for (room_id, ambiguity_maps) in &changes.ambiguity_maps {
|
||||
for (display_name, map) in ambiguity_maps {
|
||||
let key = self.encode_key(keys::DISPLAY_NAMES, (room_id, display_name));
|
||||
let key = self.encode_key(
|
||||
keys::DISPLAY_NAMES,
|
||||
(
|
||||
room_id,
|
||||
display_name
|
||||
.as_normalized_str()
|
||||
.unwrap_or_else(|| display_name.as_raw_str()),
|
||||
),
|
||||
);
|
||||
|
||||
store.put_key_val(&key, &self.serialize_value(&map)?)?;
|
||||
}
|
||||
@@ -1122,12 +1130,18 @@ impl_state_store!({
|
||||
async fn get_users_with_display_name(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
display_name: &str,
|
||||
display_name: &DisplayName,
|
||||
) -> Result<BTreeSet<OwnedUserId>> {
|
||||
self.inner
|
||||
.transaction_on_one_with_mode(keys::DISPLAY_NAMES, IdbTransactionMode::Readonly)?
|
||||
.object_store(keys::DISPLAY_NAMES)?
|
||||
.get(&self.encode_key(keys::DISPLAY_NAMES, (room_id, display_name)))?
|
||||
.get(&self.encode_key(
|
||||
keys::DISPLAY_NAMES,
|
||||
(
|
||||
room_id,
|
||||
display_name.as_normalized_str().unwrap_or_else(|| display_name.as_raw_str()),
|
||||
),
|
||||
))?
|
||||
.await?
|
||||
.map(|f| self.deserialize_value::<BTreeSet<OwnedUserId>>(&f))
|
||||
.unwrap_or_else(|| Ok(Default::default()))
|
||||
@@ -1136,10 +1150,12 @@ impl_state_store!({
|
||||
async fn get_users_with_display_names<'a>(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
display_names: &'a [String],
|
||||
) -> Result<BTreeMap<&'a str, BTreeSet<OwnedUserId>>> {
|
||||
display_names: &'a [DisplayName],
|
||||
) -> Result<HashMap<&'a DisplayName, BTreeSet<OwnedUserId>>> {
|
||||
let mut map = HashMap::new();
|
||||
|
||||
if display_names.is_empty() {
|
||||
return Ok(BTreeMap::new());
|
||||
return Ok(map);
|
||||
}
|
||||
|
||||
let txn = self
|
||||
@@ -1147,15 +1163,24 @@ impl_state_store!({
|
||||
.transaction_on_one_with_mode(keys::DISPLAY_NAMES, IdbTransactionMode::Readonly)?;
|
||||
let store = txn.object_store(keys::DISPLAY_NAMES)?;
|
||||
|
||||
let mut map = BTreeMap::new();
|
||||
for display_name in display_names {
|
||||
if let Some(user_ids) = store
|
||||
.get(&self.encode_key(keys::DISPLAY_NAMES, (room_id, display_name)))?
|
||||
.get(
|
||||
&self.encode_key(
|
||||
keys::DISPLAY_NAMES,
|
||||
(
|
||||
room_id,
|
||||
display_name
|
||||
.as_normalized_str()
|
||||
.unwrap_or_else(|| display_name.as_raw_str()),
|
||||
),
|
||||
),
|
||||
)?
|
||||
.await?
|
||||
.map(|f| self.deserialize_value::<BTreeSet<OwnedUserId>>(&f))
|
||||
.transpose()?
|
||||
{
|
||||
map.insert(display_name.as_ref(), user_ids);
|
||||
map.insert(display_name, user_ids);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
collections::{BTreeMap, BTreeSet},
|
||||
collections::{BTreeMap, BTreeSet, HashMap},
|
||||
fmt, iter,
|
||||
path::Path,
|
||||
sync::Arc,
|
||||
@@ -9,7 +9,7 @@ use std::{
|
||||
use async_trait::async_trait;
|
||||
use deadpool_sqlite::{Object as SqliteAsyncConn, Pool as SqlitePool, Runtime};
|
||||
use matrix_sdk_base::{
|
||||
deserialized_responses::{RawAnySyncOrStrippedState, SyncOrStrippedState},
|
||||
deserialized_responses::{DisplayName, RawAnySyncOrStrippedState, SyncOrStrippedState},
|
||||
store::{
|
||||
migration_helpers::RoomInfoV1, ChildTransactionId, DependentQueuedRequest,
|
||||
DependentQueuedRequestKind, QueueWedgeError, QueuedRequest, QueuedRequestKind,
|
||||
@@ -1305,13 +1305,34 @@ impl StateStore for SqliteStateStore {
|
||||
let room_id = this.encode_key(keys::DISPLAY_NAME, room_id);
|
||||
|
||||
for (name, user_ids) in display_names {
|
||||
let name = this.encode_key(keys::DISPLAY_NAME, name);
|
||||
let encoded_name = this.encode_key(
|
||||
keys::DISPLAY_NAME,
|
||||
name.as_normalized_str().unwrap_or_else(|| name.as_raw_str()),
|
||||
);
|
||||
let data = this.serialize_json(&user_ids)?;
|
||||
|
||||
if user_ids.is_empty() {
|
||||
txn.remove_display_name(&room_id, &name)?;
|
||||
txn.remove_display_name(&room_id, &encoded_name)?;
|
||||
|
||||
// We can't do a migration to merge the previously distinct buckets of
|
||||
// user IDs since the display names themselves are hashed before they
|
||||
// are persisted in the store. So the store will always retain two
|
||||
// buckets: one for raw display names and one for normalised ones.
|
||||
//
|
||||
// We therefore do the next best thing, which is a sort of a soft
|
||||
// migration: we fetch both the raw and normalised buckets, then merge
|
||||
// the user IDs contained in them into a separate, temporary merged
|
||||
// bucket. The SDK then operates on the merged buckets exclusively. See
|
||||
// the comment in `get_users_with_display_names` for details.
|
||||
//
|
||||
// If the merged bucket is empty, that must mean that both the raw and
|
||||
// normalised buckets were also empty, so we can remove both from the
|
||||
// store.
|
||||
let raw_name = this.encode_key(keys::DISPLAY_NAME, name.as_raw_str());
|
||||
txn.remove_display_name(&room_id, &raw_name)?;
|
||||
} else {
|
||||
txn.set_display_name(&room_id, &name, &data)?;
|
||||
// We only create new buckets with the normalized display name.
|
||||
txn.set_display_name(&room_id, &encoded_name, &data)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1500,10 +1521,13 @@ impl StateStore for SqliteStateStore {
|
||||
async fn get_users_with_display_name(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
display_name: &str,
|
||||
display_name: &DisplayName,
|
||||
) -> Result<BTreeSet<OwnedUserId>> {
|
||||
let room_id = self.encode_key(keys::DISPLAY_NAME, room_id);
|
||||
let names = vec![self.encode_key(keys::DISPLAY_NAME, display_name)];
|
||||
let names = vec![self.encode_key(
|
||||
keys::DISPLAY_NAME,
|
||||
display_name.as_normalized_str().unwrap_or_else(|| display_name.as_raw_str()),
|
||||
)];
|
||||
|
||||
Ok(self
|
||||
.acquire()
|
||||
@@ -1520,33 +1544,49 @@ impl StateStore for SqliteStateStore {
|
||||
async fn get_users_with_display_names<'a>(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
display_names: &'a [String],
|
||||
) -> Result<BTreeMap<&'a str, BTreeSet<OwnedUserId>>> {
|
||||
display_names: &'a [DisplayName],
|
||||
) -> Result<HashMap<&'a DisplayName, BTreeSet<OwnedUserId>>> {
|
||||
let mut result = HashMap::new();
|
||||
|
||||
if display_names.is_empty() {
|
||||
return Ok(BTreeMap::new());
|
||||
return Ok(result);
|
||||
}
|
||||
|
||||
let room_id = self.encode_key(keys::DISPLAY_NAME, room_id);
|
||||
let mut names_map = display_names
|
||||
.iter()
|
||||
.map(|n| (self.encode_key(keys::DISPLAY_NAME, n), n.as_ref()))
|
||||
.flat_map(|display_name| {
|
||||
// We encode the display name as the `raw_str()` and the normalized string.
|
||||
//
|
||||
// This is for compatibility reasons since:
|
||||
// 1. Previously "Alice" and "alice" were considered to be distinct display
|
||||
// names, while we now consider them to be the same so we need to merge the
|
||||
// previously distinct buckets of user IDs.
|
||||
// 2. We can't do a migration to merge the previously distinct buckets of user
|
||||
// IDs since the display names itself are hashed before they are persisted
|
||||
// in the store.
|
||||
let raw =
|
||||
(self.encode_key(keys::DISPLAY_NAME, display_name.as_raw_str()), display_name);
|
||||
let normalized = display_name.as_normalized_str().map(|normalized| {
|
||||
(self.encode_key(keys::DISPLAY_NAME, normalized), display_name)
|
||||
});
|
||||
|
||||
iter::once(raw).chain(normalized.into_iter())
|
||||
})
|
||||
.collect::<BTreeMap<_, _>>();
|
||||
let names = names_map.keys().cloned().collect();
|
||||
|
||||
self.acquire()
|
||||
.await?
|
||||
.get_display_names(room_id, names)
|
||||
.await?
|
||||
.into_iter()
|
||||
.map(|(name, data)| {
|
||||
Ok((
|
||||
names_map
|
||||
.remove(name.as_slice())
|
||||
.expect("returned display names were requested"),
|
||||
self.deserialize_json(&data)?,
|
||||
))
|
||||
})
|
||||
.collect::<Result<BTreeMap<_, _>>>()
|
||||
for (name, data) in
|
||||
self.acquire().await?.get_display_names(room_id, names).await?.into_iter()
|
||||
{
|
||||
let display_name =
|
||||
names_map.remove(name.as_slice()).expect("returned display names were requested");
|
||||
let user_ids: BTreeSet<_> = self.deserialize_json(&data)?;
|
||||
|
||||
result.entry(display_name).or_insert_with(BTreeSet::new).extend(user_ids);
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
async fn get_account_data_event(
|
||||
|
||||
Reference in New Issue
Block a user