base: Add state store method to fetch several state events at once

Signed-off-by: Kévin Commaille <zecakeh@tedomum.fr>
This commit is contained in:
Kévin Commaille
2023-06-01 16:58:35 +02:00
committed by Jonas Platte
parent ae4ead5550
commit f740195eb6
5 changed files with 368 additions and 86 deletions

View File

@@ -26,7 +26,7 @@ use ruma::{
},
mxc_uri, room_id,
serde::Raw,
uint, user_id, EventId, OwnedEventId, RoomId, UserId,
uint, user_id, EventId, OwnedEventId, OwnedUserId, RoomId, UserId,
};
use serde_json::{json, value::Value as JsonValue};
@@ -374,8 +374,21 @@ impl StateStoreIntegrationTests for DynStateStore {
async fn test_member_saving(&self) {
let room_id = room_id!("!test_member_saving:localhost");
let user_id = user_id();
let second_user_id = user_id!("@second:localhost");
let third_user_id = user_id!("@third:localhost");
let unknown_user_id = user_id!("@unknown:localhost");
// No event in store.
assert!(self.get_member_event(room_id, user_id).await.unwrap().is_none());
let member_events = self
.get_state_events_for_keys_static::<RoomMemberEventContent, _, _>(
room_id,
&[user_id.to_owned()],
)
.await;
assert!(member_events.unwrap().is_empty());
// One event in store.
let mut changes = StateChanges::default();
changes
.state
@@ -384,12 +397,71 @@ impl StateStoreIntegrationTests for DynStateStore {
.entry(StateEventType::RoomMember)
.or_default()
.insert(user_id.into(), membership_event().cast());
self.save_changes(&changes).await.unwrap();
assert!(self.get_member_event(room_id, user_id).await.unwrap().is_some());
assert!(self.get_member_event(room_id, user_id).await.unwrap().is_some());
let member_events = self
.get_state_events_for_keys_static::<RoomMemberEventContent, _, _>(
room_id,
&[user_id.to_owned()],
)
.await;
assert_eq!(member_events.unwrap().len(), 1);
let members = self.get_user_ids(room_id, RoomMemberships::empty()).await.unwrap();
assert!(!members.is_empty(), "We expected to find members for the room")
assert_eq!(members.len(), 1, "We expected to find members for the room");
// Several events in store.
let mut changes = StateChanges::default();
let changes_members = changes
.state
.entry(room_id.to_owned())
.or_default()
.entry(StateEventType::RoomMember)
.or_default();
changes_members.insert(
second_user_id.into(),
custom_membership_event(second_user_id, event_id!("$second_member_event")).cast(),
);
changes_members.insert(
third_user_id.into(),
custom_membership_event(third_user_id, event_id!("$third_member_event")).cast(),
);
self.save_changes(&changes).await.unwrap();
assert!(self.get_member_event(room_id, second_user_id).await.unwrap().is_some());
assert!(self.get_member_event(room_id, third_user_id).await.unwrap().is_some());
let member_events = self
.get_state_events_for_keys_static::<RoomMemberEventContent, _, _>(
room_id,
&[user_id.to_owned(), second_user_id.to_owned(), third_user_id.to_owned()],
)
.await;
assert_eq!(member_events.unwrap().len(), 3);
let members = self.get_user_ids(room_id, RoomMemberships::empty()).await.unwrap();
assert_eq!(members.len(), 3, "We expected to find members for the room");
// Several events in store with one unknown.
let member_events = self
.get_state_events_for_keys_static::<RoomMemberEventContent, _, _>(
room_id,
&[
user_id.to_owned(),
second_user_id.to_owned(),
third_user_id.to_owned(),
unknown_user_id.to_owned(),
],
)
.await;
assert_eq!(member_events.unwrap().len(), 3);
// Empty user IDs list.
let member_events = self
.get_state_events_for_keys_static::<RoomMemberEventContent, OwnedUserId, _>(
room_id,
&[],
)
.await;
assert!(member_events.unwrap().is_empty());
}
async fn test_filter_saving(&self) {
@@ -446,8 +518,21 @@ impl StateStoreIntegrationTests for DynStateStore {
async fn test_stripped_member_saving(&self) {
let room_id = room_id!("!test_stripped_member_saving:localhost");
let user_id = user_id();
let second_user_id = user_id!("@second:localhost");
let third_user_id = user_id!("@third:localhost");
let unknown_user_id = user_id!("@unknown:localhost");
// No event in store.
assert!(self.get_member_event(room_id, user_id).await.unwrap().is_none());
let member_events = self
.get_state_events_for_keys_static::<RoomMemberEventContent, _, _>(
room_id,
&[user_id.to_owned()],
)
.await;
assert!(member_events.unwrap().is_empty());
// One event in store.
let mut changes = StateChanges::default();
changes
.stripped_state
@@ -456,12 +541,67 @@ impl StateStoreIntegrationTests for DynStateStore {
.entry(StateEventType::RoomMember)
.or_default()
.insert(user_id.into(), stripped_membership_event().cast());
self.save_changes(&changes).await.unwrap();
assert!(self.get_member_event(room_id, user_id).await.unwrap().is_some());
assert!(self.get_member_event(room_id, user_id).await.unwrap().is_some());
let member_events = self
.get_state_events_for_keys_static::<RoomMemberEventContent, _, _>(
room_id,
&[user_id.to_owned()],
)
.await;
assert_eq!(member_events.unwrap().len(), 1);
let members = self.get_user_ids(room_id, RoomMemberships::empty()).await.unwrap();
assert!(!members.is_empty(), "We expected to find members for the room")
assert_eq!(members.len(), 1, "We expected to find members for the room");
// Several events in store.
let mut changes = StateChanges::default();
let changes_members = changes
.stripped_state
.entry(room_id.to_owned())
.or_default()
.entry(StateEventType::RoomMember)
.or_default();
changes_members
.insert(second_user_id.into(), custom_stripped_membership_event(second_user_id).cast());
changes_members
.insert(third_user_id.into(), custom_stripped_membership_event(third_user_id).cast());
self.save_changes(&changes).await.unwrap();
assert!(self.get_member_event(room_id, second_user_id).await.unwrap().is_some());
assert!(self.get_member_event(room_id, third_user_id).await.unwrap().is_some());
let member_events = self
.get_state_events_for_keys_static::<RoomMemberEventContent, _, _>(
room_id,
&[user_id.to_owned(), second_user_id.to_owned(), third_user_id.to_owned()],
)
.await;
assert_eq!(member_events.unwrap().len(), 3);
let members = self.get_user_ids(room_id, RoomMemberships::empty()).await.unwrap();
assert_eq!(members.len(), 3, "We expected to find members for the room");
// Several events in store with one unknown.
let member_events = self
.get_state_events_for_keys_static::<RoomMemberEventContent, _, _>(
room_id,
&[
user_id.to_owned(),
second_user_id.to_owned(),
third_user_id.to_owned(),
unknown_user_id.to_owned(),
],
)
.await;
assert_eq!(member_events.unwrap().len(), 3);
// Empty user IDs list.
let member_events = self
.get_state_events_for_keys_static::<RoomMemberEventContent, OwnedUserId, _>(
room_id,
&[],
)
.await;
assert!(member_events.unwrap().is_empty());
}
async fn test_power_level_saving(&self) {
@@ -1044,10 +1184,10 @@ fn custom_stripped_membership_event(user_id: &UserId) -> Raw<StrippedRoomMemberE
}
fn membership_event() -> Raw<SyncRoomMemberEvent> {
custom_membership_event(user_id(), event_id!("$h29iv0s8:example.com").to_owned())
custom_membership_event(user_id(), event_id!("$h29iv0s8:example.com"))
}
fn custom_membership_event(user_id: &UserId, event_id: OwnedEventId) -> Raw<SyncRoomMemberEvent> {
fn custom_membership_event(user_id: &UserId, event_id: &EventId) -> Raw<SyncRoomMemberEvent> {
let ev_json = json!({
"type": "m.room.member",
"content": RoomMemberEventContent::new(MembershipState::Join),

View File

@@ -363,33 +363,6 @@ impl MemoryStore {
Ok(self.presence.get(user_id).map(|p| p.clone()))
}
async fn get_state_event(
&self,
room_id: &RoomId,
event_type: StateEventType,
state_key: &str,
) -> Result<Option<RawAnySyncOrStrippedState>> {
if let Some(e) = self
.stripped_room_state
.get(room_id)
.as_ref()
.and_then(|events| events.get(&event_type))
.and_then(|m| m.get(state_key).map(|m| m.clone()))
{
Ok(Some(RawAnySyncOrStrippedState::Stripped(e)))
} else if let Some(e) = self
.room_state
.get(room_id)
.as_ref()
.and_then(|events| events.get(&event_type))
.and_then(|m| m.get(state_key).map(|m| m.clone()))
{
Ok(Some(RawAnySyncOrStrippedState::Sync(e)))
} else {
Ok(None)
}
}
async fn get_state_events(
&self,
room_id: &RoomId,
@@ -412,6 +385,40 @@ impl MemoryStore {
}
}
async fn get_state_events_for_keys(
&self,
room_id: &RoomId,
event_type: StateEventType,
state_keys: &[&str],
) -> Result<Vec<RawAnySyncOrStrippedState>> {
if let Some(stripped_state_events) = self
.stripped_room_state
.get(room_id)
.as_ref()
.and_then(|events| events.get(&event_type))
{
Ok(state_keys
.iter()
.filter_map(|k| {
stripped_state_events
.get(*k)
.map(|e| RawAnySyncOrStrippedState::Stripped(e.clone()))
})
.collect())
} else if let Some(sync_state_events) =
self.room_state.get(room_id).as_ref().and_then(|events| events.get(&event_type))
{
Ok(state_keys
.iter()
.filter_map(|k| {
sync_state_events.get(*k).map(|e| RawAnySyncOrStrippedState::Sync(e.clone()))
})
.collect())
} else {
Ok(Vec::new())
}
}
async fn get_profile(
&self,
room_id: &RoomId,
@@ -583,7 +590,11 @@ impl StateStore for MemoryStore {
event_type: StateEventType,
state_key: &str,
) -> Result<Option<RawAnySyncOrStrippedState>> {
self.get_state_event(room_id, event_type, state_key).await
Ok(self
.get_state_events_for_keys(room_id, event_type, &[state_key])
.await?
.into_iter()
.next())
}
async fn get_state_events(
@@ -594,6 +605,15 @@ impl StateStore for MemoryStore {
self.get_state_events(room_id, event_type).await
}
async fn get_state_events_for_keys(
&self,
room_id: &RoomId,
event_type: StateEventType,
state_keys: &[&str],
) -> Result<Vec<RawAnySyncOrStrippedState>, Self::Error> {
self.get_state_events_for_keys(room_id, event_type, state_keys).await
}
async fn get_profile(
&self,
room_id: &RoomId,

View File

@@ -117,6 +117,23 @@ pub trait StateStore: AsyncTraitDeps {
event_type: StateEventType,
) -> Result<Vec<RawAnySyncOrStrippedState>, Self::Error>;
/// Get a list of state events for a given room, `StateEventType`, and the
/// given state keys.
///
/// # Arguments
///
/// * `room_id` - The id of the room to find events for.
///
/// * `event_type` - The event type.
///
/// * `state_keys` - The list of state keys to find.
async fn get_state_events_for_keys(
&self,
room_id: &RoomId,
event_type: StateEventType,
state_keys: &[&str],
) -> Result<Vec<RawAnySyncOrStrippedState>, Self::Error>;
/// Get the current profile for the given user in the given room.
///
/// # Arguments
@@ -370,6 +387,15 @@ impl<T: StateStore> StateStore for EraseStateStoreError<T> {
self.0.get_state_events(room_id, event_type).await.map_err(Into::into)
}
async fn get_state_events_for_keys(
&self,
room_id: &RoomId,
event_type: StateEventType,
state_keys: &[&str],
) -> Result<Vec<RawAnySyncOrStrippedState>, Self::Error> {
self.0.get_state_events_for_keys(room_id, event_type, state_keys).await.map_err(Into::into)
}
async fn get_profile(
&self,
room_id: &RoomId,
@@ -563,6 +589,39 @@ pub trait StateStoreExt: StateStore {
.collect())
}
/// Get a list of state events of a statically-known type for a given room
/// and given state keys.
///
/// # Arguments
///
/// * `room_id` - The id of the room to find events for.
///
/// * `state_keys` - The list of state keys to find.
async fn get_state_events_for_keys_static<'a, C, K, I>(
&self,
room_id: &RoomId,
state_keys: I,
) -> Result<Vec<RawSyncOrStrippedState<C>>, Self::Error>
where
C: StaticEventContent + StaticStateEventContent + RedactContent,
C::StateKey: Borrow<K>,
C::Redacted: RedactedStateEventContent,
K: AsRef<str> + Sized + Sync + 'a,
I: IntoIterator<Item = &'a K> + Send,
I::IntoIter: Send,
{
Ok(self
.get_state_events_for_keys(
room_id,
C::TYPE.into(),
&state_keys.into_iter().map(|k| k.as_ref()).collect::<Vec<_>>(),
)
.await?
.into_iter()
.map(|raw| raw.cast())
.collect())
}
/// Get an event of a statically-known type from the account data store.
async fn get_account_data_event_static<C>(
&self,

View File

@@ -802,29 +802,11 @@ impl_state_store!({
event_type: StateEventType,
state_key: &str,
) -> Result<Option<RawAnySyncOrStrippedState>> {
if let Some(e) = self
.inner
.transaction_on_one_with_mode(keys::STRIPPED_ROOM_STATE, IdbTransactionMode::Readonly)?
.object_store(keys::STRIPPED_ROOM_STATE)?
.get(&self.encode_key(keys::STRIPPED_ROOM_STATE, (room_id, &event_type, state_key)))?
Ok(self
.get_state_events_for_keys(room_id, event_type, &[state_key])
.await?
.map(|f| self.deserialize_event(&f))
.transpose()?
{
Ok(Some(RawAnySyncOrStrippedState::Stripped(e)))
} else if let Some(e) = self
.inner
.transaction_on_one_with_mode(keys::ROOM_STATE, IdbTransactionMode::Readonly)?
.object_store(keys::ROOM_STATE)?
.get(&self.encode_key(keys::ROOM_STATE, (room_id, event_type, state_key)))?
.await?
.map(|f| self.deserialize_event(&f))
.transpose()?
{
Ok(Some(RawAnySyncOrStrippedState::Sync(e)))
} else {
Ok(None)
}
.into_iter()
.next())
}
async fn get_state_events(
@@ -862,6 +844,64 @@ impl_state_store!({
.collect::<Vec<_>>())
}
async fn get_state_events_for_keys(
&self,
room_id: &RoomId,
event_type: StateEventType,
state_keys: &[&str],
) -> Result<Vec<RawAnySyncOrStrippedState>> {
if state_keys.is_empty() {
return Ok(Vec::new());
}
let mut events = Vec::with_capacity(state_keys.len());
{
let txn = self.inner.transaction_on_one_with_mode(
keys::STRIPPED_ROOM_STATE,
IdbTransactionMode::Readonly,
)?;
let store = txn.object_store(keys::STRIPPED_ROOM_STATE)?;
for state_key in state_keys {
if let Some(event) =
store
.get(&self.encode_key(
keys::STRIPPED_ROOM_STATE,
(room_id, &event_type, state_key),
))?
.await?
.map(|f| self.deserialize_event(&f))
.transpose()?
{
events.push(RawAnySyncOrStrippedState::Stripped(event));
}
}
if !events.is_empty() {
return Ok(events);
}
}
let txn = self
.inner
.transaction_on_one_with_mode(keys::ROOM_STATE, IdbTransactionMode::Readonly)?;
let store = txn.object_store(keys::ROOM_STATE)?;
for state_key in state_keys {
if let Some(event) = store
.get(&self.encode_key(keys::ROOM_STATE, (room_id, &event_type, state_key)))?
.await?
.map(|f| self.deserialize_event(&f))
.transpose()?
{
events.push(RawAnySyncOrStrippedState::Sync(event));
}
}
Ok(events)
}
async fn get_profile(
&self,
room_id: &RoomId,

View File

@@ -577,21 +577,26 @@ trait SqliteObjectStateStoreExt: SqliteObjectExt {
}
}
async fn get_maybe_stripped_state_event(
async fn get_maybe_stripped_state_events_for_keys(
&self,
room_id: Key,
event_type: Key,
state_key: Key,
) -> Result<Option<(bool, Vec<u8>)>> {
state_keys: Vec<Key>,
) -> Result<Vec<(bool, Vec<u8>)>> {
let sql_params = vec!["?"; state_keys.len()].join(", ");
let sql = format!(
"SELECT stripped, data FROM state_event
WHERE room_id = ? AND event_type = ? AND state_key IN ({sql_params})"
);
let params = [room_id, event_type].into_iter().chain(state_keys);
Ok(self
.query_row(
"SELECT stripped, data FROM state_event
WHERE room_id = ? AND event_type = ? AND state_key = ?",
(room_id, event_type, state_key),
|row| Ok((row.get(0)?, row.get(1)?)),
)
.await
.optional()?)
.prepare(sql, move |mut stmt| {
stmt.query(rusqlite::params_from_iter(params))?
.mapped(|row| Ok((row.get(0)?, row.get(1)?)))
.collect()
})
.await?)
}
async fn get_maybe_stripped_state_events(
@@ -1102,23 +1107,11 @@ impl StateStore for SqliteStateStore {
event_type: StateEventType,
state_key: &str,
) -> Result<Option<RawAnySyncOrStrippedState>> {
let room_id = self.encode_key(keys::STATE_EVENT, room_id);
let event_type = self.encode_key(keys::STATE_EVENT, event_type.to_string());
let state_key = self.encode_key(keys::STATE_EVENT, state_key);
self.acquire()
Ok(self
.get_state_events_for_keys(room_id, event_type, &[state_key])
.await?
.get_maybe_stripped_state_event(room_id, event_type, state_key)
.await?
.map(|(stripped, data)| {
let ev = if stripped {
RawAnySyncOrStrippedState::Stripped(self.deserialize_json(&data)?)
} else {
RawAnySyncOrStrippedState::Sync(self.deserialize_json(&data)?)
};
Ok(ev)
})
.transpose()
.into_iter()
.next())
}
async fn get_state_events(
@@ -1145,6 +1138,36 @@ impl StateStore for SqliteStateStore {
.collect()
}
async fn get_state_events_for_keys(
&self,
room_id: &RoomId,
event_type: StateEventType,
state_keys: &[&str],
) -> Result<Vec<RawAnySyncOrStrippedState>, Self::Error> {
if state_keys.is_empty() {
return Ok(Vec::new());
}
let room_id = self.encode_key(keys::STATE_EVENT, room_id);
let event_type = self.encode_key(keys::STATE_EVENT, event_type.to_string());
let state_keys = state_keys.iter().map(|k| self.encode_key(keys::STATE_EVENT, k)).collect();
self.acquire()
.await?
.get_maybe_stripped_state_events_for_keys(room_id, event_type, state_keys)
.await?
.into_iter()
.map(|(stripped, data)| {
let ev = if stripped {
RawAnySyncOrStrippedState::Stripped(self.deserialize_json(&data)?)
} else {
RawAnySyncOrStrippedState::Sync(self.deserialize_json(&data)?)
};
Ok(ev)
})
.collect()
}
async fn get_profile(
&self,
room_id: &RoomId,