feat(stores): allow saving thread subscriptions

This commit is contained in:
Benjamin Bouvier
2025-07-28 16:15:38 +02:00
parent b645c1101f
commit 1a5cb2beb8
8 changed files with 325 additions and 11 deletions

View File

@@ -43,7 +43,10 @@ use super::{
use crate::{
RoomInfo, RoomMemberships, RoomState, StateChanges, StateStoreDataKey, StateStoreDataValue,
deserialized_responses::MemberEvent,
store::{ChildTransactionId, QueueWedgeError, Result, SerializableEventContent, StateStoreExt},
store::{
ChildTransactionId, QueueWedgeError, Result, SerializableEventContent, StateStoreExt,
ThreadStatus,
},
};
/// `StateStore` integration tests.
@@ -98,6 +101,8 @@ pub trait StateStoreIntegrationTests {
async fn test_server_info_saving(&self);
/// Test fetching room infos based on [`RoomLoadSettings`].
async fn test_get_room_infos(&self);
/// Test loading thread subscriptions.
async fn test_thread_subscriptions(&self);
}
impl StateStoreIntegrationTests for DynStateStore {
@@ -1767,6 +1772,53 @@ impl StateStoreIntegrationTests for DynStateStore {
assert_eq!(all_rooms.len(), 0);
}
}
async fn test_thread_subscriptions(&self) {
let first_thread = event_id!("$t1");
let second_thread = event_id!("$t2");
// At first, there is no thread subscription.
let maybe_status = self.load_thread_subscription(room_id(), first_thread).await.unwrap();
assert!(maybe_status.is_none());
let maybe_status = self.load_thread_subscription(room_id(), second_thread).await.unwrap();
assert!(maybe_status.is_none());
// Setting the thread subscription works.
self.upsert_thread_subscription(
room_id(),
first_thread,
ThreadStatus::Subscribed { automatic: true },
)
.await
.unwrap();
self.upsert_thread_subscription(
room_id(),
second_thread,
ThreadStatus::Subscribed { automatic: false },
)
.await
.unwrap();
// Now, reading the thread subscription returns the expected status.
let maybe_status = self.load_thread_subscription(room_id(), first_thread).await.unwrap();
assert_eq!(maybe_status, Some(ThreadStatus::Subscribed { automatic: true }));
let maybe_status = self.load_thread_subscription(room_id(), second_thread).await.unwrap();
assert_eq!(maybe_status, Some(ThreadStatus::Subscribed { automatic: false }));
// We can override the thread subscription status.
self.upsert_thread_subscription(room_id(), first_thread, ThreadStatus::Unsubscribed)
.await
.unwrap();
// And it's correctly reflected.
let maybe_status = self.load_thread_subscription(room_id(), first_thread).await.unwrap();
assert_eq!(maybe_status, Some(ThreadStatus::Unsubscribed));
// And the second thread is still subscribed.
let maybe_status = self.load_thread_subscription(room_id(), second_thread).await.unwrap();
assert_eq!(maybe_status, Some(ThreadStatus::Subscribed { automatic: false }));
}
}
/// Macro building to allow your StateStore implementation to run the entire
@@ -1937,6 +1989,12 @@ macro_rules! statestore_integration_tests {
let store = get_store().await.expect("creating store failed").into_state_store();
store.test_get_room_infos().await;
}
#[async_test]
async fn test_thread_subscriptions() {
let store = get_store().await.expect("creating store failed").into_state_store();
store.test_thread_subscriptions().await;
}
}
};
}

View File

@@ -45,7 +45,7 @@ use super::{
use crate::{
MinimalRoomMemberEvent, RoomMemberships, StateStoreDataKey, StateStoreDataValue,
deserialized_responses::{DisplayName, RawAnySyncOrStrippedState},
store::QueueWedgeError,
store::{QueueWedgeError, ThreadStatus},
};
#[derive(Debug, Default)]
@@ -75,7 +75,6 @@ struct MemoryStoreInner {
OwnedRoomId,
HashMap<(String, Option<String>), HashMap<OwnedUserId, (OwnedEventId, Receipt)>>,
>,
room_event_receipts: HashMap<
OwnedRoomId,
HashMap<(String, Option<String>), HashMap<OwnedEventId, HashMap<OwnedUserId, Receipt>>>,
@@ -84,6 +83,7 @@ struct MemoryStoreInner {
send_queue_events: BTreeMap<OwnedRoomId, Vec<QueuedRequest>>,
dependent_send_queue_events: BTreeMap<OwnedRoomId, Vec<DependentQueuedRequest>>,
seen_knock_requests: BTreeMap<OwnedRoomId, BTreeMap<OwnedEventId, OwnedUserId>>,
thread_subscriptions: BTreeMap<OwnedRoomId, BTreeMap<OwnedEventId, ThreadStatus>>,
}
/// In-memory, non-persistent implementation of the `StateStore`.
@@ -938,10 +938,6 @@ impl StateStore for MemoryStore {
}
}
/// List all the dependent send queue events.
///
/// This returns absolutely all the dependent send queue events, whether
/// they have an event id or not.
async fn load_dependent_queued_requests(
&self,
room: &RoomId,
@@ -955,6 +951,35 @@ impl StateStore for MemoryStore {
.cloned()
.unwrap_or_default())
}
async fn upsert_thread_subscription(
&self,
room: &RoomId,
thread_id: &EventId,
status: ThreadStatus,
) -> Result<(), Self::Error> {
self.inner
.write()
.unwrap()
.thread_subscriptions
.entry(room.to_owned())
.or_default()
.insert(thread_id.to_owned(), status);
Ok(())
}
async fn load_thread_subscription(
&self,
room: &RoomId,
thread_id: &EventId,
) -> Result<Option<ThreadStatus>, Self::Error> {
let inner = self.inner.read().unwrap();
Ok(inner
.thread_subscriptions
.get(room)
.and_then(|subscriptions| subscriptions.get(thread_id))
.copied())
}
}
#[cfg(test)]

View File

@@ -102,20 +102,25 @@ pub enum StoreError {
/// An error happened in the underlying database backend.
#[error(transparent)]
Backend(Box<dyn std::error::Error + Send + Sync>),
/// An error happened while serializing or deserializing some data.
#[error(transparent)]
Json(#[from] serde_json::Error),
/// An error happened while deserializing a Matrix identifier, e.g. an user
/// id.
#[error(transparent)]
Identifier(#[from] ruma::IdParseError),
/// The store is locked with a passphrase and an incorrect passphrase was
/// given.
#[error("The store failed to be unlocked")]
StoreLocked,
/// An unencrypted store was tried to be unlocked with a passphrase.
#[error("The store is not encrypted but was tried to be opened with a passphrase")]
UnencryptedStore,
/// The store failed to encrypt or decrypt some data.
#[error("Error encrypting or decrypting data from the store: {0}")]
Encryption(#[from] StoreEncryptionError),
@@ -130,11 +135,19 @@ pub enum StoreError {
version: {0}, latest version: {1}"
)]
UnsupportedDatabaseVersion(usize, usize),
/// Redacting an event in the store has failed.
///
/// This should never happen.
#[error("Redaction failed: {0}")]
Redaction(#[source] ruma::canonical_json::RedactionError),
/// The store contains invalid data.
#[error("The store contains invalid data: {details}")]
InvalidData {
/// Details about which data is invalid, and how.
details: String,
},
}
impl StoreError {
@@ -439,6 +452,47 @@ pub enum RoomLoadSettings {
One(OwnedRoomId),
}
/// Status of a thread subscription, as saved in the state store.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum ThreadStatus {
/// The thread is subscribed to.
Subscribed {
/// Whether the subscription was made automatically by a client, not by
/// manual user choice.
automatic: bool,
},
/// The thread is unsubscribed to (it won't cause any notifications or
/// automatic subscription anymore).
Unsubscribed,
}
impl ThreadStatus {
/// Convert the current [`ThreadStatus`] into a string representation.
pub fn as_str(&self) -> &'static str {
match self {
ThreadStatus::Subscribed { automatic } => {
if *automatic {
"automatic"
} else {
"manual"
}
}
ThreadStatus::Unsubscribed => "unsubscribed",
}
}
/// Convert a string representation into a [`ThreadStatus`], if it is a
/// valid one, or `None` otherwise.
pub fn from_value(s: &str) -> Option<Self> {
match s {
"automatic" => Some(ThreadStatus::Subscribed { automatic: true }),
"manual" => Some(ThreadStatus::Subscribed { automatic: false }),
"unsubscribed" => Some(ThreadStatus::Unsubscribed),
_ => None,
}
}
}
/// Store state changes and pass them to the StateStore.
#[derive(Clone, Debug, Default)]
pub struct StateChanges {

View File

@@ -55,6 +55,7 @@ use crate::{
deserialized_responses::{
DisplayName, RawAnySyncOrStrippedState, RawMemberEvent, RawSyncOrStrippedState,
},
store::ThreadStatus,
};
/// An abstract state store trait that can be used to implement different stores
@@ -478,6 +479,27 @@ pub trait StateStore: AsyncTraitDeps {
&self,
room: &RoomId,
) -> Result<Vec<DependentQueuedRequest>, Self::Error>;
/// Insert or update a thread subscription for a given room and thread.
///
/// Note: there's no way to remove a thread subscription, because it's
/// either subscribed to, or unsubscribed to, after it's been saved for
/// the first time.
async fn upsert_thread_subscription(
&self,
room: &RoomId,
thread_id: &EventId,
status: ThreadStatus,
) -> Result<(), Self::Error>;
/// Loads the current thread subscription for a given room and thread.
///
/// Returns `None` if there was no entry for the given room/thread pair.
async fn load_thread_subscription(
&self,
room: &RoomId,
thread_id: &EventId,
) -> Result<Option<ThreadStatus>, Self::Error>;
}
#[repr(transparent)]
@@ -772,6 +794,23 @@ impl<T: StateStore> StateStore for EraseStateStoreError<T> {
.await
.map_err(Into::into)
}
async fn upsert_thread_subscription(
&self,
room: &RoomId,
thread_id: &EventId,
status: ThreadStatus,
) -> Result<(), Self::Error> {
self.0.upsert_thread_subscription(room, thread_id, status).await.map_err(Into::into)
}
async fn load_thread_subscription(
&self,
room: &RoomId,
thread_id: &EventId,
) -> Result<Option<ThreadStatus>, Self::Error> {
self.0.load_thread_subscription(room, thread_id).await.map_err(Into::into)
}
}
/// Convenience functionality for state stores.

View File

@@ -46,7 +46,7 @@ use super::{
};
use crate::IndexeddbStateStoreError;
const CURRENT_DB_VERSION: u32 = 12;
const CURRENT_DB_VERSION: u32 = 13;
const CURRENT_META_DB_VERSION: u32 = 2;
/// Sometimes Migrations can't proceed without having to drop existing
@@ -237,6 +237,9 @@ pub async fn upgrade_inner_db(
if old_version < 12 {
db = migrate_to_v12(db).await?;
}
if old_version < 13 {
db = migrate_to_v13(db).await?;
}
}
db.close();
@@ -793,6 +796,16 @@ async fn migrate_to_v12(db: IdbDatabase) -> Result<IdbDatabase> {
Ok(IdbDatabase::open_u32(&name, 12)?.await?)
}
/// Add the thread subscriptions table.
async fn migrate_to_v13(db: IdbDatabase) -> Result<IdbDatabase> {
let migration = OngoingMigration {
drop_stores: [].into(),
create_stores: [keys::THREAD_SUBSCRIPTIONS].into_iter().collect(),
data: Default::default(),
};
apply_migration(db, 13, migration).await
}
#[cfg(all(test, target_family = "wasm"))]
mod tests {
wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser);

View File

@@ -27,7 +27,7 @@ use matrix_sdk_base::{
store::{
ChildTransactionId, ComposerDraft, DependentQueuedRequest, DependentQueuedRequestKind,
QueuedRequest, QueuedRequestKind, RoomLoadSettings, SentRequestKey,
SerializableEventContent, ServerInfo, StateChanges, StateStore, StoreError,
SerializableEventContent, ServerInfo, StateChanges, StateStore, StoreError, ThreadStatus,
},
MinimalRoomMemberEvent, RoomInfo, RoomMemberships, StateStoreDataKey, StateStoreDataValue,
ROOM_VERSION_FALLBACK, ROOM_VERSION_RULES_FALLBACK,
@@ -115,6 +115,7 @@ mod keys {
pub const ROOM_SEND_QUEUE: &str = "room_send_queue";
/// Table used to save dependent send queue events.
pub const DEPENDENT_SEND_QUEUE: &str = "room_dependent_send_queue";
pub const THREAD_SUBSCRIPTIONS: &str = "room_thread_subscriptions";
pub const STRIPPED_ROOM_STATE: &str = "stripped_room_state";
pub const STRIPPED_USER_IDS: &str = "stripped_user_ids";
@@ -140,6 +141,7 @@ mod keys {
ROOM_USER_RECEIPTS,
ROOM_EVENT_RECEIPTS,
ROOM_SEND_QUEUE,
THREAD_SUBSCRIPTIONS,
DEPENDENT_SEND_QUEUE,
CUSTOM,
KV,
@@ -1360,6 +1362,7 @@ impl_state_store!({
keys::ROOM_USER_RECEIPTS,
keys::STRIPPED_ROOM_STATE,
keys::STRIPPED_USER_IDS,
keys::THREAD_SUBSCRIPTIONS,
];
let all_stores = {
@@ -1782,6 +1785,58 @@ impl_state_store!({
|val| self.deserialize_value::<Vec<DependentQueuedRequest>>(&val),
)
}
async fn upsert_thread_subscription(
&self,
room: &RoomId,
thread_id: &EventId,
status: ThreadStatus,
) -> Result<()> {
let encoded_key = self.encode_key(keys::THREAD_SUBSCRIPTIONS, (room, thread_id));
let tx = self.inner.transaction_on_one_with_mode(
keys::THREAD_SUBSCRIPTIONS,
IdbTransactionMode::Readwrite,
)?;
let obj = tx.object_store(keys::THREAD_SUBSCRIPTIONS)?;
let serialized_value = self.serialize_value(&status.as_str().to_owned());
obj.put_key_val(&encoded_key, &serialized_value?)?;
tx.await.into_result()?;
Ok(())
}
async fn load_thread_subscription(
&self,
room: &RoomId,
thread_id: &EventId,
) -> Result<Option<ThreadStatus>> {
let encoded_key = self.encode_key(keys::THREAD_SUBSCRIPTIONS, (room, thread_id));
let js_value = self
.inner
.transaction_on_one_with_mode(keys::THREAD_SUBSCRIPTIONS, IdbTransactionMode::Readonly)?
.object_store(keys::THREAD_SUBSCRIPTIONS)?
.get(&encoded_key)?
.await?;
let Some(js_value) = js_value else {
// We didn't have a previous subscription for this thread.
return Ok(None);
};
let status_string: String = self.deserialize_value(&js_value)?;
let status =
ThreadStatus::from_value(&status_string).ok_or_else(|| StoreError::InvalidData {
details: format!(
"invalid thread status for room {room} and thread {thread_id}: {status_string}"
),
})?;
Ok(Some(status))
}
});
/// A room member.

View File

@@ -0,0 +1,9 @@
-- Thread subscriptions.
CREATE TABLE "thread_subscriptions" (
"room_id" BLOB NOT NULL,
-- Event ID of the thread root.
"event_id" BLOB NOT NULL,
"status" TEXT NOT NULL,
PRIMARY KEY ("room_id", "event_id")
);

View File

@@ -13,7 +13,7 @@ use matrix_sdk_base::{
store::{
migration_helpers::RoomInfoV1, ChildTransactionId, DependentQueuedRequest,
DependentQueuedRequestKind, QueueWedgeError, QueuedRequest, QueuedRequestKind,
RoomLoadSettings, SentRequestKey,
RoomLoadSettings, SentRequestKey, ThreadStatus,
},
MinimalRoomMemberEvent, RoomInfo, RoomMemberships, RoomState, StateChanges, StateStore,
StateStoreDataKey, StateStoreDataValue, ROOM_VERSION_FALLBACK, ROOM_VERSION_RULES_FALLBACK,
@@ -62,6 +62,7 @@ mod keys {
pub const DISPLAY_NAME: &str = "display_name";
pub const SEND_QUEUE: &str = "send_queue_events";
pub const DEPENDENTS_SEND_QUEUE: &str = "dependent_send_queue_events";
pub const THREAD_SUBSCRIPTIONS: &str = "thread_subscriptions";
}
/// The filename used for the SQLITE database file used by the state store.
@@ -72,7 +73,7 @@ pub const DATABASE_NAME: &str = "matrix-sdk-state.sqlite3";
/// This is used to figure whether the SQLite database requires a migration.
/// Every new SQL migration should imply a bump of this number, and changes in
/// the [`SqliteStateStore::run_migrations`] function.
const DATABASE_VERSION: u8 = 12;
const DATABASE_VERSION: u8 = 13;
/// An SQLite-based state store.
#[derive(Clone)]
@@ -356,6 +357,17 @@ impl SqliteStateStore {
conn.set_kv("version", vec![12]).await?;
}
if from < 13 && to >= 13 {
conn.with_transaction(move |txn| {
// Run the migration.
txn.execute_batch(include_str!(
"../migrations/state_store/011_thread_subscriptions.sql"
))?;
txn.set_db_version(13)
})
.await?;
}
Ok(())
}
@@ -2074,6 +2086,55 @@ impl StateStore for SqliteStateStore {
Ok(dependent_events)
}
async fn upsert_thread_subscription(
&self,
room_id: &RoomId,
thread_id: &EventId,
status: ThreadStatus,
) -> Result<(), Self::Error> {
let room_id = self.encode_key(keys::THREAD_SUBSCRIPTIONS, room_id);
let thread_id = self.encode_key(keys::THREAD_SUBSCRIPTIONS, thread_id);
let status = status.as_str();
self.acquire()
.await?
.with_transaction(move |txn| {
txn.prepare_cached(
"INSERT OR REPLACE INTO thread_subscriptions (room_id, event_id, status)
VALUES (?, ?, ?)",
)?
.execute((room_id, thread_id, status))
})
.await?;
Ok(())
}
async fn load_thread_subscription(
&self,
room_id: &RoomId,
thread_id: &EventId,
) -> Result<Option<ThreadStatus>, Self::Error> {
let room_id = self.encode_key(keys::THREAD_SUBSCRIPTIONS, room_id);
let thread_id = self.encode_key(keys::THREAD_SUBSCRIPTIONS, thread_id);
Ok(self
.acquire()
.await?
.query_row(
"SELECT status FROM thread_subscriptions WHERE room_id = ? AND event_id = ?",
(room_id, thread_id),
|row| row.get::<_, String>(0),
)
.await
.optional()?
.map(|data| {
ThreadStatus::from_value(&data).ok_or_else(|| Error::InvalidData {
details: format!("Invalid thread status: {data}"),
})
})
.transpose()?)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]