From 0716f6afaaacd22f61fc275d832db33680cec455 Mon Sep 17 00:00:00 2001 From: Jonas Platte Date: Fri, 29 Jul 2022 17:18:34 +0200 Subject: [PATCH] feat(sdk): Add Client::add_room_event_handler --- crates/matrix-sdk/src/client/mod.rs | 46 ++++++-- crates/matrix-sdk/src/event_handler.rs | 139 +++++++++++++++++++++---- 2 files changed, 157 insertions(+), 28 deletions(-) diff --git a/crates/matrix-sdk/src/client/mod.rs b/crates/matrix-sdk/src/client/mod.rs index 594e1dcce..f874802a7 100644 --- a/crates/matrix-sdk/src/client/mod.rs +++ b/crates/matrix-sdk/src/client/mod.rs @@ -84,8 +84,8 @@ use crate::{ config::RequestConfig, error::{HttpError, HttpResult}, event_handler::{ - EventHandler, EventHandlerFn, EventHandlerFut, EventHandlerHandle, EventHandlerResult, - EventHandlerWrapper, EventKind, SyncEvent, + EventHandler, EventHandlerFn, EventHandlerFut, EventHandlerHandle, EventHandlerKey, + EventHandlerResult, EventHandlerWrapper, SyncEvent, }, http_client::HttpClient, room, Account, Error, Result, @@ -106,7 +106,7 @@ const DEFAULT_UPLOAD_SPEED: u64 = 125_000; /// 5 min minimal upload request timeout, used to clamp the request timeout. const MIN_UPLOAD_REQUEST_TIMEOUT: Duration = Duration::from_secs(60 * 5); -type EventHandlerMap = BTreeMap<(EventKind, &'static str), Vec>; +type EventHandlerMap = BTreeMap, Vec>; type NotificationHandlerFut = EventHandlerFut; type NotificationHandlerFn = @@ -455,8 +455,38 @@ impl Client { H: EventHandler, ::Output: EventHandlerResult, { - let key = (Ev::KIND, Ev::TYPE); + self.add_event_handler_impl(handler, None).await + } + /// Register a handler for a specific room, and event type. + /// + /// This method works the same way as + /// [`add_event_handler`][Self::add_event_handler], except that the handler + /// will only be called for events in the room with the specified ID. See + /// that method for more details on event handler functions. + pub async fn add_room_event_handler( + &self, + room_id: &RoomId, + handler: H, + ) -> EventHandlerHandle + where + Ev: SyncEvent + DeserializeOwned + Send + 'static, + H: EventHandler, + ::Output: EventHandlerResult, + { + self.add_event_handler_impl(handler, Some(room_id.to_owned())).await + } + + async fn add_event_handler_impl( + &self, + handler: H, + room_id: Option, + ) -> EventHandlerHandle + where + Ev: SyncEvent + DeserializeOwned + Send + 'static, + H: EventHandler, + ::Output: EventHandlerResult, + { let handler_fn: Box = Box::new(move |data| { let maybe_fut = serde_json::from_str(data.raw.get()) .map(|ev| handler.clone().handle_event(ev, data)); @@ -484,14 +514,13 @@ impl Client { }); let handler_id = self.inner.event_handler_counter.fetch_add(1, SeqCst); - - let handle = EventHandlerHandle { handler_id, ev_id: key }; + let handle = EventHandlerHandle { handler_id, ev_kind: Ev::KIND, ev_type: Ev::TYPE }; self.inner .event_handlers .write() .await - .entry(key) + .entry(handle.to_key(room_id)) .or_default() .push(EventHandlerWrapper { handler_fn, handle }); @@ -560,9 +589,10 @@ impl Client { /// # }); /// ``` pub async fn remove_event_handler(&self, handle: EventHandlerHandle) { + let key = handle.to_key(None); let mut event_handlers = self.inner.event_handlers.write().await; - if let btree_map::Entry::Occupied(mut entry) = event_handlers.entry(handle.ev_id) { + if let btree_map::Entry::Occupied(mut entry) = event_handlers.entry(key) { let v = entry.get_mut(); v.retain(|e| e.handle.handler_id != handle.handler_id); diff --git a/crates/matrix-sdk/src/event_handler.rs b/crates/matrix-sdk/src/event_handler.rs index 7fbb4bb6e..a1d7b715f 100644 --- a/crates/matrix-sdk/src/event_handler.rs +++ b/crates/matrix-sdk/src/event_handler.rs @@ -36,7 +36,7 @@ use std::any::TypeId; use std::{borrow::Cow, fmt, future::Future, ops::Deref, pin::Pin}; use matrix_sdk_base::deserialized_responses::{EncryptionInfo, SyncRoomEvent}; -use ruma::{events::AnySyncStateEvent, serde::Raw}; +use ruma::{events::AnySyncStateEvent, serde::Raw, OwnedRoomId}; use serde::Deserialize; use serde_json::value::RawValue as RawJsonValue; use tracing::error; @@ -90,6 +90,13 @@ pub trait SyncEvent { const TYPE: &'static str; } +#[derive(Clone, PartialEq, Eq, PartialOrd, Ord)] +pub(crate) struct EventHandlerKey<'a> { + pub ev_kind: EventKind, + pub ev_type: &'a str, + pub room_id: Option, +} + pub(crate) struct EventHandlerWrapper { pub handler_fn: Box, pub handle: EventHandlerHandle, @@ -99,10 +106,17 @@ pub(crate) struct EventHandlerWrapper { /// [`Client::remove_event_handler`]. #[derive(Clone, Copy, Debug)] pub struct EventHandlerHandle { - pub(crate) ev_id: (EventKind, &'static str), + pub(crate) ev_kind: EventKind, + pub(crate) ev_type: &'static str, pub(crate) handler_id: u64, } +impl EventHandlerHandle { + pub(crate) fn to_key(self, room_id: Option) -> EventHandlerKey<'static> { + EventHandlerKey { ev_kind: self.ev_kind, ev_type: self.ev_type, room_id } + } +} + impl EventHandlerContext for EventHandlerHandle { fn from_data(data: &EventHandlerData<'_>) -> Option { Some(data.handle) @@ -396,26 +410,39 @@ impl Client { for x in list { let (raw_event, encryption_info) = get_event_details(x); let (ev_kind, ev_type) = get_id(raw_event)?; - let event_handler_id = (ev_kind, &*ev_type); + let ev_type: &str = ev_type.as_ref(); // Construct event handler futures - let futures: Vec<_> = self - .event_handlers() - .await - .get(&event_handler_id) - .into_iter() - .flatten() - .map(|handler_wrapper| { - let data = EventHandlerData { - client: self.clone(), - room: room.clone(), - raw: raw_event.json(), - encryption_info, - handle: handler_wrapper.handle, - }; - (handler_wrapper.handler_fn)(data) - }) - .collect(); + let futures: Vec<_> = { + let handlers_lock = self.event_handlers().await; + + let non_room_event_handlers = handlers_lock + .get(&EventHandlerKey { ev_kind, ev_type, room_id: None }) + .into_iter() + .flatten(); + + let room_event_handlers = room.iter().flat_map(|r| { + let room_id = Some(r.room_id().to_owned()); + handlers_lock + .get(&EventHandlerKey { ev_kind, ev_type, room_id }) + .into_iter() + .flatten() + }); + + non_room_event_handlers + .chain(room_event_handlers) + .map(|handler_wrapper| { + let data = EventHandlerData { + client: self.clone(), + room: room.clone(), + raw: raw_event.json(), + encryption_info, + handle: handler_wrapper.handle, + }; + (handler_wrapper.handler_fn)(data) + }) + .collect() + }; // Run the event handler futures with the `self.event_handlers` lock // no longer being held, in order. @@ -601,6 +628,7 @@ mod tests { events::{ room::{ member::{OriginalSyncRoomMemberEvent, StrippedRoomMemberEvent}, + name::OriginalSyncRoomNameEvent, power_levels::OriginalSyncRoomPowerLevelsEvent, }, typing::SyncTypingEvent, @@ -715,6 +743,77 @@ mod tests { Ok(()) } + #[async_test] + async fn add_room_event_handler() -> crate::Result<()> { + use std::sync::atomic::{AtomicU8, Ordering::SeqCst}; + + let client = crate::client::tests::logged_in_client(None).await; + + let room_id_a = room_id!("!foo:example.org"); + let room_id_b = room_id!("!bar:matrix.org"); + + let member_count = Arc::new(AtomicU8::new(0)); + let power_levels_count = Arc::new(AtomicU8::new(0)); + + // Room event handlers for member events in both rooms + client + .add_room_event_handler(room_id_a, { + let member_count = member_count.clone(); + move |_ev: OriginalSyncRoomMemberEvent, _room: room::Room| { + member_count.fetch_add(1, SeqCst); + future::ready(()) + } + }) + .await; + client + .add_room_event_handler(room_id_b, { + let member_count = member_count.clone(); + move |_ev: OriginalSyncRoomMemberEvent, _room: room::Room| { + member_count.fetch_add(1, SeqCst); + future::ready(()) + } + }) + .await; + + // Power levels event handlers for member events in room A + client + .add_room_event_handler(room_id_a, { + let power_levels_count = power_levels_count.clone(); + move |_ev: OriginalSyncRoomPowerLevelsEvent, _client: Client, _room: room::Room| { + power_levels_count.fetch_add(1, SeqCst); + future::ready(()) + } + }) + .await; + + // Room name event handler for room name events in room B + client + .add_room_event_handler(room_id_b, move |_ev: OriginalSyncRoomNameEvent| async { + unreachable!("No room event in room B") + }) + .await; + + let response = EventBuilder::default() + .add_joined_room( + JoinedRoomBuilder::new(room_id_a) + .add_timeline_event(TimelineTestEvent::Member) + .add_state_event(StateTestEvent::PowerLevels) + .add_state_event(StateTestEvent::RoomName), + ) + .add_joined_room( + JoinedRoomBuilder::new(room_id_b) + .add_timeline_event(TimelineTestEvent::Member) + .add_state_event(StateTestEvent::PowerLevels), + ) + .build_sync_response(); + client.process_sync(response).await?; + + assert_eq!(member_count.load(SeqCst), 2); + assert_eq!(power_levels_count.load(SeqCst), 1); + + Ok(()) + } + #[async_test] async fn remove_event_handler() -> crate::Result<()> { use std::sync::atomic::{AtomicU8, Ordering::SeqCst};