fix: Use the DisplayName struct to protect against homoglyph attacks

This commit is contained in:
Damir Jelić
2024-11-04 15:30:08 +01:00
parent e4ebeb8a42
commit d40aac89cb
12 changed files with 408 additions and 105 deletions

View File

@@ -16,7 +16,7 @@
#[cfg(feature = "e2e-encryption")]
use std::sync::Arc;
use std::{
collections::{BTreeMap, BTreeSet},
collections::{BTreeMap, BTreeSet, HashMap},
fmt, iter,
ops::Deref,
};
@@ -68,7 +68,7 @@ use crate::latest_event::{is_suitable_for_latest_event, LatestEvent, PossibleLat
#[cfg(feature = "e2e-encryption")]
use crate::RoomMemberships;
use crate::{
deserialized_responses::{RawAnySyncOrStrippedTimelineEvent, SyncTimelineEvent},
deserialized_responses::{DisplayName, RawAnySyncOrStrippedTimelineEvent, SyncTimelineEvent},
error::{Error, Result},
event_cache::store::EventCacheStoreLock,
response_processors::AccountDataProcessor,
@@ -1332,7 +1332,7 @@ impl BaseClient {
#[cfg(feature = "e2e-encryption")]
let mut user_ids = BTreeSet::new();
let mut ambiguity_map: BTreeMap<String, BTreeSet<OwnedUserId>> = BTreeMap::new();
let mut ambiguity_map: HashMap<DisplayName, BTreeSet<OwnedUserId>> = Default::default();
for raw_event in &response.chunk {
let member = match raw_event.deserialize() {
@@ -1363,7 +1363,11 @@ impl BaseClient {
if let StateEvent::Original(e) = &member {
if let Some(d) = &e.content.displayname {
ambiguity_map.entry(d.clone()).or_default().insert(member.state_key().clone());
let display_name = DisplayName::new(d);
ambiguity_map
.entry(display_name)
.or_default()
.insert(member.state_key().clone());
}
}

View File

@@ -469,10 +469,12 @@ impl MemberEvent {
///
/// It there is no `displayname` in the event's content, the localpart or
/// the user ID is returned.
pub fn display_name(&self) -> &str {
self.original_content()
.and_then(|c| c.displayname.as_deref())
.unwrap_or_else(|| self.user_id().localpart())
pub fn display_name(&self) -> DisplayName {
DisplayName::new(
self.original_content()
.and_then(|c| c.displayname.as_deref())
.unwrap_or_else(|| self.user_id().localpart()),
)
}
}

View File

@@ -13,7 +13,7 @@
// limitations under the License.
use std::{
collections::{BTreeMap, BTreeSet},
collections::{BTreeSet, HashMap},
sync::Arc,
};
@@ -30,7 +30,8 @@ use ruma::{
};
use crate::{
deserialized_responses::{MemberEvent, SyncOrStrippedState},
deserialized_responses::{DisplayName, MemberEvent, SyncOrStrippedState},
store::ambiguity_map::is_display_name_ambiguous,
MinimalRoomMemberEvent,
};
@@ -67,8 +68,10 @@ impl RoomMember {
} = room_info;
let is_room_creator = room_creator.as_deref() == Some(event.user_id());
let display_name_ambiguous =
users_display_names.get(event.display_name()).is_some_and(|s| s.len() > 1);
let display_name = event.display_name();
let display_name_ambiguous = users_display_names
.get(&display_name)
.is_some_and(|s| is_display_name_ambiguous(&display_name, s));
let is_ignored = ignored_users.as_ref().is_some_and(|s| s.contains(event.user_id()));
Self {
@@ -245,6 +248,6 @@ pub(crate) struct MemberRoomInfo<'a> {
pub(crate) power_levels: Arc<Option<SyncOrStrippedState<RoomPowerLevelsEventContent>>>,
pub(crate) max_power_level: i64,
pub(crate) room_creator: Option<OwnedUserId>,
pub(crate) users_display_names: BTreeMap<&'a str, BTreeSet<OwnedUserId>>,
pub(crate) users_display_names: HashMap<&'a DisplayName, BTreeSet<OwnedUserId>>,
pub(crate) ignored_users: Option<BTreeSet<OwnedUserId>>,
}

View File

@@ -67,7 +67,7 @@ use super::{
#[cfg(feature = "experimental-sliding-sync")]
use crate::latest_event::LatestEvent;
use crate::{
deserialized_responses::{MemberEvent, RawSyncOrStrippedState},
deserialized_responses::{DisplayName, MemberEvent, RawSyncOrStrippedState},
notification_settings::RoomNotificationMode,
read_receipts::RoomReadReceipts,
store::{DynStateStore, Result as StoreResult, StateStoreExt},
@@ -819,8 +819,7 @@ impl Room {
})
.collect::<BTreeMap<_, _>>();
let display_names =
member_events.iter().map(|e| e.display_name().to_owned()).collect::<Vec<_>>();
let display_names = member_events.iter().map(|e| e.display_name()).collect::<Vec<_>>();
let room_info = self.member_room_info(&display_names).await?;
let mut members = Vec::new();
@@ -900,7 +899,7 @@ impl Room {
let profile = self.store.get_profile(self.room_id(), user_id).await?;
let display_names = [event.display_name().to_owned()];
let display_names = [event.display_name()];
let room_info = self.member_room_info(&display_names).await?;
Ok(Some(RoomMember::from_parts(event, profile, presence, &room_info)))
@@ -911,7 +910,7 @@ impl Room {
/// Async because it can read from storage.
async fn member_room_info<'a>(
&self,
display_names: &'a [String],
display_names: &'a [DisplayName],
) -> StoreResult<MemberRoomInfo<'a>> {
let max_power_level = self.max_power_level();
let room_creator = self.inner.read().creator().map(ToOwned::to_owned);

View File

@@ -695,6 +695,10 @@ async fn cache_latest_events(
changes: Option<&StateChanges>,
store: Option<&Store>,
) {
use crate::{
deserialized_responses::DisplayName, store::ambiguity_map::is_display_name_ambiguous,
};
let mut encrypted_events =
Vec::with_capacity(room.latest_encrypted_events.read().unwrap().capacity());
@@ -752,11 +756,13 @@ async fn cache_latest_events(
.as_original()
.and_then(|profile| profile.content.displayname.as_ref())
.and_then(|display_name| {
let display_name = DisplayName::new(display_name);
changes.ambiguity_maps.get(room.room_id()).and_then(
|map_for_room| {
map_for_room
.get(display_name)
.map(|user_ids| user_ids.len() > 1)
map_for_room.get(&display_name).map(|users| {
is_display_name_ambiguous(&display_name, users)
})
},
)
});

View File

@@ -13,7 +13,7 @@
// limitations under the License.
use std::{
collections::{BTreeMap, BTreeSet},
collections::{BTreeMap, BTreeSet, HashMap},
sync::Arc,
};
@@ -24,18 +24,18 @@ use ruma::{
},
OwnedEventId, OwnedRoomId, OwnedUserId, RoomId, UserId,
};
use tracing::trace;
use tracing::{instrument, trace};
use super::{DynStateStore, Result, StateChanges};
use crate::{
deserialized_responses::{AmbiguityChange, RawMemberEvent},
deserialized_responses::{AmbiguityChange, DisplayName, RawMemberEvent},
store::StateStoreExt,
};
/// A map of users that use a certain display name.
#[derive(Debug, Clone)]
struct DisplayNameUsers {
display_name: String,
display_name: DisplayName,
users: BTreeSet<OwnedUserId>,
}
@@ -70,7 +70,7 @@ impl DisplayNameUsers {
/// Is the display name considered to be ambiguous.
fn is_ambiguous(&self) -> bool {
self.user_count() > 1
is_display_name_ambiguous(&self.display_name, &self.users)
}
}
@@ -82,10 +82,19 @@ fn is_member_active(membership: &MembershipState) -> bool {
#[derive(Debug)]
pub(crate) struct AmbiguityCache {
pub store: Arc<DynStateStore>,
pub cache: BTreeMap<OwnedRoomId, BTreeMap<String, BTreeSet<OwnedUserId>>>,
pub cache: BTreeMap<OwnedRoomId, HashMap<DisplayName, BTreeSet<OwnedUserId>>>,
pub changes: BTreeMap<OwnedRoomId, BTreeMap<OwnedEventId, AmbiguityChange>>,
}
#[instrument(ret)]
pub(crate) fn is_display_name_ambiguous(
display_name: &DisplayName,
users_with_display_name: &BTreeSet<OwnedUserId>,
) -> bool {
trace!("Checking if a display name is ambiguous");
display_name.is_inherently_ambiguous() || users_with_display_name.len() > 1
}
impl AmbiguityCache {
/// Create a new [`AmbiguityCache`] backed by the given state store.
pub fn new(store: Arc<DynStateStore>) -> Self {
@@ -224,18 +233,15 @@ impl AmbiguityCache {
async fn get_users_with_display_name(
&mut self,
room_id: &RoomId,
display_name: &str,
display_name: &DisplayName,
) -> Result<DisplayNameUsers> {
Ok(if let Some(u) = self.cache.entry(room_id.to_owned()).or_default().get(display_name) {
DisplayNameUsers { display_name: display_name.to_owned(), users: u.clone() }
DisplayNameUsers { display_name: display_name.clone(), users: u.clone() }
} else {
let users_with_display_name =
self.store.get_users_with_display_name(room_id, display_name).await?;
DisplayNameUsers {
display_name: display_name.to_owned(),
users: users_with_display_name,
}
DisplayNameUsers { display_name: display_name.clone(), users: users_with_display_name }
})
}
@@ -254,7 +260,8 @@ impl AmbiguityCache {
let old_display_name = self.get_old_display_name(changes, room_id, member_event).await?;
let old_map = if let Some(old_name) = old_display_name.as_deref() {
Some(self.get_users_with_display_name(room_id, old_name).await?)
let old_display_name = DisplayName::new(old_name);
Some(self.get_users_with_display_name(room_id, &old_display_name).await?)
} else {
None
};
@@ -275,11 +282,221 @@ impl AmbiguityCache {
new
};
Some(self.get_users_with_display_name(room_id, new_display_name).await?)
let new_display_name = DisplayName::new(new_display_name);
Some(self.get_users_with_display_name(room_id, &new_display_name).await?)
} else {
None
};
Ok((old_map, new_map))
}
#[cfg(test)]
fn check(&self, room_id: &RoomId, display_name: &DisplayName) -> bool {
self.cache
.get(room_id)
.and_then(|display_names| {
display_names
.get(display_name)
.map(|user_ids| is_display_name_ambiguous(display_name, user_ids))
})
.unwrap_or_else(|| {
panic!(
"The display name {:?} should be part of the cache {:?}",
display_name, self.cache
)
})
}
}
#[cfg(test)]
mod test {
use matrix_sdk_test::async_test;
use ruma::{room_id, server_name, user_id, EventId};
use serde_json::json;
use super::*;
use crate::store::{IntoStateStore, MemoryStore};
fn generate_event(user_id: &UserId, display_name: &str) -> SyncRoomMemberEvent {
let server_name = server_name!("localhost");
serde_json::from_value(json!({
"content": {
"displayname": display_name,
"membership": "join"
},
"event_id": EventId::new(server_name),
"origin_server_ts": 152037280,
"sender": user_id,
"state_key": user_id,
"type": "m.room.member",
}))
.expect("We should be able to deserialize the static member event")
}
macro_rules! assert_ambiguity {
(
[ $( ($user:literal, $display_name:literal) ),* ],
[ $( ($check_display_name:literal, $ambiguous:expr) ),* ] $(,)?
) => {
assert_ambiguity!(
[ $( ($user, $display_name) ),* ],
[ $( ($check_display_name, $ambiguous) ),* ],
"The test failed the ambiguity assertions"
)
};
(
[ $( ($user:literal, $display_name:literal) ),* ],
[ $( ($check_display_name:literal, $ambiguous:expr) ),* ],
$description:literal $(,)?
) => {
let store = MemoryStore::new();
let mut ambiguity_cache = AmbiguityCache::new(store.into_state_store());
let changes = Default::default();
let room_id = room_id!("!foo:bar");
macro_rules! add_display_name {
($u:literal, $n:literal) => {
let event = generate_event(user_id!($u), $n);
ambiguity_cache
.handle_event(&changes, room_id, &event)
.await
.expect("We should be able to handle a member event to calculate the ambiguity.");
};
}
macro_rules! assert_display_name_ambiguity {
($n:literal, $a:expr) => {
let display_name = DisplayName::new($n);
if ambiguity_cache.check(room_id, &display_name) != $a {
let foo = if $a { "be" } else { "not be" };
panic!("{}: the display name {} should {} ambiguous", $description, $n, foo);
}
};
}
$(
add_display_name!($user, $display_name);
)*
$(
assert_display_name_ambiguity!($check_display_name, $ambiguous);
)*
};
}
#[async_test]
async fn test_disambiguation() {
assert_ambiguity!(
[("@alice:localhost", "alice")],
[("alice", false)],
"Alice is alone in the room"
);
assert_ambiguity!(
[("@alice:localhost", "alice")],
[("Alice", false)],
"Alice is alone in the room and has a capitalized display name"
);
assert_ambiguity!(
[("@alice:localhost", "alice"), ("@bob:localhost", "alice")],
[("alice", true)],
"Alice and bob share a display name"
);
assert_ambiguity!(
[
("@alice:localhost", "alice"),
("@bob:localhost", "alice"),
("@carol:localhost", "carol")
],
[("alice", true), ("carol", false)],
"Alice and Bob share a display name, while Carol is unique"
);
assert_ambiguity!(
[("@alice:localhost", "alice"), ("@bob:localhost", "ALICE")],
[("alice", true)],
"Alice and Bob share a display name that is differently capitalized"
);
assert_ambiguity!(
[("@alice:localhost", "alice"), ("@bob:localhost", "аlice")],
[("alice", true)],
"Bob tries to impersonate Alice using a cyrilic а"
);
assert_ambiguity!(
[("@alice:localhost", "@bob:localhost"), ("@bob:localhost", "аlice")],
[("@bob:localhost", true)],
"Alice tries to impersonate bob using an mxid"
);
assert_ambiguity!(
[("@alice:localhost", "Sahasrahla"), ("@bob:localhost", "𝒮𝒶𝒽𝒶𝓈𝓇𝒶𝒽𝓁𝒶")],
[("Sahasrahla", true)],
"Bob tries to impersonate Alice using scripture symbols"
);
assert_ambiguity!(
[("@alice:localhost", "Sahasrahla"), ("@bob:localhost", "𝔖𝔞𝔥𝔞𝔰𝔯𝔞𝔥𝔩𝔞")],
[("Sahasrahla", true)],
"Bob tries to impersonate Alice using fraktur symbols"
);
assert_ambiguity!(
[("@alice:localhost", "Sahasrahla"), ("@bob:localhost", "Ⓢⓐⓗⓐⓢⓡⓐⓗⓛⓐ")],
[("Sahasrahla", true)],
"Bob tries to impersonate Alice using circled symbols"
);
assert_ambiguity!(
[("@alice:localhost", "Sahasrahla"), ("@bob:localhost", "🅂🄰🄷🄰🅂🅁🄰🄷🄻🄰")],
[("Sahasrahla", true)],
"Bob tries to impersonate Alice using squared symbols"
);
assert_ambiguity!(
[("@alice:localhost", "Sahasrahla"), ("@bob:localhost", "")],
[("Sahasrahla", true)],
"Bob tries to impersonate Alice using big unicode letters"
);
assert_ambiguity!(
[("@alice:localhost", "Sahasrahla"), ("@bob:localhost", "\u{202e}alharsahas")],
[("Sahasrahla", true)],
"Bob tries to impersonate Alice using left to right shenanigans"
);
assert_ambiguity!(
[("@alice:localhost", "Sahasrahla"), ("@bob:localhost", "Sa̴hasrahla")],
[("Sahasrahla", true)],
"Bob tries to impersonate Alice using a diacritical mark"
);
assert_ambiguity!(
[("@alice:localhost", "Sahasrahla"), ("@bob:localhost", "Sahas\u{200B}rahla")],
[("Sahasrahla", true)],
"Bob tries to impersonate Alice using a zero-width space"
);
assert_ambiguity!(
[("@alice:localhost", "Sahasrahla"), ("@bob:localhost", "Sahas\u{200D}rahla")],
[("Sahasrahla", true)],
"Bob tries to impersonate Alice using a zero-width space"
);
assert_ambiguity!(
[("@alice:localhost", "ff"), ("@bob:localhost", "\u{FB00}")],
[("ff", true)],
"Bob tries to impersonate Alice using a ligature"
);
}
}

View File

@@ -1,6 +1,6 @@
//! Trait and macro of integration tests for StateStore implementations.
use std::collections::{BTreeMap, BTreeSet};
use std::collections::{BTreeMap, BTreeSet, HashMap};
use assert_matches::assert_matches;
use assert_matches2::assert_let;
@@ -34,7 +34,8 @@ use ruma::{
use serde_json::{json, value::Value as JsonValue};
use super::{
send_queue::SentRequestKey, DependentQueuedRequestKind, DynStateStore, ServerCapabilities,
send_queue::SentRequestKey, DependentQueuedRequestKind, DisplayName, DynStateStore,
ServerCapabilities,
};
use crate::{
deserialized_responses::MemberEvent,
@@ -141,13 +142,15 @@ impl StateStoreIntegrationTests for DynStateStore {
room.handle_state_event(&topic_event);
changes.add_state_event(room_id, topic_event, topic_raw);
let mut room_ambiguity_map = BTreeMap::new();
let mut room_ambiguity_map = HashMap::new();
let mut room_profiles = BTreeMap::new();
let member_json: &JsonValue = &test_json::MEMBER;
let member_event: SyncRoomMemberEvent =
serde_json::from_value(member_json.clone()).unwrap();
let displayname = member_event.as_original().unwrap().content.displayname.clone().unwrap();
let displayname = DisplayName::new(
member_event.as_original().unwrap().content.displayname.as_ref().unwrap(),
);
room_ambiguity_map.insert(displayname.clone(), BTreeSet::from([user_id.to_owned()]));
room_profiles.insert(user_id.to_owned(), (&member_event).into());
@@ -256,6 +259,8 @@ impl StateStoreIntegrationTests for DynStateStore {
async fn test_populate_store(&self) -> Result<()> {
let room_id = room_id();
let user_id = user_id();
let display_name = DisplayName::new("example");
self.populate().await?;
assert!(self.get_kv_data(StateStoreDataKey::SyncToken).await?.is_some());
@@ -290,7 +295,7 @@ impl StateStoreIntegrationTests for DynStateStore {
"Expected to find 1 joined user ids"
);
assert_eq!(
self.get_users_with_display_name(room_id, "example").await?.len(),
self.get_users_with_display_name(room_id, &display_name).await?.len(),
2,
"Expected to find 2 display names for room"
);
@@ -962,6 +967,7 @@ impl StateStoreIntegrationTests for DynStateStore {
async fn test_room_removal(&self) -> Result<()> {
let room_id = room_id();
let user_id = user_id();
let display_name = DisplayName::new("example");
let stripped_room_id = stripped_room_id();
self.populate().await?;
@@ -990,7 +996,7 @@ impl StateStoreIntegrationTests for DynStateStore {
"still joined users found"
);
assert!(
self.get_users_with_display_name(room_id, "example").await?.is_empty(),
self.get_users_with_display_name(room_id, &display_name).await?.is_empty(),
"still display names found"
);
assert!(self
@@ -1145,15 +1151,15 @@ impl StateStoreIntegrationTests for DynStateStore {
async fn test_display_names_saving(&self) {
let room_id = room_id!("!test_display_names_saving:localhost");
let user_id = user_id();
let user_display_name = "User";
let user_display_name = DisplayName::new("User");
let second_user_id = user_id!("@second:localhost");
let third_user_id = user_id!("@third:localhost");
let other_display_name = "Raoul";
let unknown_display_name = "Unknown";
let other_display_name = DisplayName::new("Raoul");
let unknown_display_name = DisplayName::new("Unknown");
// No event in store.
let mut display_names = vec![user_display_name.to_owned()];
let users = self.get_users_with_display_name(room_id, user_display_name).await.unwrap();
let users = self.get_users_with_display_name(room_id, &user_display_name).await.unwrap();
assert!(users.is_empty());
let names = self.get_users_with_display_names(room_id, &display_names).await.unwrap();
assert!(names.is_empty());
@@ -1167,7 +1173,7 @@ impl StateStoreIntegrationTests for DynStateStore {
.insert(user_display_name.to_owned(), [user_id.to_owned()].into());
self.save_changes(&changes).await.unwrap();
let users = self.get_users_with_display_name(room_id, user_display_name).await.unwrap();
let users = self.get_users_with_display_name(room_id, &user_display_name).await.unwrap();
assert_eq!(users.len(), 1);
let names = self.get_users_with_display_names(room_id, &display_names).await.unwrap();
assert_eq!(names.len(), 1);
@@ -1182,9 +1188,9 @@ impl StateStoreIntegrationTests for DynStateStore {
self.save_changes(&changes).await.unwrap();
display_names.push(other_display_name.to_owned());
let users = self.get_users_with_display_name(room_id, user_display_name).await.unwrap();
let users = self.get_users_with_display_name(room_id, &user_display_name).await.unwrap();
assert_eq!(users.len(), 1);
let users = self.get_users_with_display_name(room_id, other_display_name).await.unwrap();
let users = self.get_users_with_display_name(room_id, &other_display_name).await.unwrap();
assert_eq!(users.len(), 2);
let names = self.get_users_with_display_names(room_id, &display_names).await.unwrap();
assert_eq!(names.len(), 2);

View File

@@ -42,7 +42,8 @@ use super::{
StateChanges, StateStore, StoreError,
};
use crate::{
deserialized_responses::RawAnySyncOrStrippedState, store::QueueWedgeError,
deserialized_responses::{DisplayName, RawAnySyncOrStrippedState},
store::QueueWedgeError,
MinimalRoomMemberEvent, RoomMemberships, StateStoreDataKey, StateStoreDataValue,
};
@@ -61,7 +62,7 @@ pub struct MemoryStore {
utd_hook_manager_data: StdRwLock<Option<GrowableBloom>>,
account_data: StdRwLock<HashMap<GlobalAccountDataEventType, Raw<AnyGlobalAccountDataEvent>>>,
profiles: StdRwLock<HashMap<OwnedRoomId, HashMap<OwnedUserId, MinimalRoomMemberEvent>>>,
display_names: StdRwLock<HashMap<OwnedRoomId, HashMap<String, BTreeSet<OwnedUserId>>>>,
display_names: StdRwLock<HashMap<OwnedRoomId, HashMap<DisplayName, BTreeSet<OwnedUserId>>>>,
members: StdRwLock<HashMap<OwnedRoomId, HashMap<OwnedUserId, MembershipState>>>,
room_info: StdRwLock<HashMap<OwnedRoomId, RoomInfo>>,
room_state: StdRwLock<
@@ -701,7 +702,7 @@ impl StateStore for MemoryStore {
async fn get_users_with_display_name(
&self,
room_id: &RoomId,
display_name: &str,
display_name: &DisplayName,
) -> Result<BTreeSet<OwnedUserId>> {
Ok(self
.display_names
@@ -715,21 +716,18 @@ impl StateStore for MemoryStore {
async fn get_users_with_display_names<'a>(
&self,
room_id: &RoomId,
display_names: &'a [String],
) -> Result<BTreeMap<&'a str, BTreeSet<OwnedUserId>>> {
display_names: &'a [DisplayName],
) -> Result<HashMap<&'a DisplayName, BTreeSet<OwnedUserId>>> {
if display_names.is_empty() {
return Ok(BTreeMap::new());
return Ok(HashMap::new());
}
let read_guard = &self.display_names.read().unwrap();
let Some(room_names) = read_guard.get(room_id) else {
return Ok(BTreeMap::new());
return Ok(HashMap::new());
};
Ok(display_names
.iter()
.filter_map(|n| room_names.get(n).map(|d| (n.as_str(), d.clone())))
.collect())
Ok(display_names.iter().filter_map(|n| room_names.get(n).map(|d| (n, d.clone()))).collect())
}
async fn get_account_data_event(

View File

@@ -21,7 +21,7 @@
//! store.
use std::{
collections::{BTreeMap, BTreeSet},
collections::{BTreeMap, BTreeSet, HashMap},
fmt,
ops::Deref,
result::Result as StdResult,
@@ -58,6 +58,7 @@ use tokio::sync::{broadcast, Mutex, RwLock};
use tracing::warn;
use crate::{
deserialized_responses::DisplayName,
event_cache::store as event_cache_store,
rooms::{normal::RoomInfoNotableUpdate, RoomInfo, RoomState},
MinimalRoomMemberEvent, Room, RoomStateFilter, SessionMeta,
@@ -384,7 +385,7 @@ pub struct StateChanges {
/// A map from room id to a map of a display name and a set of user ids that
/// share that display name in the given room.
pub ambiguity_maps: BTreeMap<OwnedRoomId, BTreeMap<String, BTreeSet<OwnedUserId>>>,
pub ambiguity_maps: BTreeMap<OwnedRoomId, HashMap<DisplayName, BTreeSet<OwnedUserId>>>,
}
impl StateChanges {

View File

@@ -14,7 +14,7 @@
use std::{
borrow::Borrow,
collections::{BTreeMap, BTreeSet},
collections::{BTreeMap, BTreeSet, HashMap},
fmt,
sync::Arc,
};
@@ -46,7 +46,9 @@ use super::{
StoreError,
};
use crate::{
deserialized_responses::{RawAnySyncOrStrippedState, RawMemberEvent, RawSyncOrStrippedState},
deserialized_responses::{
DisplayName, RawAnySyncOrStrippedState, RawMemberEvent, RawSyncOrStrippedState,
},
MinimalRoomMemberEvent, RoomInfo, RoomMemberships,
};
@@ -206,7 +208,7 @@ pub trait StateStore: AsyncTraitDeps {
async fn get_users_with_display_name(
&self,
room_id: &RoomId,
display_name: &str,
display_name: &DisplayName,
) -> Result<BTreeSet<OwnedUserId>, Self::Error>;
/// Get all the users that use the given display names in the given room.
@@ -219,8 +221,8 @@ pub trait StateStore: AsyncTraitDeps {
async fn get_users_with_display_names<'a>(
&self,
room_id: &RoomId,
display_names: &'a [String],
) -> Result<BTreeMap<&'a str, BTreeSet<OwnedUserId>>, Self::Error>;
display_names: &'a [DisplayName],
) -> Result<HashMap<&'a DisplayName, BTreeSet<OwnedUserId>>, Self::Error>;
/// Get an event out of the account data store.
///
@@ -567,7 +569,7 @@ impl<T: StateStore> StateStore for EraseStateStoreError<T> {
async fn get_users_with_display_name(
&self,
room_id: &RoomId,
display_name: &str,
display_name: &DisplayName,
) -> Result<BTreeSet<OwnedUserId>, Self::Error> {
self.0.get_users_with_display_name(room_id, display_name).await.map_err(Into::into)
}
@@ -575,8 +577,8 @@ impl<T: StateStore> StateStore for EraseStateStoreError<T> {
async fn get_users_with_display_names<'a>(
&self,
room_id: &RoomId,
display_names: &'a [String],
) -> Result<BTreeMap<&'a str, BTreeSet<OwnedUserId>>, Self::Error> {
display_names: &'a [DisplayName],
) -> Result<HashMap<&'a DisplayName, BTreeSet<OwnedUserId>>, Self::Error> {
self.0.get_users_with_display_names(room_id, display_names).await.map_err(Into::into)
}

View File

@@ -13,7 +13,7 @@
// limitations under the License.
use std::{
collections::{BTreeMap, BTreeSet, HashSet},
collections::{BTreeMap, BTreeSet, HashMap, HashSet},
sync::Arc,
};
@@ -23,7 +23,7 @@ use gloo_utils::format::JsValueSerdeExt;
use growable_bloom_filter::GrowableBloom;
use indexed_db_futures::prelude::*;
use matrix_sdk_base::{
deserialized_responses::RawAnySyncOrStrippedState,
deserialized_responses::{DisplayName, RawAnySyncOrStrippedState},
store::{
ChildTransactionId, ComposerDraft, DependentQueuedRequest, DependentQueuedRequestKind,
QueuedRequest, QueuedRequestKind, SentRequestKey, SerializableEventContent,
@@ -665,7 +665,15 @@ impl_state_store!({
let store = tx.object_store(keys::DISPLAY_NAMES)?;
for (room_id, ambiguity_maps) in &changes.ambiguity_maps {
for (display_name, map) in ambiguity_maps {
let key = self.encode_key(keys::DISPLAY_NAMES, (room_id, display_name));
let key = self.encode_key(
keys::DISPLAY_NAMES,
(
room_id,
display_name
.as_normalized_str()
.unwrap_or_else(|| display_name.as_raw_str()),
),
);
store.put_key_val(&key, &self.serialize_value(&map)?)?;
}
@@ -1122,12 +1130,18 @@ impl_state_store!({
async fn get_users_with_display_name(
&self,
room_id: &RoomId,
display_name: &str,
display_name: &DisplayName,
) -> Result<BTreeSet<OwnedUserId>> {
self.inner
.transaction_on_one_with_mode(keys::DISPLAY_NAMES, IdbTransactionMode::Readonly)?
.object_store(keys::DISPLAY_NAMES)?
.get(&self.encode_key(keys::DISPLAY_NAMES, (room_id, display_name)))?
.get(&self.encode_key(
keys::DISPLAY_NAMES,
(
room_id,
display_name.as_normalized_str().unwrap_or_else(|| display_name.as_raw_str()),
),
))?
.await?
.map(|f| self.deserialize_value::<BTreeSet<OwnedUserId>>(&f))
.unwrap_or_else(|| Ok(Default::default()))
@@ -1136,10 +1150,12 @@ impl_state_store!({
async fn get_users_with_display_names<'a>(
&self,
room_id: &RoomId,
display_names: &'a [String],
) -> Result<BTreeMap<&'a str, BTreeSet<OwnedUserId>>> {
display_names: &'a [DisplayName],
) -> Result<HashMap<&'a DisplayName, BTreeSet<OwnedUserId>>> {
let mut map = HashMap::new();
if display_names.is_empty() {
return Ok(BTreeMap::new());
return Ok(map);
}
let txn = self
@@ -1147,15 +1163,24 @@ impl_state_store!({
.transaction_on_one_with_mode(keys::DISPLAY_NAMES, IdbTransactionMode::Readonly)?;
let store = txn.object_store(keys::DISPLAY_NAMES)?;
let mut map = BTreeMap::new();
for display_name in display_names {
if let Some(user_ids) = store
.get(&self.encode_key(keys::DISPLAY_NAMES, (room_id, display_name)))?
.get(
&self.encode_key(
keys::DISPLAY_NAMES,
(
room_id,
display_name
.as_normalized_str()
.unwrap_or_else(|| display_name.as_raw_str()),
),
),
)?
.await?
.map(|f| self.deserialize_value::<BTreeSet<OwnedUserId>>(&f))
.transpose()?
{
map.insert(display_name.as_ref(), user_ids);
map.insert(display_name, user_ids);
}
}

View File

@@ -1,6 +1,6 @@
use std::{
borrow::Cow,
collections::{BTreeMap, BTreeSet},
collections::{BTreeMap, BTreeSet, HashMap},
fmt, iter,
path::Path,
sync::Arc,
@@ -9,7 +9,7 @@ use std::{
use async_trait::async_trait;
use deadpool_sqlite::{Object as SqliteAsyncConn, Pool as SqlitePool, Runtime};
use matrix_sdk_base::{
deserialized_responses::{RawAnySyncOrStrippedState, SyncOrStrippedState},
deserialized_responses::{DisplayName, RawAnySyncOrStrippedState, SyncOrStrippedState},
store::{
migration_helpers::RoomInfoV1, ChildTransactionId, DependentQueuedRequest,
DependentQueuedRequestKind, QueueWedgeError, QueuedRequest, QueuedRequestKind,
@@ -1305,13 +1305,34 @@ impl StateStore for SqliteStateStore {
let room_id = this.encode_key(keys::DISPLAY_NAME, room_id);
for (name, user_ids) in display_names {
let name = this.encode_key(keys::DISPLAY_NAME, name);
let encoded_name = this.encode_key(
keys::DISPLAY_NAME,
name.as_normalized_str().unwrap_or_else(|| name.as_raw_str()),
);
let data = this.serialize_json(&user_ids)?;
if user_ids.is_empty() {
txn.remove_display_name(&room_id, &name)?;
txn.remove_display_name(&room_id, &encoded_name)?;
// We can't do a migration to merge the previously distinct buckets of
// user IDs since the display names themselves are hashed before they
// are persisted in the store. So the store will always retain two
// buckets: one for raw display names and one for normalised ones.
//
// We therefore do the next best thing, which is a sort of a soft
// migration: we fetch both the raw and normalised buckets, then merge
// the user IDs contained in them into a separate, temporary merged
// bucket. The SDK then operates on the merged buckets exclusively. See
// the comment in `get_users_with_display_names` for details.
//
// If the merged bucket is empty, that must mean that both the raw and
// normalised buckets were also empty, so we can remove both from the
// store.
let raw_name = this.encode_key(keys::DISPLAY_NAME, name.as_raw_str());
txn.remove_display_name(&room_id, &raw_name)?;
} else {
txn.set_display_name(&room_id, &name, &data)?;
// We only create new buckets with the normalized display name.
txn.set_display_name(&room_id, &encoded_name, &data)?;
}
}
}
@@ -1500,10 +1521,13 @@ impl StateStore for SqliteStateStore {
async fn get_users_with_display_name(
&self,
room_id: &RoomId,
display_name: &str,
display_name: &DisplayName,
) -> Result<BTreeSet<OwnedUserId>> {
let room_id = self.encode_key(keys::DISPLAY_NAME, room_id);
let names = vec![self.encode_key(keys::DISPLAY_NAME, display_name)];
let names = vec![self.encode_key(
keys::DISPLAY_NAME,
display_name.as_normalized_str().unwrap_or_else(|| display_name.as_raw_str()),
)];
Ok(self
.acquire()
@@ -1520,33 +1544,49 @@ impl StateStore for SqliteStateStore {
async fn get_users_with_display_names<'a>(
&self,
room_id: &RoomId,
display_names: &'a [String],
) -> Result<BTreeMap<&'a str, BTreeSet<OwnedUserId>>> {
display_names: &'a [DisplayName],
) -> Result<HashMap<&'a DisplayName, BTreeSet<OwnedUserId>>> {
let mut result = HashMap::new();
if display_names.is_empty() {
return Ok(BTreeMap::new());
return Ok(result);
}
let room_id = self.encode_key(keys::DISPLAY_NAME, room_id);
let mut names_map = display_names
.iter()
.map(|n| (self.encode_key(keys::DISPLAY_NAME, n), n.as_ref()))
.flat_map(|display_name| {
// We encode the display name as the `raw_str()` and the normalized string.
//
// This is for compatibility reasons since:
// 1. Previously "Alice" and "alice" were considered to be distinct display
// names, while we now consider them to be the same so we need to merge the
// previously distinct buckets of user IDs.
// 2. We can't do a migration to merge the previously distinct buckets of user
// IDs since the display names itself are hashed before they are persisted
// in the store.
let raw =
(self.encode_key(keys::DISPLAY_NAME, display_name.as_raw_str()), display_name);
let normalized = display_name.as_normalized_str().map(|normalized| {
(self.encode_key(keys::DISPLAY_NAME, normalized), display_name)
});
iter::once(raw).chain(normalized.into_iter())
})
.collect::<BTreeMap<_, _>>();
let names = names_map.keys().cloned().collect();
self.acquire()
.await?
.get_display_names(room_id, names)
.await?
.into_iter()
.map(|(name, data)| {
Ok((
names_map
.remove(name.as_slice())
.expect("returned display names were requested"),
self.deserialize_json(&data)?,
))
})
.collect::<Result<BTreeMap<_, _>>>()
for (name, data) in
self.acquire().await?.get_display_names(room_id, names).await?.into_iter()
{
let display_name =
names_map.remove(name.as_slice()).expect("returned display names were requested");
let user_ids: BTreeSet<_> = self.deserialize_json(&data)?;
result.entry(display_name).or_insert_with(BTreeSet::new).extend(user_ids);
}
Ok(result)
}
async fn get_account_data_event(