feat(sdk): Add Client::add_room_event_handler

This commit is contained in:
Jonas Platte
2022-07-29 17:18:34 +02:00
committed by Jonas Platte
parent 78169d516e
commit 0716f6afaa
2 changed files with 157 additions and 28 deletions

View File

@@ -84,8 +84,8 @@ use crate::{
config::RequestConfig,
error::{HttpError, HttpResult},
event_handler::{
EventHandler, EventHandlerFn, EventHandlerFut, EventHandlerHandle, EventHandlerResult,
EventHandlerWrapper, EventKind, SyncEvent,
EventHandler, EventHandlerFn, EventHandlerFut, EventHandlerHandle, EventHandlerKey,
EventHandlerResult, EventHandlerWrapper, SyncEvent,
},
http_client::HttpClient,
room, Account, Error, Result,
@@ -106,7 +106,7 @@ const DEFAULT_UPLOAD_SPEED: u64 = 125_000;
/// 5 min minimal upload request timeout, used to clamp the request timeout.
const MIN_UPLOAD_REQUEST_TIMEOUT: Duration = Duration::from_secs(60 * 5);
type EventHandlerMap = BTreeMap<(EventKind, &'static str), Vec<EventHandlerWrapper>>;
type EventHandlerMap = BTreeMap<EventHandlerKey<'static>, Vec<EventHandlerWrapper>>;
type NotificationHandlerFut = EventHandlerFut;
type NotificationHandlerFn =
@@ -455,8 +455,38 @@ impl Client {
H: EventHandler<Ev, Ctx>,
<H::Future as Future>::Output: EventHandlerResult,
{
let key = (Ev::KIND, Ev::TYPE);
self.add_event_handler_impl(handler, None).await
}
/// Register a handler for a specific room, and event type.
///
/// This method works the same way as
/// [`add_event_handler`][Self::add_event_handler], except that the handler
/// will only be called for events in the room with the specified ID. See
/// that method for more details on event handler functions.
pub async fn add_room_event_handler<Ev, Ctx, H>(
&self,
room_id: &RoomId,
handler: H,
) -> EventHandlerHandle
where
Ev: SyncEvent + DeserializeOwned + Send + 'static,
H: EventHandler<Ev, Ctx>,
<H::Future as Future>::Output: EventHandlerResult,
{
self.add_event_handler_impl(handler, Some(room_id.to_owned())).await
}
async fn add_event_handler_impl<Ev, Ctx, H>(
&self,
handler: H,
room_id: Option<OwnedRoomId>,
) -> EventHandlerHandle
where
Ev: SyncEvent + DeserializeOwned + Send + 'static,
H: EventHandler<Ev, Ctx>,
<H::Future as Future>::Output: EventHandlerResult,
{
let handler_fn: Box<EventHandlerFn> = Box::new(move |data| {
let maybe_fut = serde_json::from_str(data.raw.get())
.map(|ev| handler.clone().handle_event(ev, data));
@@ -484,14 +514,13 @@ impl Client {
});
let handler_id = self.inner.event_handler_counter.fetch_add(1, SeqCst);
let handle = EventHandlerHandle { handler_id, ev_id: key };
let handle = EventHandlerHandle { handler_id, ev_kind: Ev::KIND, ev_type: Ev::TYPE };
self.inner
.event_handlers
.write()
.await
.entry(key)
.entry(handle.to_key(room_id))
.or_default()
.push(EventHandlerWrapper { handler_fn, handle });
@@ -560,9 +589,10 @@ impl Client {
/// # });
/// ```
pub async fn remove_event_handler(&self, handle: EventHandlerHandle) {
let key = handle.to_key(None);
let mut event_handlers = self.inner.event_handlers.write().await;
if let btree_map::Entry::Occupied(mut entry) = event_handlers.entry(handle.ev_id) {
if let btree_map::Entry::Occupied(mut entry) = event_handlers.entry(key) {
let v = entry.get_mut();
v.retain(|e| e.handle.handler_id != handle.handler_id);

View File

@@ -36,7 +36,7 @@ use std::any::TypeId;
use std::{borrow::Cow, fmt, future::Future, ops::Deref, pin::Pin};
use matrix_sdk_base::deserialized_responses::{EncryptionInfo, SyncRoomEvent};
use ruma::{events::AnySyncStateEvent, serde::Raw};
use ruma::{events::AnySyncStateEvent, serde::Raw, OwnedRoomId};
use serde::Deserialize;
use serde_json::value::RawValue as RawJsonValue;
use tracing::error;
@@ -90,6 +90,13 @@ pub trait SyncEvent {
const TYPE: &'static str;
}
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord)]
pub(crate) struct EventHandlerKey<'a> {
pub ev_kind: EventKind,
pub ev_type: &'a str,
pub room_id: Option<OwnedRoomId>,
}
pub(crate) struct EventHandlerWrapper {
pub handler_fn: Box<EventHandlerFn>,
pub handle: EventHandlerHandle,
@@ -99,10 +106,17 @@ pub(crate) struct EventHandlerWrapper {
/// [`Client::remove_event_handler`].
#[derive(Clone, Copy, Debug)]
pub struct EventHandlerHandle {
pub(crate) ev_id: (EventKind, &'static str),
pub(crate) ev_kind: EventKind,
pub(crate) ev_type: &'static str,
pub(crate) handler_id: u64,
}
impl EventHandlerHandle {
pub(crate) fn to_key(self, room_id: Option<OwnedRoomId>) -> EventHandlerKey<'static> {
EventHandlerKey { ev_kind: self.ev_kind, ev_type: self.ev_type, room_id }
}
}
impl EventHandlerContext for EventHandlerHandle {
fn from_data(data: &EventHandlerData<'_>) -> Option<Self> {
Some(data.handle)
@@ -396,26 +410,39 @@ impl Client {
for x in list {
let (raw_event, encryption_info) = get_event_details(x);
let (ev_kind, ev_type) = get_id(raw_event)?;
let event_handler_id = (ev_kind, &*ev_type);
let ev_type: &str = ev_type.as_ref();
// Construct event handler futures
let futures: Vec<_> = self
.event_handlers()
.await
.get(&event_handler_id)
.into_iter()
.flatten()
.map(|handler_wrapper| {
let data = EventHandlerData {
client: self.clone(),
room: room.clone(),
raw: raw_event.json(),
encryption_info,
handle: handler_wrapper.handle,
};
(handler_wrapper.handler_fn)(data)
})
.collect();
let futures: Vec<_> = {
let handlers_lock = self.event_handlers().await;
let non_room_event_handlers = handlers_lock
.get(&EventHandlerKey { ev_kind, ev_type, room_id: None })
.into_iter()
.flatten();
let room_event_handlers = room.iter().flat_map(|r| {
let room_id = Some(r.room_id().to_owned());
handlers_lock
.get(&EventHandlerKey { ev_kind, ev_type, room_id })
.into_iter()
.flatten()
});
non_room_event_handlers
.chain(room_event_handlers)
.map(|handler_wrapper| {
let data = EventHandlerData {
client: self.clone(),
room: room.clone(),
raw: raw_event.json(),
encryption_info,
handle: handler_wrapper.handle,
};
(handler_wrapper.handler_fn)(data)
})
.collect()
};
// Run the event handler futures with the `self.event_handlers` lock
// no longer being held, in order.
@@ -601,6 +628,7 @@ mod tests {
events::{
room::{
member::{OriginalSyncRoomMemberEvent, StrippedRoomMemberEvent},
name::OriginalSyncRoomNameEvent,
power_levels::OriginalSyncRoomPowerLevelsEvent,
},
typing::SyncTypingEvent,
@@ -715,6 +743,77 @@ mod tests {
Ok(())
}
#[async_test]
async fn add_room_event_handler() -> crate::Result<()> {
use std::sync::atomic::{AtomicU8, Ordering::SeqCst};
let client = crate::client::tests::logged_in_client(None).await;
let room_id_a = room_id!("!foo:example.org");
let room_id_b = room_id!("!bar:matrix.org");
let member_count = Arc::new(AtomicU8::new(0));
let power_levels_count = Arc::new(AtomicU8::new(0));
// Room event handlers for member events in both rooms
client
.add_room_event_handler(room_id_a, {
let member_count = member_count.clone();
move |_ev: OriginalSyncRoomMemberEvent, _room: room::Room| {
member_count.fetch_add(1, SeqCst);
future::ready(())
}
})
.await;
client
.add_room_event_handler(room_id_b, {
let member_count = member_count.clone();
move |_ev: OriginalSyncRoomMemberEvent, _room: room::Room| {
member_count.fetch_add(1, SeqCst);
future::ready(())
}
})
.await;
// Power levels event handlers for member events in room A
client
.add_room_event_handler(room_id_a, {
let power_levels_count = power_levels_count.clone();
move |_ev: OriginalSyncRoomPowerLevelsEvent, _client: Client, _room: room::Room| {
power_levels_count.fetch_add(1, SeqCst);
future::ready(())
}
})
.await;
// Room name event handler for room name events in room B
client
.add_room_event_handler(room_id_b, move |_ev: OriginalSyncRoomNameEvent| async {
unreachable!("No room event in room B")
})
.await;
let response = EventBuilder::default()
.add_joined_room(
JoinedRoomBuilder::new(room_id_a)
.add_timeline_event(TimelineTestEvent::Member)
.add_state_event(StateTestEvent::PowerLevels)
.add_state_event(StateTestEvent::RoomName),
)
.add_joined_room(
JoinedRoomBuilder::new(room_id_b)
.add_timeline_event(TimelineTestEvent::Member)
.add_state_event(StateTestEvent::PowerLevels),
)
.build_sync_response();
client.process_sync(response).await?;
assert_eq!(member_count.load(SeqCst), 2);
assert_eq!(power_levels_count.load(SeqCst), 1);
Ok(())
}
#[async_test]
async fn remove_event_handler() -> crate::Result<()> {
use std::sync::atomic::{AtomicU8, Ordering::SeqCst};