diff --git a/.github/workflows/bindings_ci.yml b/.github/workflows/bindings_ci.yml index 658142e17..140bb10b7 100644 --- a/.github/workflows/bindings_ci.yml +++ b/.github/workflows/bindings_ci.yml @@ -113,6 +113,8 @@ jobs: - name: Install Node.js uses: actions/setup-node@v3 + with: + node-version: 18.0 - name: Install NPM dependencies working-directory: ${{ env.MATRIX_SDK_CRYPTO_JS_PATH }} diff --git a/.rustfmt.toml b/.rustfmt.toml index a05345e41..7b1f8212c 100644 --- a/.rustfmt.toml +++ b/.rustfmt.toml @@ -1,7 +1,8 @@ -edition = "2018" max_width = 100 comment_width = 80 wrap_comments = true imports_granularity = "Crate" use_small_heuristics = "Max" group_imports = "StdExternalCrate" +format_code_in_doc_comments = true +doc_comment_code_block_width = 80 diff --git a/Cargo.toml b/Cargo.toml index c7f3fb182..bdee9206c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,12 @@ resolver = "2" [profile.release] lto = true +[profile.dev] +# Copied from rust-analyzer. Saves a lot of disk space and hopefully +# compilation time / mem usage too, at the expense of potentially having to +# change this setting here when you want to use a debugger. +debug = 0 + [profile.dev.package] # Optimize quote even in debug mode. Speeds up proc-macros enough to account # for the extra time of optimizing it for a clean build of matrix-sdk-ffi. diff --git a/benchmarks/benches/crypto_bench.rs b/benchmarks/benches/crypto_bench.rs index 8b0c929d2..66fd7081d 100644 --- a/benchmarks/benches/crypto_bench.rs +++ b/benchmarks/benches/crypto_bench.rs @@ -63,7 +63,7 @@ pub fn keys_query(c: &mut Criterion) { let mut group = c.benchmark_group("Keys querying"); group.throughput(Throughput::Elements(count as u64)); - let name = format!("{} device and cross signing keys", count); + let name = format!("{count} device and cross signing keys"); group.bench_with_input(BenchmarkId::new("memory store", &name), &response, |b, response| { b.to_async(&runtime) @@ -96,7 +96,7 @@ pub fn keys_claiming(c: &mut Criterion) { let mut group = c.benchmark_group("Olm session creation"); group.throughput(Throughput::Elements(count as u64)); - let name = format!("{} one-time keys", count); + let name = format!("{count} one-time keys"); group.bench_with_input(BenchmarkId::new("memory store", &name), &response, |b, response| { b.iter_batched( @@ -158,7 +158,7 @@ pub fn room_key_sharing(c: &mut Criterion) { let mut group = c.benchmark_group("Room key sharing"); group.throughput(Throughput::Elements(count as u64)); - let name = format!("{} devices", count); + let name = format!("{count} devices"); group.bench_function(BenchmarkId::new("memory store", &name), |b| { b.to_async(&runtime).iter(|| async { @@ -225,7 +225,7 @@ pub fn devices_missing_sessions_collecting(c: &mut Criterion) { let mut group = c.benchmark_group("Devices missing sessions collecting"); group.throughput(Throughput::Elements(count as u64)); - let name = format!("{} devices", count); + let name = format!("{count} devices"); runtime.block_on(machine.mark_request_as_sent(&txn_id, &response)).unwrap(); diff --git a/bindings/apple/build_xcframework.sh b/bindings/apple/build_xcframework.sh index d98438747..f1350d12d 100755 --- a/bindings/apple/build_xcframework.sh +++ b/bindings/apple/build_xcframework.sh @@ -17,16 +17,22 @@ REL_TYPE_DIR="release" # Build static libs for all the different architectures # iOS +echo -e "Building for iOS [1/5]" cargo build -p matrix-sdk-ffi ${REL_FLAG} --target "aarch64-apple-ios" # MacOS +echo -e "\nBuilding for macOS (Apple Silicon) [2/5]" cargo build -p matrix-sdk-ffi ${REL_FLAG} --target "aarch64-apple-darwin" +echo -e "\nBuilding for macOS (Intel) [3/5]" cargo build -p matrix-sdk-ffi ${REL_FLAG} --target "x86_64-apple-darwin" # iOS Simulator +echo -e "\nBuilding for iOS Simulator (Apple Silicon) [4/5]" cargo build -p matrix-sdk-ffi ${REL_FLAG} --target "aarch64-apple-ios-sim" +echo -e "\nBuilding for iOS Simulator (Intel) [5/5]" cargo build -p matrix-sdk-ffi ${REL_FLAG} --target "x86_64-apple-ios" +echo -e "\nCreating XCFramework" # Lipo together the libraries for the same platform # MacOS diff --git a/bindings/matrix-sdk-crypto-ffi/Cargo.toml b/bindings/matrix-sdk-crypto-ffi/Cargo.toml index 3b200e82a..a40c755df 100644 --- a/bindings/matrix-sdk-crypto-ffi/Cargo.toml +++ b/bindings/matrix-sdk-crypto-ffi/Cargo.toml @@ -2,7 +2,7 @@ name = "matrix-sdk-crypto-ffi" version = "0.1.0" authors = ["Damir Jelić "] -edition = "2018" +edition = "2021" rust-version = "1.60" description = "Uniffi based bindings for the Rust SDK crypto crate" repository = "https://github.com/matrix-org/matrix-rust-sdk" @@ -57,7 +57,7 @@ features = ["rt-multi-thread"] [dependencies.vodozemac] git = "https://github.com/matrix-org/vodozemac/" -rev = "2404f83f7d3a3779c1f518e4d949f7da9677c3dd" +rev = "18bcbc3359298894415931547ea41abb75af2d4a" [build-dependencies] uniffi_build = { version = "0.18.0", features = ["builtin-bindgen"] } diff --git a/bindings/matrix-sdk-crypto-ffi/src/lib.rs b/bindings/matrix-sdk-crypto-ffi/src/lib.rs index ce1f22d1e..850a9cc42 100644 --- a/bindings/matrix-sdk-crypto-ffi/src/lib.rs +++ b/bindings/matrix-sdk-crypto-ffi/src/lib.rs @@ -14,7 +14,7 @@ mod responses; mod users; mod verification; -use std::{borrow::Borrow, collections::HashMap, convert::TryFrom, str::FromStr, sync::Arc}; +use std::{borrow::Borrow, collections::HashMap, str::FromStr, sync::Arc}; pub use backup_recovery_key::{ BackupRecoveryKey, DecodeError, MegolmV1BackupKey, PassphraseInfo, PkDecryptionError, @@ -262,6 +262,7 @@ pub fn migrate( imported: session.imported, backed_up: session.backed_up, history_visibility: None, + algorithm: ruma::EventEncryptionAlgorithm::MegolmV1AesSha2, }; let session = matrix_sdk_crypto::olm::InboundGroupSession::from_pickle(pickle)?; diff --git a/bindings/matrix-sdk-crypto-ffi/src/machine.rs b/bindings/matrix-sdk-crypto-ffi/src/machine.rs index 79c595450..47210bc8d 100644 --- a/bindings/matrix-sdk-crypto-ffi/src/machine.rs +++ b/bindings/matrix-sdk-crypto-ffi/src/machine.rs @@ -1,6 +1,5 @@ use std::{ collections::{BTreeMap, HashMap}, - convert::TryInto, io::Cursor, ops::Deref, sync::Arc, @@ -32,10 +31,8 @@ use ruma::{ }, IncomingResponse, }, - events::{ - key::verification::VerificationMethod, room::encrypted::OriginalSyncRoomEncryptedEvent, - AnySyncMessageLikeEvent, - }, + events::{key::verification::VerificationMethod, AnySyncMessageLikeEvent}, + serde::Raw, DeviceKeyAlgorithm, EventId, OwnedTransactionId, OwnedUserId, RoomId, UserId, }; use serde::{Deserialize, Serialize}; @@ -602,7 +599,7 @@ impl OlmMachine { content: &'a RawValue, } - let event: OriginalSyncRoomEncryptedEvent = serde_json::from_str(event)?; + let event: Raw<_> = serde_json::from_str(event)?; let room_id = RoomId::parse(room_id)?; let decrypted = self.runtime.block_on(self.inner.decrypt_room_event(&event, &room_id))?; @@ -640,7 +637,7 @@ impl OlmMachine { event: &str, room_id: &str, ) -> Result { - let event: OriginalSyncRoomEncryptedEvent = serde_json::from_str(event)?; + let event: Raw<_> = serde_json::from_str(event)?; let room_id = RoomId::parse(room_id)?; let (cancel, request) = diff --git a/bindings/matrix-sdk-crypto-js/Cargo.toml b/bindings/matrix-sdk-crypto-js/Cargo.toml index 2b91685a9..006f01937 100644 --- a/bindings/matrix-sdk-crypto-js/Cargo.toml +++ b/bindings/matrix-sdk-crypto-js/Cargo.toml @@ -1,15 +1,15 @@ [package] -authors = ["Ivan Enderlin "] +name = "matrix-sdk-crypto-js" description = "Matrix encryption library, for JavaScript" +authors = ["Ivan Enderlin "] edition = "2021" homepage = "https://github.com/matrix-org/matrix-rust-sdk" keywords = ["matrix", "chat", "messaging", "ruma", "nio"] license = "Apache-2.0" -name = "matrix-sdk-crypto-js" readme = "README.md" repository = "https://github.com/matrix-org/matrix-rust-sdk" rust-version = "1.60" -version = "0.5.0" +version = "0.1.0-alpha.0" [package.metadata.docs.rs] features = ["docsrs"] @@ -31,7 +31,6 @@ tracing = [] matrix-sdk-common = { version = "0.5.0", path = "../../crates/matrix-sdk-common" } matrix-sdk-crypto = { version = "0.5.0", path = "../../crates/matrix-sdk-crypto" } ruma = { git = "https://github.com/ruma/ruma", rev = "ca8c66c885241a7ba3805399604eda4a38979f6b", features = ["client-api-c", "js", "rand", "unstable-msc2676", "unstable-msc2677"] } -vodozemac = { git = "https://github.com/matrix-org/vodozemac/", rev = "2404f83f7d3a3779c1f518e4d949f7da9677c3dd", features = ["js"] } wasm-bindgen = "0.2.80" wasm-bindgen-futures = "0.4.30" js-sys = "0.3.49" @@ -41,3 +40,8 @@ http = "0.2.6" anyhow = "1.0.58" tracing = { version = "0.1.35", default-features = false, features = ["attributes"] } tracing-subscriber = { version = "0.3.14", default-features = false, features = ["registry", "std"] } + +[dependencies.vodozemac] +git = "https://github.com/matrix-org/vodozemac/" +rev = "18bcbc3359298894415931547ea41abb75af2d4a" +features = ["js"] diff --git a/bindings/matrix-sdk-crypto-js/cliff.toml b/bindings/matrix-sdk-crypto-js/cliff.toml new file mode 100644 index 000000000..26f33b838 --- /dev/null +++ b/bindings/matrix-sdk-crypto-js/cliff.toml @@ -0,0 +1,61 @@ +# configuration file for git-cliff (0.1.0) + +[changelog] +# changelog header +header = """ +# Matrix SDK Crypto JavaScript Changelog\n +All notable changes to this project will be documented in this file.\n +""" +# template for the changelog body +# https://tera.netlify.app/docs/#introduction +body = """ +{% if version %}\ + ## [{{ version | trim_start_matches(pat="v") }}] - {{ timestamp | date(format="%Y-%m-%d") }} +{% else %}\ + ## [unreleased] +{% endif %}\ +{% for group, commits in commits | filter(attribute="scope", value="crypto-js") | group_by(attribute="group") %} + ### {{ group | upper_first }} + {% for commit in commits %} + - {% if commit.breaking %}[**breaking**] {% endif %}{{ commit.message | upper_first }}\ + {% endfor %} +{% endfor %}\n +""" +# remove the leading and trailing whitespace from the template +trim = true +# changelog footer +footer = """ +""" + +[git] +# parse the commits based on https://www.conventionalcommits.org +conventional_commits = true +# filter out the commits that are not conventional +filter_unconventional = true +# regex for preprocessing the commit messages +commit_preprocessors = [ + { pattern = '\((\w+\s)?#([0-9]+)\)', replace = "([#${2}](https://github.com/matrix-org/matrix-rust-sdk/issues/${2}))"}, +] +# regex for parsing and grouping commits +commit_parsers = [ + { message = "^feat", group = "Features"}, + { message = "^fix", group = "Bug Fixes"}, + { message = "^test", group = "Testing"}, + { message = "^doc", group = "Documentation"}, + { message = "^refactor", group = "Refactoring"}, + { message = "^ci", group = "Continuous Integration"}, + { message = "^chore", group = "Miscellaneous Tasks"}, + { body = ".*security", group = "Security"}, +] +# filter out the commits that are not matched by commit parsers +filter_commits = false +# glob pattern for matching git tags +tag_pattern = "v[0-9]*" +# regex for skipping tags +skip_tags = "" +# regex for ignoring tags +ignore_tags = "" +# sort the tags chronologically +date_order = false +# sort the commits inside sections by oldest/newest order +sort_commits = "oldest" diff --git a/bindings/matrix-sdk-crypto-js/package.json b/bindings/matrix-sdk-crypto-js/package.json index 45e7d85d2..9fdc08f7d 100644 --- a/bindings/matrix-sdk-crypto-js/package.json +++ b/bindings/matrix-sdk-crypto-js/package.json @@ -1,6 +1,6 @@ { "name": "@matrix-org/matrix-sdk-crypto-js", - "version": "0.5.0", + "version": "0.1.0-alpha.0", "homepage": "https://github.com/matrix-org/matrix-rust-sdk", "description": "Matrix encryption library, for JavaScript", "license": "Apache-2.0", @@ -35,7 +35,7 @@ "node": ">= 10" }, "scripts": { - "build": "cross-env RUSTFLAGS='-C opt-level=z' wasm-pack build --release --target nodejs --out-name matrix_sdk_crypto --out-dir ./pkg", + "build": "cross-env RUSTFLAGS='-C opt-level=z' wasm-pack build --release --target nodejs --scope matrix-org --out-dir ./pkg", "test": "jest --verbose", "doc": "typedoc --tsconfig ." } diff --git a/bindings/matrix-sdk-crypto-js/src/lib.rs b/bindings/matrix-sdk-crypto-js/src/lib.rs index b927fedca..4b22a69d9 100644 --- a/bindings/matrix-sdk-crypto-js/src/lib.rs +++ b/bindings/matrix-sdk-crypto-js/src/lib.rs @@ -62,8 +62,7 @@ where Ok(unsafe { T::ref_from_abi(pointer) }) } else { Err(JsError::new(&format!( - "Expect an `{}` instance, received `{}` instead", - classname, constructor_name, + "Expect an `{classname}` instance, received `{constructor_name}` instead", ))) } } diff --git a/bindings/matrix-sdk-crypto-js/src/machine.rs b/bindings/matrix-sdk-crypto-js/src/machine.rs index e6f45175d..d5be9449b 100644 --- a/bindings/matrix-sdk-crypto-js/src/machine.rs +++ b/bindings/matrix-sdk-crypto-js/src/machine.rs @@ -3,10 +3,7 @@ use std::collections::BTreeMap; use js_sys::{Array, Map, Promise, Set}; -use ruma::{ - events::room::encrypted::OriginalSyncRoomEncryptedEvent, DeviceKeyAlgorithm, - OwnedTransactionId, UInt, -}; +use ruma::{serde::Raw, DeviceKeyAlgorithm, OwnedTransactionId, UInt}; use serde_json::Value as JsonValue; use wasm_bindgen::prelude::*; @@ -274,7 +271,7 @@ impl OlmMachine { event: &str, room_id: &identifiers::RoomId, ) -> Result { - let event: OriginalSyncRoomEncryptedEvent = serde_json::from_str(event)?; + let event: Raw<_> = serde_json::from_str(event)?; let room_id = room_id.inner.clone(); let me = self.inner.clone(); diff --git a/bindings/matrix-sdk-crypto-js/src/tracing.rs b/bindings/matrix-sdk-crypto-js/src/tracing.rs index 75a7801aa..3d5ca7f9f 100644 --- a/bindings/matrix-sdk-crypto-js/src/tracing.rs +++ b/bindings/matrix-sdk-crypto-js/src/tracing.rs @@ -207,7 +207,7 @@ mod inner { let origin = metadata .file() - .and_then(|file| metadata.line().map(|ln| format!("{}:{}", file, ln))) + .and_then(|file| metadata.line().map(|ln| format!("{file}:{ln}"))) .unwrap_or_default(); let message = format!("{level} {origin}{recorder}"); @@ -240,11 +240,11 @@ mod inner { self.string.push('\n'); } - let _ = write!(self.string, "{:?}", value); + let _ = write!(self.string, "{value:?}"); } field_name => { - let _ = write!(self.string, "\n{} = {:?}", field_name, value); + let _ = write!(self.string, "\n{field_name} = {value:?}"); } } } diff --git a/bindings/matrix-sdk-crypto-js/tests/encryption.test.js b/bindings/matrix-sdk-crypto-js/tests/encryption.test.js index 75374822b..4307deaf3 100644 --- a/bindings/matrix-sdk-crypto-js/tests/encryption.test.js +++ b/bindings/matrix-sdk-crypto-js/tests/encryption.test.js @@ -1,4 +1,4 @@ -const { EncryptionAlgorithm, EncryptionSettings, HistoryVisibility, VerificationState } = require('../pkg/matrix_sdk_crypto'); +const { EncryptionAlgorithm, EncryptionSettings, HistoryVisibility, VerificationState } = require('../pkg/matrix_sdk_crypto_js'); describe('EncryptionAlgorithm', () => { test('has the correct variant values', () => { diff --git a/bindings/matrix-sdk-crypto-js/tests/events.test.js b/bindings/matrix-sdk-crypto-js/tests/events.test.js index b478e5158..75ed2b610 100644 --- a/bindings/matrix-sdk-crypto-js/tests/events.test.js +++ b/bindings/matrix-sdk-crypto-js/tests/events.test.js @@ -1,4 +1,4 @@ -const { HistoryVisibility } = require('../pkg/matrix_sdk_crypto'); +const { HistoryVisibility } = require('../pkg/matrix_sdk_crypto_js'); describe('HistoryVisibility', () => { test('has the correct variant values', () => { diff --git a/bindings/matrix-sdk-crypto-js/tests/identifiers.test.js b/bindings/matrix-sdk-crypto-js/tests/identifiers.test.js index 3c4668f5d..3b72da2ff 100644 --- a/bindings/matrix-sdk-crypto-js/tests/identifiers.test.js +++ b/bindings/matrix-sdk-crypto-js/tests/identifiers.test.js @@ -1,4 +1,4 @@ -const { UserId, DeviceId, RoomId, ServerName } = require('../pkg/matrix_sdk_crypto'); +const { UserId, DeviceId, RoomId, ServerName } = require('../pkg/matrix_sdk_crypto_js'); describe(UserId.name, () => { test('cannot be invalid', () => { diff --git a/bindings/matrix-sdk-crypto-js/tests/machine.test.js b/bindings/matrix-sdk-crypto-js/tests/machine.test.js index dd759f1b8..d4e35641a 100644 --- a/bindings/matrix-sdk-crypto-js/tests/machine.test.js +++ b/bindings/matrix-sdk-crypto-js/tests/machine.test.js @@ -1,4 +1,4 @@ -const { OlmMachine, UserId, DeviceId, RoomId, DeviceLists, RequestType, KeysUploadRequest, KeysQueryRequest, KeysClaimRequest, EncryptionSettings, DecryptedRoomEvent, VerificationState } = require('../pkg/matrix_sdk_crypto'); +const { OlmMachine, UserId, DeviceId, RoomId, DeviceLists, RequestType, KeysUploadRequest, KeysQueryRequest, KeysClaimRequest, EncryptionSettings, DecryptedRoomEvent, VerificationState } = require('../pkg/matrix_sdk_crypto_js'); describe(OlmMachine.name, () => { test('can be instantiated with the async initializer', async () => { diff --git a/bindings/matrix-sdk-crypto-js/tests/requests.test.js b/bindings/matrix-sdk-crypto-js/tests/requests.test.js index b23a2d31e..eb595ccc8 100644 --- a/bindings/matrix-sdk-crypto-js/tests/requests.test.js +++ b/bindings/matrix-sdk-crypto-js/tests/requests.test.js @@ -1,4 +1,4 @@ -const { RequestType, KeysUploadRequest, KeysQueryRequest, KeysClaimRequest, ToDeviceRequest, SignatureUploadRequest, RoomMessageRequest, KeysBackupRequest } = require('../pkg/matrix_sdk_crypto'); +const { RequestType, KeysUploadRequest, KeysQueryRequest, KeysClaimRequest, ToDeviceRequest, SignatureUploadRequest, RoomMessageRequest, KeysBackupRequest } = require('../pkg/matrix_sdk_crypto_js'); describe('RequestType', () => { test('has the correct variant values', () => { diff --git a/bindings/matrix-sdk-crypto-js/tests/sync_events.test.js b/bindings/matrix-sdk-crypto-js/tests/sync_events.test.js index 0322d1317..305d22556 100644 --- a/bindings/matrix-sdk-crypto-js/tests/sync_events.test.js +++ b/bindings/matrix-sdk-crypto-js/tests/sync_events.test.js @@ -1,4 +1,4 @@ -const { DeviceLists, UserId } = require('../pkg/matrix_sdk_crypto'); +const { DeviceLists, UserId } = require('../pkg/matrix_sdk_crypto_js'); describe(DeviceLists.name, () => { test('can be empty', () => { diff --git a/bindings/matrix-sdk-crypto-js/tests/tracing.test.js b/bindings/matrix-sdk-crypto-js/tests/tracing.test.js index bf27431ee..ecb7276f9 100644 --- a/bindings/matrix-sdk-crypto-js/tests/tracing.test.js +++ b/bindings/matrix-sdk-crypto-js/tests/tracing.test.js @@ -1,4 +1,4 @@ -const { Tracing, LoggerLevel, OlmMachine, UserId, DeviceId } = require('../pkg/matrix_sdk_crypto'); +const { Tracing, LoggerLevel, OlmMachine, UserId, DeviceId } = require('../pkg/matrix_sdk_crypto_js'); describe('LoggerLevel', () => { test('has the correct variant values', () => { diff --git a/bindings/matrix-sdk-crypto-js/tsconfig.json b/bindings/matrix-sdk-crypto-js/tsconfig.json index 0f9ea102a..cca9bfa19 100644 --- a/bindings/matrix-sdk-crypto-js/tsconfig.json +++ b/bindings/matrix-sdk-crypto-js/tsconfig.json @@ -3,7 +3,7 @@ "strict": true }, "typedocOptions": { - "entryPoints": ["pkg/matrix_sdk_crypto.d.ts"], + "entryPoints": ["pkg/matrix_sdk_crypto_js.d.ts"], "out": "docs", "readme": "README.md", } diff --git a/bindings/matrix-sdk-crypto-nodejs/Cargo.toml b/bindings/matrix-sdk-crypto-nodejs/Cargo.toml index 82a51626a..389419f5a 100644 --- a/bindings/matrix-sdk-crypto-nodejs/Cargo.toml +++ b/bindings/matrix-sdk-crypto-nodejs/Cargo.toml @@ -29,7 +29,6 @@ matrix-sdk-crypto = { version = "0.5.0", path = "../../crates/matrix-sdk-crypto" matrix-sdk-common = { version = "0.5.0", path = "../../crates/matrix-sdk-common" } matrix-sdk-sled = { version = "0.1.0", path = "../../crates/matrix-sdk-sled", default-features = false, features = ["crypto-store"] } ruma = { git = "https://github.com/ruma/ruma", rev = "ca8c66c885241a7ba3805399604eda4a38979f6b", features = ["client-api-c", "rand", "unstable-msc2676", "unstable-msc2677"] } -vodozemac = { git = "https://github.com/matrix-org/vodozemac/", rev = "2404f83f7d3a3779c1f518e4d949f7da9677c3dd" } napi = { git = "https://github.com/Hywan/napi-rs", branch = "fix-napi-strict-on-t-and-ref-t", default-features = false, features = ["napi6", "tokio_rt"] } napi-derive = { git = "https://github.com/Hywan/napi-rs", branch = "fix-napi-strict-on-t-and-ref-t" } serde_json = "1.0.79" @@ -37,5 +36,10 @@ http = "0.2.6" zeroize = "1.3.0" tracing-subscriber = { version = "0.3", default-features = false, features = ["tracing-log", "time", "smallvec", "fmt", "env-filter"], optional = true } +[dependencies.vodozemac] +git = "https://github.com/matrix-org/vodozemac/" +rev = "18bcbc3359298894415931547ea41abb75af2d4a" +features = ["js"] + [build-dependencies] napi-build = "2.0.0" diff --git a/bindings/matrix-sdk-crypto-nodejs/cliff.toml b/bindings/matrix-sdk-crypto-nodejs/cliff.toml index e9fb2cf3b..f2c005219 100644 --- a/bindings/matrix-sdk-crypto-nodejs/cliff.toml +++ b/bindings/matrix-sdk-crypto-nodejs/cliff.toml @@ -40,9 +40,11 @@ commit_preprocessors = [ commit_parsers = [ { message = "^feat", group = "Features"}, { message = "^fix", group = "Bug Fixes"}, - { message = "^doc", group = "Documentation"}, - { message = "^perf", group = "Performance"}, { message = "^test", group = "Testing"}, + { message = "^doc", group = "Documentation"}, + { message = "^refactor", group = "Refactoring"}, + { message = "^ci", group = "Continuous Integration"}, + { message = "^chore", group = "Miscellaneous Tasks"}, { body = ".*security", group = "Security"}, ] # filter out the commits that are not matched by commit parsers diff --git a/bindings/matrix-sdk-crypto-nodejs/src/machine.rs b/bindings/matrix-sdk-crypto-nodejs/src/machine.rs index 5348ad54a..d6191d219 100644 --- a/bindings/matrix-sdk-crypto-nodejs/src/machine.rs +++ b/bindings/matrix-sdk-crypto-nodejs/src/machine.rs @@ -7,11 +7,8 @@ use std::{ use napi::bindgen_prelude::Either7; use napi_derive::*; -use ruma::{ - events::room::encrypted::OriginalSyncRoomEncryptedEvent, DeviceKeyAlgorithm, - OwnedTransactionId, UInt, -}; -use serde_json::Value as JsonValue; +use ruma::{serde::Raw, DeviceKeyAlgorithm, OwnedTransactionId, UInt}; +use serde_json::{value::RawValue, Value as JsonValue}; use zeroize::Zeroize; use crate::{ @@ -389,8 +386,7 @@ impl OlmMachine { event: String, room_id: &identifiers::RoomId, ) -> napi::Result { - let event: OriginalSyncRoomEncryptedEvent = - serde_json::from_str(event.as_str()).map_err(into_err)?; + let event = Raw::from_json(RawValue::from_string(event).map_err(into_err)?); let room_id = room_id.inner.clone(); let room_event = self.inner.decrypt_room_event(&event, &room_id).await.map_err(into_err)?; diff --git a/bindings/matrix-sdk-ffi/src/api.udl b/bindings/matrix-sdk-ffi/src/api.udl index 76d2ada9f..d5168fdad 100644 --- a/bindings/matrix-sdk-ffi/src/api.udl +++ b/bindings/matrix-sdk-ffi/src/api.udl @@ -42,7 +42,7 @@ interface Client { string homeserver(); - void start_sync(); + void start_sync(u16? timeline_limit); [Throws=ClientError] string restore_token(); @@ -78,6 +78,12 @@ callback interface RoomDelegate { void did_receive_message(AnyMessage message); }; +enum Membership { + "Invited", + "Joined", + "Left", +}; + interface Room { void set_delegate(RoomDelegate? delegate); @@ -85,7 +91,9 @@ interface Room { string? name(); string? topic(); string? avatar_url(); - + + Membership membership(); + boolean is_direct(); boolean is_public(); boolean is_space(); @@ -159,6 +167,7 @@ interface MediaSource { [Error] enum AuthenticationError { "ClientMissing", + "SessionMissing", "Generic", }; @@ -178,6 +187,9 @@ interface AuthenticationService { [Throws=AuthenticationError] Client login(string username, string password); + + [Throws=AuthenticationError] + Client restore_with_access_token(string token, string device_id); }; interface SessionVerificationEmoji { diff --git a/bindings/matrix-sdk-ffi/src/authentication_service.rs b/bindings/matrix-sdk-ffi/src/authentication_service.rs index 0c3df8f12..426fa6476 100644 --- a/bindings/matrix-sdk-ffi/src/authentication_service.rs +++ b/bindings/matrix-sdk-ffi/src/authentication_service.rs @@ -1,6 +1,10 @@ use std::sync::Arc; use futures_util::future::join3; +use matrix_sdk::{ + ruma::{OwnedDeviceId, UserId}, + Session, +}; use parking_lot::RwLock; use super::{client::Client, client_builder::ClientBuilder, RUNTIME}; @@ -15,6 +19,8 @@ pub struct AuthenticationService { pub enum AuthenticationError { #[error("A successful call to use_server must be made first.")] ClientMissing, + #[error("Login was successful but is missing a valid Session to configure the file store.")] + SessionMissing, #[error("An error occurred: {message}")] Generic { message: String }, } @@ -66,14 +72,12 @@ impl AuthenticationService { /// Updates the service to authenticate with the homeserver for the /// specified address. pub fn configure_homeserver(&self, server_name: String) -> Result<(), AuthenticationError> { - // Construct a username as the builder currently requires one. - let username = format!("@auth:{}", server_name); - - let mut builder = - Arc::new(ClientBuilder::new()).base_path(self.base_path.clone()).username(username); + let mut builder = Arc::new(ClientBuilder::new()).base_path(self.base_path.clone()); if server_name.starts_with("http://") || server_name.starts_with("https://") { builder = builder.homeserver_url(server_name) + } else { + builder = builder.server_name(server_name); } let client = builder.build().map_err(AuthenticationError::from)?; @@ -96,18 +100,74 @@ impl AuthenticationService { ) -> Result, AuthenticationError> { match self.client.read().as_ref() { Some(client) => { - let homeserver_url = client.homeserver(); + // Login and ask the server for the full user ID as this could be different from + // the username that was entered. + client.login(username, password).map_err(AuthenticationError::from)?; + let whoami = client.whoami()?; - // Create a new client to setup the store path for the username + // Create a new client to setup the store path now the user ID is known. + let homeserver_url = client.homeserver(); + let session = client.session().ok_or(AuthenticationError::SessionMissing)?; let client = Arc::new(ClientBuilder::new()) .base_path(self.base_path.clone()) .homeserver_url(homeserver_url) - .username(username.clone()) + .username(whoami.user_id.to_string()) .build() .map_err(AuthenticationError::from)?; + // Restore the client using the session from the login request. client - .login(username, password) + .restore_session(session.clone()) + .map(|_| client.clone()) + .map_err(AuthenticationError::from) + } + None => Err(AuthenticationError::ClientMissing), + } + } + + /// Restore an existing session on the current homeserver using an access + /// token issued by an authentication server. + /// # Arguments + /// + /// * `token` - The access token issued by the authentication server. + /// + /// * `device_id` - The device ID that the access token was scoped for. + pub fn restore_with_access_token( + &self, + token: String, + device_id: String, + ) -> Result, AuthenticationError> { + match self.client.read().as_ref() { + Some(client) => { + // Restore the client and ask the server for the full user ID as this + // could be different from the username that was entered. + let discovery_user_id = UserId::parse("@unknown:unknown") + .map_err(|e| AuthenticationError::Generic { message: e.to_string() })?; + let device_id: OwnedDeviceId = device_id.as_str().into(); + + let discovery_session = Session { + access_token: token.clone(), + user_id: discovery_user_id, + device_id: device_id.clone(), + }; + + client.restore_session(discovery_session).map_err(AuthenticationError::from)?; + let whoami = client.whoami()?; + + // Create the actual client with a store path from the user ID. + let homeserver_url = client.homeserver(); + let session = + Session { access_token: token, user_id: whoami.user_id.clone(), device_id }; + let client = Arc::new(ClientBuilder::new()) + .base_path(self.base_path.clone()) + .homeserver_url(homeserver_url) + .username(whoami.user_id.to_string()) + .build() + .map_err(AuthenticationError::from)?; + + // Restore the client using the session. + client + .restore_session(session) .map(|_| client.clone()) .map_err(AuthenticationError::from) } diff --git a/bindings/matrix-sdk-ffi/src/client.rs b/bindings/matrix-sdk-ffi/src/client.rs index 35538cb84..25bccc5e1 100644 --- a/bindings/matrix-sdk-ffi/src/client.rs +++ b/bindings/matrix-sdk-ffi/src/client.rs @@ -1,10 +1,12 @@ use std::sync::Arc; +use anyhow::anyhow; use matrix_sdk::{ config::SyncSettings, media::{MediaFormat, MediaRequest}, ruma::{ api::client::{ + account::whoami, filter::{FilterDefinition, LazyLoadOptions, RoomEventFilter, RoomFilter}, session::get_login_types, sync::sync_events::v3::Filter, @@ -12,7 +14,7 @@ use matrix_sdk::{ events::room::MediaSource, TransactionId, }, - Client as MatrixClient, LoopCtrl, + Client as MatrixClient, LoopCtrl, Session, }; use parking_lot::RwLock; @@ -51,6 +53,7 @@ impl Client { } } + /// Login using a username and password. pub fn login(&self, username: String, password: String) -> anyhow::Result<()> { RUNTIME.block_on(async move { self.client.login_username(&username, &password).send().await?; @@ -58,10 +61,16 @@ impl Client { }) } + /// Restores the client from a `RestoreToken`. pub fn restore_login(&self, restore_token: String) -> anyhow::Result<()> { let RestoreToken { session, homeurl: _, is_guest: _ } = serde_json::from_str(&restore_token)?; + self.restore_session(session) + } + + /// Restores the client from a `Session`. + pub fn restore_session(&self, session: Session) -> anyhow::Result<()> { RUNTIME.block_on(async move { self.client.restore_login(session).await?; Ok(()) @@ -97,7 +106,13 @@ impl Client { Ok(supports_password) } - pub fn start_sync(&self) { + /// Gets information about the owner of a given access token. + pub fn whoami(&self) -> anyhow::Result { + RUNTIME + .block_on(async move { self.client.whoami().await.map_err(|e| anyhow!(e.to_string())) }) + } + + pub fn start_sync(&self, timeline_limit: Option) { let client = self.client.clone(); let state = self.state.clone(); let delegate = self.delegate.clone(); @@ -106,12 +121,16 @@ impl Client { let mut filter = FilterDefinition::default(); let mut room_filter = RoomFilter::default(); let mut event_filter = RoomEventFilter::default(); + let mut timeline_filter = RoomEventFilter::default(); event_filter.lazy_load_options = LazyLoadOptions::Enabled { include_redundant_members: false }; room_filter.state = event_filter; filter.room = room_filter; + timeline_filter.limit = timeline_limit.map(|limit| limit.into()); + filter.room.timeline = timeline_filter; + let filter_id = client.get_or_upload_filter("sync", filter).await.unwrap(); let sync_settings = SyncSettings::new().filter(Filter::FilterId(&filter_id)); diff --git a/bindings/matrix-sdk-ffi/src/client_builder.rs b/bindings/matrix-sdk-ffi/src/client_builder.rs index 786fbdf18..0b677d5d1 100644 --- a/bindings/matrix-sdk-ffi/src/client_builder.rs +++ b/bindings/matrix-sdk-ffi/src/client_builder.rs @@ -1,9 +1,10 @@ use std::{fs, path::PathBuf, sync::Arc}; -use anyhow::Context; +use anyhow::anyhow; use matrix_sdk::{ - ruma::UserId, store::make_store_config, Client as MatrixClient, - ClientBuilder as MatrixClientBuilder, + ruma::{ServerName, UserId}, + store::make_store_config, + Client as MatrixClient, ClientBuilder as MatrixClientBuilder, }; use sanitize_filename_reader_friendly::sanitize; @@ -13,6 +14,7 @@ use super::{client::Client, ClientState, RUNTIME}; pub struct ClientBuilder { base_path: Option, username: Option, + server_name: Option, homeserver_url: Option, inner: MatrixClientBuilder, } @@ -22,6 +24,7 @@ impl ClientBuilder { Self { base_path: None, username: None, + server_name: None, homeserver_url: None, inner: MatrixClient::builder().user_agent("rust-sdk-ios"), } @@ -39,6 +42,12 @@ impl ClientBuilder { Arc::new(builder) } + pub fn server_name(self: Arc, server_name: String) -> Arc { + let mut builder = unwrap_or_clone_arc(self); + builder.server_name = Some(server_name); + Arc::new(builder) + } + pub fn homeserver_url(self: Arc, url: String) -> Arc { let mut builder = unwrap_or_clone_arc(self); builder.homeserver_url = Some(url); @@ -47,25 +56,30 @@ impl ClientBuilder { pub fn build(self: Arc) -> anyhow::Result> { let builder = unwrap_or_clone_arc(self); + let mut inner_builder = builder.inner; - let base_path = builder.base_path.context("Base path was not set")?; - let username = builder - .username - .context("Username to determine homeserver and home path was not set")?; + if let (Some(base_path), Some(username)) = (builder.base_path, &builder.username) { + // Determine store path + let data_path = PathBuf::from(base_path).join(sanitize(username)); + fs::create_dir_all(&data_path)?; + let store_config = make_store_config(&data_path, None)?; - // Determine store path - let data_path = PathBuf::from(base_path).join(sanitize(&username)); - fs::create_dir_all(&data_path)?; - let store_config = make_store_config(&data_path, None)?; + inner_builder = inner_builder.store_config(store_config); + } - let mut inner_builder = builder.inner.store_config(store_config); - - // Determine server either from explicitly set homeserver or from userId + // Determine server either from URL, server name or user ID. if let Some(homeserver_url) = builder.homeserver_url { inner_builder = inner_builder.homeserver_url(homeserver_url); - } else { + } else if let Some(server_name) = builder.server_name { + let server_name = ServerName::parse(server_name)?; + inner_builder = inner_builder.server_name(&server_name); + } else if let Some(username) = builder.username { let user = UserId::parse(username)?; inner_builder = inner_builder.server_name(user.server_name()); + } else { + return Err(anyhow!( + "Failed to build: One of homeserver_url, server_name or username must be called." + )); } RUNTIME.block_on(async move { diff --git a/bindings/matrix-sdk-ffi/src/room.rs b/bindings/matrix-sdk-ffi/src/room.rs index deacfe541..f3718f2e5 100644 --- a/bindings/matrix-sdk-ffi/src/room.rs +++ b/bindings/matrix-sdk-ffi/src/room.rs @@ -18,6 +18,12 @@ pub trait RoomDelegate: Sync + Send { fn did_receive_message(&self, messages: Arc); } +pub enum Membership { + Invited, + Joined, + Left, +} + pub struct Room { room: MatrixRoom, delegate: Arc>>>, @@ -80,6 +86,14 @@ impl Room { }) } + pub fn membership(&self) -> Membership { + match &self.room { + MatrixRoom::Invited(_) => Membership::Invited, + MatrixRoom::Joined(_) => Membership::Joined, + MatrixRoom::Left(_) => Membership::Left, + } + } + pub fn is_direct(&self) -> bool { self.room.is_direct() } diff --git a/crates/matrix-sdk-appservice/Cargo.toml b/crates/matrix-sdk-appservice/Cargo.toml index 90e569a60..23cafa95c 100644 --- a/crates/matrix-sdk-appservice/Cargo.toml +++ b/crates/matrix-sdk-appservice/Cargo.toml @@ -1,6 +1,6 @@ [package] authors = ["Johannes Becker "] -edition = "2018" +edition = "2021" homepage = "https://github.com/matrix-org/matrix-rust-sdk" repository = "https://github.com/matrix-org/matrix-rust-sdk" description = "Appservice SDK based on the matrix-sdk" diff --git a/crates/matrix-sdk-appservice/examples/appservice_autojoin.rs b/crates/matrix-sdk-appservice/examples/appservice_autojoin.rs index 8835449e4..470e3ba63 100644 --- a/crates/matrix-sdk-appservice/examples/appservice_autojoin.rs +++ b/crates/matrix-sdk-appservice/examples/appservice_autojoin.rs @@ -60,11 +60,12 @@ pub async fn main() -> Result<(), Box> { let appservice = AppService::new(homeserver_url, server_name, registration).await?; appservice.register_user_query(Box::new(|_, _| Box::pin(async { true }))).await; - appservice - .virtual_user(None) - .await? - .register_event_handler_context(appservice.clone()) - .register_event_handler( + + let virtual_user = appservice.virtual_user(None).await?; + + virtual_user.add_event_handler_context(appservice.clone()); + virtual_user + .add_event_handler( move |event: OriginalSyncRoomMemberEvent, room: Room, Ctx(appservice): Ctx| { diff --git a/crates/matrix-sdk-appservice/src/lib.rs b/crates/matrix-sdk-appservice/src/lib.rs index 7fb2519c8..292197496 100644 --- a/crates/matrix-sdk-appservice/src/lib.rs +++ b/crates/matrix-sdk-appservice/src/lib.rs @@ -43,8 +43,8 @@ //! # async { //! # //! use matrix_sdk_appservice::{ -//! ruma::events::room::member::SyncRoomMemberEvent, -//! AppService, AppServiceRegistration +//! ruma::events::room::member::SyncRoomMemberEvent, AppService, +//! AppServiceRegistration, //! }; //! //! let homeserver_url = "http://127.0.0.1:8008"; @@ -60,13 +60,15 @@ //! users: //! - exclusive: true //! regex: '@_appservice_.*' -//! ")?; +//! ", +//! )?; //! -//! let mut appservice = AppService::new(homeserver_url, server_name, registration).await?; +//! let mut appservice = +//! AppService::new(homeserver_url, server_name, registration).await?; //! appservice //! .virtual_user(None) //! .await? -//! .register_event_handler(|_ev: SyncRoomMemberEvent| async { +//! .add_event_handler(|_ev: SyncRoomMemberEvent| async { //! // do stuff //! }) //! .await; @@ -84,7 +86,7 @@ //! [matrix-org/matrix-rust-sdk#228]: https://github.com/matrix-org/matrix-rust-sdk/issues/228 //! [examples directory]: https://github.com/matrix-org/matrix-rust-sdk/tree/main/crates/matrix-sdk-appservice/examples -use std::{convert::TryInto, sync::Arc}; +use std::sync::Arc; use dashmap::DashMap; pub use error::Error; @@ -122,7 +124,7 @@ pub use virtual_user::VirtualUserBuilder; pub type Result = std::result::Result; const USER_KEY: &[u8] = b"appservice.users."; -pub const USER_MEMBER: &[u8] = b"appservice.users.membership."; +const USER_MEMBER: &[u8] = b"appservice.users.membership."; type Localpart = String; @@ -256,10 +258,12 @@ impl AppService { /// ```no_run /// # use matrix_sdk_appservice::AppService; /// # fn run(appservice: AppService) { - /// appservice.register_user_query(Box::new(|appservice, req| Box::pin(async move { - /// println!("Got request for {}", req.user_id); - /// true - /// }))); + /// appservice.register_user_query(Box::new(|appservice, req| { + /// Box::pin(async move { + /// println!("Got request for {}", req.user_id); + /// true + /// }) + /// })); /// # } /// ``` pub async fn register_user_query( @@ -278,10 +282,12 @@ impl AppService { /// ```no_run /// # use matrix_sdk_appservice::AppService; /// # fn run(appservice: AppService) { - /// appservice.register_room_query(Box::new(|appservice, req| Box::pin(async move { - /// println!("Got request for {}", req.room_alias); - /// true - /// }))); + /// appservice.register_room_query(Box::new(|appservice, req| { + /// Box::pin(async move { + /// println!("Got request for {}", req.room_alias); + /// true + /// }) + /// })); /// # } /// ``` pub async fn register_room_query( @@ -492,7 +498,7 @@ impl AppService { } for task in tasks { if let Err(e) = task.await { - warn!("Joining sync task failed: {}", e); + warn!("Joining sync task failed: {e}"); } } Ok(()) @@ -505,7 +511,7 @@ impl AppService { pub async fn run(&self, host: impl Into, port: impl Into) -> Result<()> { let host = host.into(); let port = port.into(); - info!("Starting AppService on {}:{}", &host, &port); + info!(host, port, "Starting AppService"); webserver::run_server(self.clone(), host, port).await?; Ok(()) @@ -640,7 +646,7 @@ mod tests { appservice .virtual_user(None) .await? - .register_event_handler({ + .add_event_handler({ let on_state_member = on_state_member.clone(); move |_ev: OriginalSyncRoomMemberEvent| { *on_state_member.lock().unwrap() = true; @@ -796,7 +802,7 @@ mod tests { appservice .virtual_user(None) .await? - .register_event_handler({ + .add_event_handler({ let on_state_member = on_state_member.clone(); move |_ev: OriginalSyncRoomMemberEvent| { *on_state_member.lock().unwrap() = true; diff --git a/crates/matrix-sdk-appservice/src/registration.rs b/crates/matrix-sdk-appservice/src/registration.rs index 2c581d28f..4da736588 100644 --- a/crates/matrix-sdk-appservice/src/registration.rs +++ b/crates/matrix-sdk-appservice/src/registration.rs @@ -14,7 +14,7 @@ //! AppService Registration. -use std::{convert::TryFrom, fs::File, ops::Deref, path::PathBuf}; +use std::{fs::File, ops::Deref, path::PathBuf}; use http::Uri; use regex::Regex; diff --git a/crates/matrix-sdk-appservice/src/webserver.rs b/crates/matrix-sdk-appservice/src/webserver.rs index 9e29a08a9..49fa38894 100644 --- a/crates/matrix-sdk-appservice/src/webserver.rs +++ b/crates/matrix-sdk-appservice/src/webserver.rs @@ -145,7 +145,7 @@ mod filters { .and(warp::body::bytes()) .and_then(|method, path: FullPath, query, headers, bytes| async move { let uri = http::uri::Builder::new() - .path_and_query(format!("{}?{}", path.as_str(), query)) + .path_and_query(format!("{}?{query}", path.as_str())) .build() .map_err(Error::from)?; @@ -164,9 +164,13 @@ mod filters { mod handlers { use percent_encoding::percent_decode_str; + use serde::Serialize; use super::*; + #[derive(Serialize)] + struct EmptyObject {} + pub async fn user( user_id: String, appservice: AppService, @@ -177,12 +181,12 @@ mod handlers { let request = query_user::IncomingRequest::try_from_http_request(request, &[user_id]) .map_err(Error::from)?; return if user_exists(appservice.clone(), request).await { - Ok(warp::reply::json(&String::from("{}"))) + Ok(warp::reply::json(&EmptyObject {})) } else { Err(warp::reject::not_found()) }; } - Ok(warp::reply::json(&String::from("{}"))) + Ok(warp::reply::json(&EmptyObject {})) } pub async fn room( @@ -195,12 +199,12 @@ mod handlers { let request = query_room::IncomingRequest::try_from_http_request(request, &[room_id]) .map_err(Error::from)?; return if room_exists(appservice.clone(), request).await { - Ok(warp::reply::json(&String::from("{}"))) + Ok(warp::reply::json(&EmptyObject {})) } else { Err(warp::reject::not_found()) }; } - Ok(warp::reply::json(&String::from("{}"))) + Ok(warp::reply::json(&EmptyObject {})) } pub async fn transaction( @@ -213,7 +217,7 @@ mod handlers { .map_err(Error::from)?; appservice.receive_transaction(incoming_transaction).await?; - Ok(warp::reply::json(&String::from("{}"))) + Ok(warp::reply::json(&EmptyObject {})) } } diff --git a/crates/matrix-sdk-base/src/client.rs b/crates/matrix-sdk-base/src/client.rs index 5d58c040a..6786a691e 100644 --- a/crates/matrix-sdk-base/src/client.rs +++ b/crates/matrix-sdk-base/src/client.rs @@ -305,11 +305,12 @@ impl BaseClient { #[cfg(feature = "e2e-encryption")] AnySyncRoomEvent::MessageLike(e) => match e { AnySyncMessageLikeEvent::RoomEncrypted( - SyncMessageLikeEvent::Original(encrypted), + SyncMessageLikeEvent::Original(_), ) => { if let Some(olm) = self.olm_machine() { - if let Ok(decrypted) = - olm.decrypt_room_event(encrypted, room_id).await + if let Ok(decrypted) = olm + .decrypt_room_event(event.event.cast_ref(), room_id) + .await { event = decrypted.into(); } @@ -400,8 +401,8 @@ impl BaseClient { } Err(err) => { warn!( - "Couldn't deserialize stripped state event for room {}: {:?}", - room_info.room_id, err + room_id = %room_info.room_id, + "Couldn't deserialize stripped state event: {err:?}", ); } } @@ -429,10 +430,7 @@ impl BaseClient { let event = match raw_event.deserialize() { Ok(e) => e, Err(e) => { - warn!( - "Couldn't deserialize state event for room {}: {:?} {:#?}", - room_id, e, raw_event - ); + warn!(%room_id, "Couldn't deserialize state event: {e:?}"); continue; } }; @@ -831,7 +829,7 @@ impl BaseClient { .filter_map(|event| match event.deserialize() { Ok(ev) => Some(ev), Err(e) => { - debug!(?event, "Failed to deserialize m.room.member event: {}", e); + debug!(?event, "Failed to deserialize m.room.member event: {e}"); None } }) diff --git a/crates/matrix-sdk-base/src/rooms/mod.rs b/crates/matrix-sdk-base/src/rooms/mod.rs index 049854a09..3d3e4d7c5 100644 --- a/crates/matrix-sdk-base/src/rooms/mod.rs +++ b/crates/matrix-sdk-base/src/rooms/mod.rs @@ -47,9 +47,9 @@ impl fmt::Display for DisplayName { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { DisplayName::Named(s) | DisplayName::Calculated(s) | DisplayName::Aliased(s) => { - write!(f, "{}", s) + write!(f, "{s}") } - DisplayName::EmptyWas(s) => write!(f, "Empty Room (was {})", s), + DisplayName::EmptyWas(s) => write!(f, "Empty Room (was {s})"), DisplayName::Empty => write!(f, "Empty Room"), } } diff --git a/crates/matrix-sdk-base/src/rooms/normal.rs b/crates/matrix-sdk-base/src/rooms/normal.rs index fe286f965..1ff395a72 100644 --- a/crates/matrix-sdk-base/src/rooms/normal.rs +++ b/crates/matrix-sdk-base/src/rooms/normal.rs @@ -44,6 +44,7 @@ use ruma::{ RoomVersionId, UserId, }; use serde::{Deserialize, Serialize}; +use tracing::debug; use super::{BaseRoomInfo, DisplayName, RoomMember}; use crate::{ @@ -396,7 +397,7 @@ impl Room { _ => (summary.joined_member_count, summary.invited_member_count), }; - tracing::debug!( + debug!( room_id = self.room_id().as_str(), own_user = self.own_user_id.as_str(), joined, invited, @@ -589,6 +590,8 @@ impl Room { /// Add a new timeline slice to the timeline streams. #[cfg(feature = "experimental-timeline")] pub async fn add_timeline_slice(&self, timeline: &TimelineSlice) { + use tracing::warn; + if timeline.sync { let mut streams = self.forward_timeline_streams.lock().await; let mut remaining_streams = Vec::with_capacity(streams.len()); @@ -596,7 +599,11 @@ impl Room { if !forward.is_closed() { if let Err(error) = forward.try_send(timeline.clone()) { if error.is_full() { - tracing::warn!("Drop timeline slice because the limit of the buffer for the forward stream is reached"); + warn!( + room_id = %self.room_id(), + "Dropping timeline slice because the limit of the buffer for the \ + forward stream is reached" + ); } } else { remaining_streams.push(forward); @@ -611,7 +618,11 @@ impl Room { if !backward.is_closed() { if let Err(error) = backward.try_send(timeline.clone()) { if error.is_full() { - tracing::warn!("Drop timeline slice because the limit of the buffer for the backward stream is reached"); + warn!( + room_id = %self.room_id(), + "Dropping timeline slice because the limit of the buffer for the \ + backward stream is reached" + ); } } else { remaining_streams.push(backward); diff --git a/crates/matrix-sdk-base/src/store/ambiguity_map.rs b/crates/matrix-sdk-base/src/store/ambiguity_map.rs index 57e8b18b0..0455904d6 100644 --- a/crates/matrix-sdk-base/src/store/ambiguity_map.rs +++ b/crates/matrix-sdk-base/src/store/ambiguity_map.rs @@ -123,7 +123,7 @@ impl AmbiguityCache { member_ambiguous: ambiguous, }; - trace!("Handling display name ambiguity for {}: {:#?}", member_event.state_key(), change); + trace!(user_id = %member_event.state_key(), "Handling display name ambiguity: {change:#?}"); self.add_change(room_id, member_event.event_id().to_owned(), change); diff --git a/crates/matrix-sdk-base/src/store/integration_tests.rs b/crates/matrix-sdk-base/src/store/integration_tests.rs index 0af39b2b8..49dc96e71 100644 --- a/crates/matrix-sdk-base/src/store/integration_tests.rs +++ b/crates/matrix-sdk-base/src/store/integration_tests.rs @@ -17,13 +17,13 @@ /// #[cfg(test)] /// mod tests { /// -/// use super::{MyStore, StoreResult, StateStore}; +/// use super::{MyStore, StateStore, StoreResult}; /// -/// async fn get_store() -> StoreResult { -/// Ok(MyStore::new()) -/// } +/// async fn get_store() -> StoreResult { +/// Ok(MyStore::new()) +/// } /// -/// statestore_integration_tests! { integration } +/// statestore_integration_tests! { integration } /// } /// ``` #[allow(unused_macros, unused_extern_crates)] @@ -775,13 +775,14 @@ macro_rules! statestore_integration_tests { .zip(stored_events.iter()) .enumerate() { - assert_eq!(a.expect("not a value").event_id(), b.event_id(), "pos {} not equal - expected: {:#?}, but found {:#?}", idx, expected, found); - + assert_eq!( + a.expect("not a value").event_id(), + b.event_id(), + "pos {idx} not equal - expected: {expected:#?}, but found {found:#?}", + ); } } - } - )* } } diff --git a/crates/matrix-sdk-base/src/store/memory_store.rs b/crates/matrix-sdk-base/src/store/memory_store.rs index a0321fad8..bcf5315fe 100644 --- a/crates/matrix-sdk-base/src/store/memory_store.rs +++ b/crates/matrix-sdk-base/src/store/memory_store.rs @@ -43,6 +43,7 @@ use ruma::{ serde::Raw, EventId, MxcUri, OwnedEventId, OwnedRoomId, OwnedUserId, RoomId, UserId, }; +use tracing::info; #[cfg(feature = "experimental-timeline")] use super::BoxStream; @@ -342,26 +343,25 @@ impl MemoryStore { } #[cfg(feature = "experimental-timeline")] - for (room, timeline) in &changes.timeline { + for (room_id, timeline) in &changes.timeline { + use tracing::warn; + if timeline.sync { - tracing::info!("Save new timeline batch from sync response for {}", room); + info!(%room_id, "Saving new timeline batch from sync response"); } else { - tracing::info!("Save new timeline batch from messages response for {}", room); + info!(%room_id, "Saving new timeline batch from messages response"); } let mut delete_timeline = false; if timeline.limited { - tracing::info!( - "Delete stored timeline for {} because the sync response was limited", - room - ); + info!(%room_id, "Deleting stored timeline because the sync response was limited"); delete_timeline = true; - } else if let Some(mut data) = self.room_timeline.get_mut(room) { + } else if let Some(mut data) = self.room_timeline.get_mut(room_id) { if !timeline.sync && Some(&timeline.start) != data.end.as_ref() { // This should only happen when a developer adds a wrong timeline // batch to the `StateChanges` or the server returns a wrong response // to our request. - tracing::warn!("Drop unexpected timeline batch for {}", room); + warn!(%room_id, "Dropping unexpected timeline batch"); return Ok(()); } @@ -385,12 +385,12 @@ impl MemoryStore { } if delete_timeline { - tracing::info!("Delete stored timeline for {} because of duplicated events", room); - self.room_timeline.remove(room); + info!(%room_id, "Deleting stored timeline because of duplicated events"); + self.room_timeline.remove(room_id); } let mut data = - self.room_timeline.entry(room.to_owned()).or_insert_with(|| TimelineData { + self.room_timeline.entry(room_id.to_owned()).or_insert_with(|| TimelineData { start: timeline.start.clone(), end: timeline.end.clone(), ..Default::default() @@ -398,13 +398,10 @@ impl MemoryStore { let make_room_version = || { self.room_info - .get(room) + .get(room_id) .and_then(|info| info.room_version().cloned()) .unwrap_or_else(|| { - tracing::warn!( - "Unable to find the room version for {}, assume version 9", - room - ); + warn!(%room_id, "Unable to find the room version, assuming version 9"); RoomVersionId::V9 }) }; @@ -455,7 +452,7 @@ impl MemoryStore { } } - tracing::info!("Saved changes in {:?}", now.elapsed()); + info!("Saved changes in {:?}", now.elapsed()); Ok(()) } @@ -676,7 +673,7 @@ impl MemoryStore { let (events, end_token) = if let Some(data) = self.room_timeline.get(room_id) { (data.events.clone(), data.end.clone()) } else { - tracing::info!("No timeline for {} was previously stored", room_id); + info!(%room_id, "Couldn't find a previously stored timeline"); return Ok(None); }; @@ -686,11 +683,7 @@ impl MemoryStore { } }; - tracing::info!( - "Found previously stored timeline for {}, with end token {:?}", - room_id, - end_token - ); + info!(%room_id, ?end_token, "Found previously stored timeline"); Ok(Some((Box::pin(stream), end_token))) } diff --git a/crates/matrix-sdk-crypto/Cargo.toml b/crates/matrix-sdk-crypto/Cargo.toml index e97269045..4c81c5d97 100644 --- a/crates/matrix-sdk-crypto/Cargo.toml +++ b/crates/matrix-sdk-crypto/Cargo.toml @@ -51,12 +51,25 @@ zeroize = { version = "1.3.0", features = ["zeroize_derive"] } [target.'cfg(not(target_arch = "wasm32"))'.dependencies] tokio = { version = "1.18", default-features = false, features = ["time"] } -ruma = { git = "https://github.com/ruma/ruma", rev = "ca8c66c885241a7ba3805399604eda4a38979f6b", features = ["client-api-c", "rand", "canonical-json", "unstable-msc2676", "unstable-msc2677"] } -vodozemac = { git = "https://github.com/matrix-org/vodozemac/", rev = "2404f83f7d3a3779c1f518e4d949f7da9677c3dd" } -[target.'cfg(target_arch = "wasm32")'.dependencies] -ruma = { git = "https://github.com/ruma/ruma", rev = "ca8c66c885241a7ba3805399604eda4a38979f6b", features = ["client-api-c", "js", "rand", "canonical-json", "unstable-msc2676", "unstable-msc2677"] } -vodozemac = { git = "https://github.com/matrix-org/vodozemac/", rev = "2404f83f7d3a3779c1f518e4d949f7da9677c3dd", features = ["js"] } +[target.'cfg(target_arch = "wasm32")'.dependencies.ruma] +git = "https://github.com/ruma/ruma" +rev = "ca8c66c885241a7ba3805399604eda4a38979f6b" +features = ["client-api-c", "js", "rand", "canonical-json", "unstable-msc2676", "unstable-msc2677"] + +[target.'cfg(not(target_arch = "wasm32"))'.dependencies.ruma] +git = "https://github.com/ruma/ruma" +rev = "ca8c66c885241a7ba3805399604eda4a38979f6b" +features = ["client-api-c", "rand", "canonical-json", "unstable-msc2676", "unstable-msc2677"] + +[target.'cfg(not(target_arch = "wasm32"))'.dependencies.vodozemac] +git = "https://github.com/matrix-org/vodozemac/" +rev = "18bcbc3359298894415931547ea41abb75af2d4a" + +[target.'cfg(target_arch = "wasm32")'.dependencies.vodozemac] +git = "https://github.com/matrix-org/vodozemac/" +rev = "18bcbc3359298894415931547ea41abb75af2d4a" +features = ["js"] [dev-dependencies] futures = { version = "0.3.21", default-features = false, features = ["executor"] } diff --git a/crates/matrix-sdk-crypto/README.md b/crates/matrix-sdk-crypto/README.md index 9891105ba..b8dc9ad80 100644 --- a/crates/matrix-sdk-crypto/README.md +++ b/crates/matrix-sdk-crypto/README.md @@ -17,7 +17,7 @@ The state machine works in a push/pull manner: state machine ```rust,no_run -use std::{collections::BTreeMap, convert::TryFrom}; +use std::collections::BTreeMap; use matrix_sdk_crypto::{OlmMachine, OlmError}; use ruma::{ diff --git a/crates/matrix-sdk-crypto/src/backups/keys/recovery.rs b/crates/matrix-sdk-crypto/src/backups/keys/recovery.rs index 42ef34259..4af89fe1b 100644 --- a/crates/matrix-sdk-crypto/src/backups/keys/recovery.rs +++ b/crates/matrix-sdk-crypto/src/backups/keys/recovery.rs @@ -13,7 +13,6 @@ // limitations under the License. use std::{ - convert::TryFrom, io::{Cursor, Read}, ops::DerefMut, }; diff --git a/crates/matrix-sdk-crypto/src/error.rs b/crates/matrix-sdk-crypto/src/error.rs index 9bc6574e1..aa1d05f61 100644 --- a/crates/matrix-sdk-crypto/src/error.rs +++ b/crates/matrix-sdk-crypto/src/error.rs @@ -39,7 +39,7 @@ pub enum OlmError { /// The received room key couldn't be converted into a valid Megolm session. #[error(transparent)] - SessionCreation(#[from] vodozemac::megolm::SessionKeyDecodeError), + SessionCreation(#[from] SessionCreationError), /// The storage layer returned an error. #[error("failed to read or write to the crypto store {0}")] @@ -135,7 +135,7 @@ pub enum EventError { #[error( "the room id of the room key doesn't match the room id of the \ - decrypted event: expected {0}, got {:1}" + decrypted event: expected {0}, got {1:?}" )] MismatchedRoom(OwnedRoomId, Option), } diff --git a/crates/matrix-sdk-crypto/src/file_encryption/attachments.rs b/crates/matrix-sdk-crypto/src/file_encryption/attachments.rs index 5f9facfe2..e49dace77 100644 --- a/crates/matrix-sdk-crypto/src/file_encryption/attachments.rs +++ b/crates/matrix-sdk-crypto/src/file_encryption/attachments.rs @@ -374,6 +374,6 @@ mod tests { let mut decryptor = AttachmentDecryptor::new(&mut cursor, key).unwrap(); let mut decrypted_data = Vec::new(); - assert!(decryptor.read_to_end(&mut decrypted_data).is_err()) + decryptor.read_to_end(&mut decrypted_data).unwrap_err(); } } diff --git a/crates/matrix-sdk-crypto/src/file_encryption/key_export.rs b/crates/matrix-sdk-crypto/src/file_encryption/key_export.rs index f5adde9b0..c7664e88a 100644 --- a/crates/matrix-sdk-crypto/src/file_encryption/key_export.rs +++ b/crates/matrix-sdk-crypto/src/file_encryption/key_export.rs @@ -312,7 +312,7 @@ mod tests { #[test] fn test_decode() { let export = export_without_headers(); - assert!(decode(export).is_ok()); + decode(export).unwrap(); } #[test] @@ -369,7 +369,7 @@ mod tests { )]), ); - assert_eq!(machine.import_keys(export, false, |_, _| {}).await?, keys,); + assert_eq!(machine.import_keys(export, false, |_, _| {}).await?, keys); let export = vec![session.export_at_index(10).await]; assert_eq!( @@ -379,7 +379,7 @@ mod tests { let better_export = vec![session.export().await]; - assert_eq!(machine.import_keys(better_export, false, |_, _| {}).await?, keys,); + assert_eq!(machine.import_keys(better_export, false, |_, _| {}).await?, keys); let another_session = machine.create_inbound_session(room_id).await?; let export = vec![another_session.export_at_index(10).await]; @@ -396,7 +396,7 @@ mod tests { )]), ); - assert_eq!(machine.import_keys(export, false, |_, _| {}).await?, keys,); + assert_eq!(machine.import_keys(export, false, |_, _| {}).await?, keys); Ok(()) } diff --git a/crates/matrix-sdk-crypto/src/gossiping/machine.rs b/crates/matrix-sdk-crypto/src/gossiping/machine.rs index 8b49d86b7..96053debc 100644 --- a/crates/matrix-sdk-crypto/src/gossiping/machine.rs +++ b/crates/matrix-sdk-crypto/src/gossiping/machine.rs @@ -48,7 +48,7 @@ use crate::{ requests::{OutgoingRequest, ToDeviceRequest}, session_manager::GroupSessionCache, store::{Changes, CryptoStoreError, SecretImportError, Store}, - types::events::secret_send::SecretSendEvent, + types::events::{secret_send::SecretSendEvent, EventType}, Device, }; @@ -435,7 +435,8 @@ impl GossipMachine { let request = ToDeviceRequest::new( device.user_id(), device.device_id().to_owned(), - AnyToDeviceEventContent::RoomEncrypted(content), + content.event_type(), + content.cast(), ); let request = OutgoingRequest { @@ -459,7 +460,8 @@ impl GossipMachine { let request = ToDeviceRequest::new( device.user_id(), device.device_id().to_owned(), - AnyToDeviceEventContent::RoomEncrypted(content), + content.event_type(), + content.cast(), ); let request = OutgoingRequest { @@ -589,9 +591,10 @@ impl GossipMachine { room_id: &RoomId, sender_key: &str, session_id: &str, + algorithm: &EventEncryptionAlgorithm, ) -> Result<(Option, OutgoingRequest), CryptoStoreError> { let key_info = RequestedKeyInfo::new( - EventEncryptionAlgorithm::MegolmV1AesSha2, + algorithm.to_owned(), room_id.to_owned(), sender_key.to_owned(), session_id.to_owned(), @@ -666,9 +669,10 @@ impl GossipMachine { room_id: &RoomId, sender_key: &str, session_id: &str, + algorithm: &EventEncryptionAlgorithm, ) -> Result { let key_info = RequestedKeyInfo::new( - EventEncryptionAlgorithm::MegolmV1AesSha2, + algorithm.to_owned(), room_id.to_owned(), sender_key.to_owned(), session_id.to_owned(), @@ -950,10 +954,9 @@ mod tests { device_id, events::{ forwarded_room_key::ToDeviceForwardedRoomKeyEventContent, - room::encrypted::ToDeviceRoomEncryptedEventContent, room_key_request::ToDeviceRoomKeyRequestEventContent, secret::request::{RequestAction, SecretName, ToDeviceSecretRequestEventContent}, - AnyToDeviceEvent, ToDeviceEvent, + AnyToDeviceEvent, ToDeviceEvent as RumaToDeviceEvent, }, room_id, to_device::DeviceIdOrAllDevices, @@ -966,6 +969,8 @@ mod tests { olm::{Account, PrivateCrossSigningIdentity, ReadOnlyAccount}, session_manager::GroupSessionCache, store::{Changes, CryptoStore, MemoryStore, Store}, + types::events::{room::encrypted::ToDeviceEncryptedEventContent, ToDeviceEvent}, + utilities::json_convert, verification::VerificationMachine, OutgoingRequests, }; @@ -1064,7 +1069,12 @@ mod tests { assert!(machine.outgoing_to_device_requests().await.unwrap().is_empty()); let (cancel, request) = machine - .request_key(session.room_id(), &session.sender_key, session.session_id()) + .request_key( + session.room_id(), + &session.sender_key, + session.session_id(), + session.algorithm(), + ) .await .unwrap(); @@ -1073,7 +1083,12 @@ mod tests { machine.mark_outgoing_request_as_sent(&request.request_id).await.unwrap(); let (cancel, _) = machine - .request_key(session.room_id(), &session.sender_key, session.session_id()) + .request_key( + session.room_id(), + &session.sender_key, + session.session_id(), + session.algorithm(), + ) .await .unwrap(); @@ -1099,6 +1114,7 @@ mod tests { session.room_id(), &session.sender_key, session.session_id(), + session.algorithm(), ) .await .unwrap(); @@ -1110,6 +1126,7 @@ mod tests { session.room_id(), &session.sender_key, session.session_id(), + session.algorithm(), ) .await .unwrap(); @@ -1141,6 +1158,7 @@ mod tests { session.room_id(), &session.sender_key, session.session_id(), + session.algorithm(), ) .await .unwrap(); @@ -1155,7 +1173,7 @@ mod tests { let content: ToDeviceForwardedRoomKeyEventContent = export.try_into().unwrap(); - let event = ToDeviceEvent { sender: alice_id().to_owned(), content }; + let event = RumaToDeviceEvent { sender: alice_id().to_owned(), content }; assert!( machine @@ -1189,6 +1207,7 @@ mod tests { session.room_id(), &session.sender_key, session.session_id(), + session.algorithm(), ) .await .unwrap(); @@ -1202,7 +1221,7 @@ mod tests { let content: ToDeviceForwardedRoomKeyEventContent = export.try_into().unwrap(); - let event = ToDeviceEvent { sender: alice_id().to_owned(), content }; + let event = RumaToDeviceEvent { sender: alice_id().to_owned(), content }; let second_session = machine.receive_forwarded_room_key(&session.sender_key, &event).await.unwrap(); @@ -1213,7 +1232,7 @@ mod tests { let content: ToDeviceForwardedRoomKeyEventContent = export.try_into().unwrap(); - let event = ToDeviceEvent { sender: alice_id().to_owned(), content }; + let event = RumaToDeviceEvent { sender: alice_id().to_owned(), content }; let second_session = machine.receive_forwarded_room_key(&session.sender_key, &event).await.unwrap(); @@ -1238,7 +1257,7 @@ mod tests { ); own_device.set_trust_state(LocalTrust::Verified); // Now we do want to share the keys. - assert!(machine.should_share_key(&own_device, &inbound).await.is_ok()); + machine.should_share_key(&own_device, &inbound).await.unwrap(); let bob_device = ReadOnlyDevice::from_account(&bob_account()).await; machine.store.save_devices(&[bob_device]).await.unwrap(); @@ -1284,7 +1303,7 @@ mod tests { bob_device.curve25519_key().unwrap(), ) .await; - assert!(machine.should_share_key(&bob_device, &inbound).await.is_ok()); + machine.should_share_key(&bob_device, &inbound).await.unwrap(); let (other_outbound, other_inbound) = account.create_group_session_pair_with_defaults(room_id()).await; @@ -1375,6 +1394,7 @@ mod tests { room_id(), &bob_account.identity_keys.curve25519.to_base64(), group_session.session_id(), + &group_session.settings().algorithm, ) .await .unwrap(); @@ -1406,7 +1426,7 @@ mod tests { alice_machine.mark_outgoing_request_as_sent(id).await.unwrap(); - let event = ToDeviceEvent { sender: alice_id().to_owned(), content }; + let event = RumaToDeviceEvent { sender: alice_id().to_owned(), content }; // Bob doesn't have any outgoing requests. assert!(bob_machine.outgoing_requests.is_empty()); @@ -1431,11 +1451,13 @@ mod tests { .unwrap() .get(&DeviceIdOrAllDevices::DeviceId(alice_device_id().to_owned())) .unwrap(); - let content: ToDeviceRoomEncryptedEventContent = content.deserialize_as().unwrap(); + let content: ToDeviceEncryptedEventContent = content.deserialize_as().unwrap(); bob_machine.mark_outgoing_request_as_sent(id).await.unwrap(); - let event = ToDeviceEvent { sender: bob_id().to_owned(), content }; + let event = + ToDeviceEvent { sender: bob_id().to_owned(), content, other: Default::default() }; + let event = json_convert(&event).unwrap(); // Check that alice doesn't have the session. assert!(alice_machine @@ -1488,7 +1510,7 @@ mod tests { alice_machine.store.save_sessions(&[alice_session]).await.unwrap(); - let event = ToDeviceEvent { + let event = RumaToDeviceEvent { sender: bob_account.user_id().to_owned(), content: ToDeviceSecretRequestEventContent::new( RequestAction::Request(SecretName::CrossSigningMasterKey), @@ -1517,7 +1539,7 @@ mod tests { alice_machine.collect_incoming_key_requests().await.unwrap(); assert!(alice_machine.outgoing_requests.is_empty()); - let event = ToDeviceEvent { + let event = RumaToDeviceEvent { sender: alice_id().to_owned(), content: ToDeviceSecretRequestEventContent::new( RequestAction::Request(SecretName::CrossSigningMasterKey), @@ -1577,6 +1599,7 @@ mod tests { room_id(), &bob_account.identity_keys.curve25519.to_base64(), group_session.session_id(), + &group_session.settings().algorithm, ) .await .unwrap(); @@ -1608,7 +1631,7 @@ mod tests { alice_machine.mark_outgoing_request_as_sent(id).await.unwrap(); - let event = ToDeviceEvent { sender: alice_id().to_owned(), content }; + let event = RumaToDeviceEvent { sender: alice_id().to_owned(), content }; // Bob doesn't have any outgoing requests. assert!(bob_machine.outgoing_to_device_requests().await.unwrap().is_empty()); @@ -1653,11 +1676,13 @@ mod tests { .unwrap() .get(&DeviceIdOrAllDevices::DeviceId(alice_device_id().to_owned())) .unwrap(); - let content: ToDeviceRoomEncryptedEventContent = content.deserialize_as().unwrap(); + let content: ToDeviceEncryptedEventContent = content.deserialize_as().unwrap(); bob_machine.mark_outgoing_request_as_sent(id).await.unwrap(); - let event = ToDeviceEvent { sender: bob_id().to_owned(), content }; + let event = + ToDeviceEvent { sender: bob_id().to_owned(), content, other: Default::default() }; + let event = json_convert(&event).unwrap(); // Check that alice doesn't have the session. assert!(alice_machine diff --git a/crates/matrix-sdk-crypto/src/gossiping/mod.rs b/crates/matrix-sdk-crypto/src/gossiping/mod.rs index 58b35f9ff..7e381f7b4 100644 --- a/crates/matrix-sdk-crypto/src/gossiping/mod.rs +++ b/crates/matrix-sdk-crypto/src/gossiping/mod.rs @@ -180,10 +180,11 @@ impl GossipRequest { } }; - let request = ToDeviceRequest::new( + let request = ToDeviceRequest::with_id( &self.request_recipient, DeviceIdOrAllDevices::AllDevices, content, + TransactionId::new(), ); OutgoingRequest { request_id: request.txn_id.clone(), request: Arc::new(request.into()) } diff --git a/crates/matrix-sdk-crypto/src/identities/device.rs b/crates/matrix-sdk-crypto/src/identities/device.rs index b4ed0d601..83920a54e 100644 --- a/crates/matrix-sdk-crypto/src/identities/device.rs +++ b/crates/matrix-sdk-crypto/src/identities/device.rs @@ -28,9 +28,9 @@ use ruma::{ api::client::keys::upload_signatures::v3::Request as SignatureUploadRequest, events::{ forwarded_room_key::ToDeviceForwardedRoomKeyEventContent, - key::verification::VerificationMethod, room::encrypted::ToDeviceRoomEncryptedEventContent, - AnyToDeviceEventContent, + key::verification::VerificationMethod, AnyToDeviceEventContent, }, + serde::Raw, DeviceId, DeviceKeyAlgorithm, DeviceKeyId, EventEncryptionAlgorithm, OwnedDeviceId, OwnedDeviceKeyId, UserId, }; @@ -46,7 +46,10 @@ use crate::{ identities::{ReadOnlyOwnUserIdentity, ReadOnlyUserIdentities}, olm::{InboundGroupSession, Session, SignedJsonObject, VerifyJson}, store::{Changes, CryptoStore, DeviceChanges, Result as StoreResult}, - types::{DeviceKey, DeviceKeys, Signatures, SignedKey}, + types::{ + events::room::encrypted::ToDeviceEncryptedEventContent, DeviceKey, DeviceKeys, Signatures, + SignedKey, + }, verification::VerificationMachine, OutgoingVerificationRequest, ReadOnlyAccount, Sas, ToDeviceRequest, VerificationRequest, }; @@ -255,7 +258,7 @@ impl Device { pub(crate) async fn encrypt( &self, content: AnyToDeviceEventContent, - ) -> OlmResult<(Session, ToDeviceRoomEncryptedEventContent)> { + ) -> OlmResult<(Session, Raw)> { self.inner.encrypt(self.verification_machine.store.inner(), content).await } @@ -265,7 +268,7 @@ impl Device { &self, session: InboundGroupSession, message_index: Option, - ) -> OlmResult<(Session, ToDeviceRoomEncryptedEventContent)> { + ) -> OlmResult<(Session, Raw)> { let export = if let Some(index) = message_index { session.export_at_index(index).await } else { @@ -512,7 +515,7 @@ impl ReadOnlyDevice { &self, store: &dyn CryptoStore, content: AnyToDeviceEventContent, - ) -> OlmResult<(Session, ToDeviceRoomEncryptedEventContent)> { + ) -> OlmResult<(Session, Raw)> { let sender_key = if let Some(k) = self.curve25519_key() { k } else { diff --git a/crates/matrix-sdk-crypto/src/identities/manager.rs b/crates/matrix-sdk-crypto/src/identities/manager.rs index dd566c3b7..e3d6fec86 100644 --- a/crates/matrix-sdk-crypto/src/identities/manager.rs +++ b/crates/matrix-sdk-crypto/src/identities/manager.rs @@ -14,7 +14,6 @@ use std::{ collections::{BTreeMap, BTreeSet, HashSet}, - convert::TryFrom, ops::Deref, sync::Arc, time::Duration, @@ -623,7 +622,7 @@ impl IdentityManager { } if let Err(e) = self.store.update_tracked_user(user, true).await { - warn!("Error storing users for tracking {}", e); + warn!("Error storing users for tracking: {e}"); } } } @@ -861,7 +860,7 @@ pub(crate) mod tests { manager.receive_keys_query_response(&other_key_query()).await.unwrap(); - assert!(task.await.unwrap().is_ok()); + task.await.unwrap().unwrap(); let devices = manager.store.get_user_devices(other_user).await.unwrap(); assert_eq!(devices.devices().count(), 1); @@ -875,7 +874,7 @@ pub(crate) mod tests { let identity = manager.store.get_user_identity(other_user).await.unwrap().unwrap(); let identity = identity.other().unwrap(); - assert!(identity.is_device_signed(&device).is_ok()) + identity.is_device_signed(&device).unwrap(); } #[async_test] @@ -899,7 +898,7 @@ pub(crate) mod tests { let identity = manager.store.get_user_identity(other_user).await.unwrap().unwrap(); let identity = identity.other().unwrap(); - assert!(identity.is_device_signed(&device).is_ok()) + identity.is_device_signed(&device).unwrap(); } #[async_test] diff --git a/crates/matrix-sdk-crypto/src/identities/user.rs b/crates/matrix-sdk-crypto/src/identities/user.rs index 6ca61fd9d..8e871d04b 100644 --- a/crates/matrix-sdk-crypto/src/identities/user.rs +++ b/crates/matrix-sdk-crypto/src/identities/user.rs @@ -1014,8 +1014,8 @@ pub(crate) mod tests { let identity = get_own_identity(); let (first, second) = device(&response); - assert!(identity.is_device_signed(&first).is_err()); - assert!(identity.is_device_signed(&second).is_ok()); + identity.is_device_signed(&first).unwrap_err(); + identity.is_device_signed(&second).unwrap(); let private_identity = Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty(second.user_id()))); diff --git a/crates/matrix-sdk-crypto/src/machine.rs b/crates/matrix-sdk-crypto/src/machine.rs index ada12d5b8..5e5f1eced 100644 --- a/crates/matrix-sdk-crypto/src/machine.rs +++ b/crates/matrix-sdk-crypto/src/machine.rs @@ -35,12 +35,7 @@ use ruma::{ }, assign, events::{ - room::encrypted::{ - EncryptedEventScheme, MegolmV1AesSha2Content, OriginalSyncRoomEncryptedEvent, - RoomEncryptedEventContent, ToDeviceRoomEncryptedEvent, - }, - secret::request::SecretName, - AnyMessageLikeEvent, AnyRoomEvent, MessageLikeEventContent, + secret::request::SecretName, AnyMessageLikeEvent, AnyRoomEvent, MessageLikeEventContent, }, serde::Raw, DeviceId, DeviceKeyAlgorithm, OwnedDeviceKeyId, OwnedTransactionId, OwnedUserId, RoomId, @@ -69,6 +64,10 @@ use crate::{ }, types::{ events::{ + room::encrypted::{ + EncryptedEvent, EncryptedToDeviceEvent, RoomEncryptedEventContent, + RoomEventEncryptionScheme, SupportedEventEncryptionSchemes, + }, room_key::{RoomKeyContent, RoomKeyEvent}, ToDeviceEvents, }, @@ -538,7 +537,7 @@ impl OlmMachine { /// * `event` - The to-device event that should be decrypted. async fn decrypt_to_device_event( &self, - event: &ToDeviceRoomEncryptedEvent, + event: &EncryptedToDeviceEvent, ) -> OlmResult { let mut decrypted = self.account.decrypt_to_device_event(event).await?; // Handle the decrypted event, e.g. fetch out Megolm sessions out of @@ -555,6 +554,15 @@ impl OlmMachine { signing_key: &str, event: &RoomKeyEvent, ) -> OlmResult> { + let unsupported_warning = || { + warn!( + sender = %event.sender, + sender_key = sender_key, + algorithm = %event.algorithm(), + "Received room key with unsupported key algorithm", + ); + }; + match &event.content { RoomKeyContent::MegolmV1AesSha2(content) => { let session = InboundGroupSession::new( @@ -562,26 +570,14 @@ impl OlmMachine { signing_key, &content.room_id, &content.session_key, + event.algorithm(), None, ); - info!( - sender = %event.sender, - sender_key = sender_key, - room_id = %content.room_id, - session_id = session.session_id(), - "Received a new room key", - ); - Ok(Some(session)) } - RoomKeyContent::Unknown(content) => { - warn!( - sender = %event.sender, - sender_key = sender_key, - algorithm = ?content.algorithm, - "Received room key with unsupported key algorithm", - ); + RoomKeyContent::Unknown(_) => { + unsupported_warning(); Ok(None) } } @@ -636,7 +632,7 @@ impl OlmMachine { &self, room_id: &RoomId, content: impl MessageLikeEventContent, - ) -> MegolmResult { + ) -> MegolmResult> { let event_type = content.event_type().to_string(); let content = serde_json::to_value(&content)?; self.encrypt_room_event_raw(room_id, content, &event_type).await @@ -666,7 +662,7 @@ impl OlmMachine { room_id: &RoomId, content: Value, event_type: &str, - ) -> MegolmResult { + ) -> MegolmResult> { self.group_session_manager.encrypt(room_id, content, event_type).await } @@ -988,11 +984,13 @@ impl OlmMachine { /// * `session_id` - The id that uniquely identifies the session. pub async fn request_room_key( &self, - event: &OriginalSyncRoomEncryptedEvent, + event: &Raw, room_id: &RoomId, ) -> MegolmResult<(Option, OutgoingRequest)> { - let content = match &event.content.scheme { - EncryptedEventScheme::MegolmV1AesSha2(c) => c, + let event = event.deserialize()?; + + let content: SupportedEventEncryptionSchemes<'_> = match &event.content.scheme { + RoomEventEncryptionScheme::MegolmV1AesSha2(c) => c.into(), _ => return Err(EventError::UnsupportedAlgorithm.into()), }; @@ -1001,8 +999,9 @@ impl OlmMachine { .request_key( room_id, #[allow(deprecated)] - &content.sender_key, - &content.session_id, + &content.sender_key().to_base64(), + content.session_id(), + &content.algorithm(), ) .await?) } @@ -1043,19 +1042,19 @@ impl OlmMachine { }) } - async fn decrypt_megolm_v1_event( + async fn decrypt_megolm_events( &self, room_id: &RoomId, - event: &OriginalSyncRoomEncryptedEvent, - content: &MegolmV1AesSha2Content, + event: &EncryptedEvent, + content: &SupportedEventEncryptionSchemes<'_>, ) -> MegolmResult { if let Some(session) = self .store .get_inbound_group_session( room_id, #[allow(deprecated)] - &content.sender_key, - &content.session_id, + &content.sender_key().to_base64(), + content.session_id(), ) .await? { @@ -1071,6 +1070,7 @@ impl OlmMachine { room_id = room_id.as_str(), session_id = session.session_id(), sender_key = session.sender_key(), + algorithm = %session.algorithm(), "Successfully decrypted a room event" ); @@ -1084,6 +1084,7 @@ impl OlmMachine { room_id = room_id.as_str(), session_id = session.session_id(), sender_key = session.sender_key(), + algorithm = %session.algorithm(), error = ?e, "Event was successfully decrypted but has an invalid format" ); @@ -1095,7 +1096,7 @@ impl OlmMachine { &session, &event.sender, #[allow(deprecated)] - &content.device_id, + content.device_id(), ) .await?; @@ -1105,8 +1106,9 @@ impl OlmMachine { .create_outgoing_key_request( room_id, #[allow(deprecated)] - &content.sender_key, - &content.session_id, + &content.sender_key().to_base64(), + content.session_id(), + &content.algorithm(), ) .await?; @@ -1123,49 +1125,50 @@ impl OlmMachine { /// * `room_id` - The ID of the room where the event was sent to. pub async fn decrypt_room_event( &self, - event: &OriginalSyncRoomEncryptedEvent, + event: &Raw, room_id: &RoomId, ) -> MegolmResult { - match &event.content.scheme { - EncryptedEventScheme::MegolmV1AesSha2(c) => { - match self.decrypt_megolm_v1_event(room_id, event, c).await { - Ok(r) => Ok(r), - Err(e) => { - #[allow(deprecated)] - if let MegolmError::MissingRoomKey = e { - // TODO log the withheld reason if we have one. - debug!( - sender = event.sender.as_str(), - room_id = room_id.as_str(), - sender_key = c.sender_key.as_str(), - session_id = c.session_id.as_str(), - "Failed to decrypt a room event, the room key is missing" - ); - } else { - warn!( - sender = event.sender.as_str(), - room_id = room_id.as_str(), - sender_key = c.sender_key.as_str(), - session_id = c.session_id.as_str(), - error = ?e, - "Failed to decrypt a room event" - ); - } + let event = event.deserialize()?; - Err(e) - } - } - } - algorithm => { + let content = match &event.content.scheme { + RoomEventEncryptionScheme::MegolmV1AesSha2(c) => c.into(), + RoomEventEncryptionScheme::Unknown(c) => { warn!( sender = event.sender.as_str(), room_id = room_id.as_str(), - ?algorithm, + algorithm = %c.algorithm, "Received an encrypted room event with an unsupported algorithm" ); - Err(EventError::UnsupportedAlgorithm.into()) + + return Err(EventError::UnsupportedAlgorithm.into()); } - } + }; + + self.decrypt_megolm_events(room_id, &event, &content).await.map_err(|e| { + if let MegolmError::MissingRoomKey = e { + // TODO log the withheld reason if we have one. + debug!( + sender = event.sender.as_str(), + room_id = room_id.as_str(), + sender_key = content.sender_key().to_base64(), + session_id = content.session_id(), + algorithm = %content.algorithm(), + "Failed to decrypt a room event, the room key is missing" + ); + } else { + warn!( + sender = event.sender.as_str(), + room_id = room_id.as_str(), + sender_key = content.sender_key().to_base64(), + session_id = content.session_id(), + algorithm = %content.algorithm(), + error = ?e, + "Failed to decrypt a room event" + ); + } + + e + }) } /// Update the tracked users. @@ -1211,7 +1214,6 @@ impl OlmMachine { /// # Example /// /// ``` - /// # use std::convert::TryFrom; /// # use matrix_sdk_crypto::OlmMachine; /// # use ruma::{device_id, user_id}; /// # use futures::executor::block_on; @@ -1269,7 +1271,6 @@ impl OlmMachine { /// # Example /// /// ``` - /// # use std::convert::TryFrom; /// # use matrix_sdk_crypto::OlmMachine; /// # use ruma::{device_id, user_id}; /// # use futures::executor::block_on; @@ -1372,24 +1373,35 @@ impl OlmMachine { let mut keys = BTreeMap::new(); for (i, key) in exported_keys.into_iter().enumerate() { - let session = InboundGroupSession::from_export(key); + match InboundGroupSession::from_export(&key) { + Ok(session) => { + // Only import the session if we didn't have this session or if it's + // a better version of the same session, that is the first known + // index is lower. + if !existing_sessions.has_better_session(&session) { + #[cfg(feature = "backups_v1")] + if from_backup { + session.mark_as_backed_up(); + } - // Only import the session if we didn't have this session or if it's - // a better version of the same session, that is the first known - // index is lower. - if !existing_sessions.has_better_session(&session) { - #[cfg(feature = "backups_v1")] - if from_backup { - session.mark_as_backed_up(); + keys.entry(session.room_id().to_owned()) + .or_insert_with(BTreeMap::new) + .entry(session.sender_key().to_owned()) + .or_insert_with(BTreeSet::new) + .insert(session.session_id().to_owned()); + + sessions.push(session); + } + } + Err(e) => { + warn!( + sender_key= key.sender_key, + room_id = %key.room_id, + session_id = key.session_id, + error = ?e, + "Couldn't import a room key from a file export." + ); } - - keys.entry(session.room_id().to_owned()) - .or_insert_with(BTreeMap::new) - .entry(session.sender_key().to_owned()) - .or_insert_with(BTreeSet::new) - .insert(session.session_id().to_owned()); - - sessions.push(session); } progress_listener(i, total_count); @@ -1554,7 +1566,7 @@ pub(crate) mod testing { #[cfg(test)] pub(crate) mod tests { - use std::{collections::BTreeMap, convert::TryInto, iter, sync::Arc}; + use std::{collections::BTreeMap, iter, sync::Arc}; use matrix_sdk_test::{async_test, test_json}; use ruma::{ @@ -1572,26 +1584,29 @@ pub(crate) mod tests { dummy::ToDeviceDummyEventContent, key::verification::VerificationMethod, room::{ - encrypted::ToDeviceRoomEncryptedEventContent, + encrypted::OriginalSyncRoomEncryptedEvent, message::{MessageType, RoomMessageEventContent}, }, AnyMessageLikeEvent, AnyMessageLikeEventContent, AnyRoomEvent, AnyToDeviceEvent, AnyToDeviceEventContent, MessageLikeEvent, MessageLikeUnsigned, - OriginalMessageLikeEvent, OriginalSyncMessageLikeEvent, ToDeviceEvent, + OriginalMessageLikeEvent, }, room_id, serde::Raw, uint, user_id, DeviceId, DeviceKeyAlgorithm, DeviceKeyId, MilliSecondsSinceUnixEpoch, OwnedDeviceKeyId, UserId, }; - use serde_json::value::to_raw_value; use vodozemac::Ed25519PublicKey; use super::testing::response_from_file; use crate::{ machine::OlmMachine, olm::VerifyJson, - types::{DeviceKeys, SignedKey}, + types::{ + events::{room::encrypted::ToDeviceEncryptedEventContent, ToDeviceEvent}, + DeviceKeys, SignedKey, + }, + utilities::json_convert, verification::tests::{outgoing_request_to_event, request_to_event}, EncryptionSettings, ReadOnlyDevice, ToDeviceRequest, }; @@ -1625,7 +1640,7 @@ pub(crate) mod tests { fn to_device_requests_to_content( requests: Vec>, - ) -> ToDeviceRoomEncryptedEventContent { + ) -> ToDeviceEncryptedEventContent { let to_device_request = &requests[0]; to_device_request @@ -1706,7 +1721,11 @@ pub(crate) mod tests { .unwrap(); alice.store.save_sessions(&[session]).await.unwrap(); - let event = ToDeviceEvent { sender: alice.user_id().to_owned(), content }; + let event = ToDeviceEvent { + sender: alice.user_id().to_owned(), + content: content.deserialize_as().unwrap(), + other: Default::default(), + }; let decrypted = bob.decrypt_to_device_event(&event).await.unwrap(); bob.store.save_sessions(&[decrypted.session.session()]).await.unwrap(); @@ -1749,7 +1768,7 @@ pub(crate) mod tests { &DeviceKeyId::from_parts(DeviceKeyAlgorithm::Ed25519, machine.device_id()), &device_keys, ); - assert!(ret.is_ok()); + ret.unwrap(); } #[async_test] @@ -1782,7 +1801,7 @@ pub(crate) mod tests { &DeviceKeyId::from_parts(DeviceKeyAlgorithm::Ed25519, machine.device_id()), &device_keys, ); - assert!(ret.is_err()); + ret.unwrap_err(); } #[async_test] @@ -1832,7 +1851,7 @@ pub(crate) mod tests { &DeviceKeyId::from_parts(DeviceKeyAlgorithm::Ed25519, machine.device_id()), &one_time_key, ); - assert!(ret.is_ok()); + ret.unwrap(); let device_keys: DeviceKeys = request.device_keys.unwrap().deserialize_as().unwrap(); @@ -1841,7 +1860,7 @@ pub(crate) mod tests { &DeviceKeyId::from_parts(DeviceKeyAlgorithm::Ed25519, machine.device_id()), &device_keys, ); - assert!(ret.is_ok()); + ret.unwrap(); let mut response = keys_upload_response(); response.one_time_key_counts.insert( @@ -1928,7 +1947,10 @@ pub(crate) mod tests { .encrypt(AnyToDeviceEventContent::Dummy(ToDeviceDummyEventContent::new())) .await .unwrap() - .1, + .1 + .deserialize_as() + .unwrap(), + other: Default::default(), }; let event = bob.decrypt_to_device_event(&event).await.unwrap().event.deserialize().unwrap(); @@ -1954,8 +1976,9 @@ pub(crate) mod tests { let event = ToDeviceEvent { sender: alice.user_id().to_owned(), content: to_device_requests_to_content(to_device_requests), + other: Default::default(), }; - let event = Raw::from_json(to_raw_value(&event).unwrap()); + let event = json_convert(&event).unwrap(); let alice_session = alice.group_session_manager.get_outbound_group_session(room_id).unwrap(); @@ -2002,6 +2025,7 @@ pub(crate) mod tests { let event = ToDeviceEvent { sender: alice.user_id().to_owned(), content: to_device_requests_to_content(to_device_requests), + other: Default::default(), }; let group_session = @@ -2017,14 +2041,16 @@ pub(crate) mod tests { .await .unwrap(); - let event = OriginalSyncMessageLikeEvent { + let event = OriginalSyncRoomEncryptedEvent { event_id: event_id!("$xxxxx:example.org").to_owned(), origin_server_ts: MilliSecondsSinceUnixEpoch::now(), sender: alice.user_id().to_owned(), - content: encrypted_content, + content: encrypted_content.deserialize_as().unwrap(), unsigned: MessageLikeUnsigned::default(), }; + let event = json_convert(&event).unwrap(); + let decrypted_event = bob.decrypt_room_event(&event, room_id).await.unwrap().event.deserialize().unwrap(); diff --git a/crates/matrix-sdk-crypto/src/olm/account.rs b/crates/matrix-sdk-crypto/src/olm/account.rs index 2c12ef75b..343c41fb5 100644 --- a/crates/matrix-sdk-crypto/src/olm/account.rs +++ b/crates/matrix-sdk-crypto/src/olm/account.rs @@ -14,7 +14,6 @@ use std::{ collections::{BTreeMap, HashMap}, - convert::TryInto, fmt, ops::Deref, sync::{ @@ -29,12 +28,7 @@ use ruma::{ upload_keys, upload_signatures::v3::{Request as SignatureUploadRequest, SignedKeys}, }, - events::{ - room::encrypted::{ - EncryptedEventScheme, OlmV1Curve25519AesSha2Content, ToDeviceRoomEncryptedEvent, - }, - AnyToDeviceEvent, OlmV1Keys, - }, + events::{AnyToDeviceEvent, OlmV1Keys}, serde::Raw, DeviceId, DeviceKeyAlgorithm, DeviceKeyId, EventEncryptionAlgorithm, OwnedDeviceId, OwnedDeviceKeyId, OwnedUserId, RoomId, SecondsSinceUnixEpoch, UInt, UserId, @@ -57,7 +51,12 @@ use crate::{ identities::{MasterPubkey, ReadOnlyDevice}, requests::UploadSigningKeysRequest, store::{Changes, Store}, - types::{CrossSigningKey, DeviceKeys, OneTimeKey, SignedKey}, + types::{ + events::room::encrypted::{ + EncryptedToDeviceEvent, OlmV1Curve25519AesSha2Content, ToDeviceEncryptedEventContent, + }, + CrossSigningKey, DeviceKeys, OneTimeKey, SignedKey, + }, utilities::encode, CryptoStoreError, OlmError, SignatureError, }; @@ -118,14 +117,17 @@ pub struct OlmMessageHash { } impl OlmMessageHash { - fn new(sender_key: &str, message_type: u8, ciphertext: &str) -> Self { + fn new(sender_key: Curve25519PublicKey, ciphertext: &OlmMessage) -> Self { + let (message_type, ciphertext) = ciphertext.clone().to_parts(); + let sender_key = sender_key.to_base64(); + let sha = Sha256::new() - .chain_update(sender_key) - .chain_update(&[message_type]) + .chain_update(sender_key.as_bytes()) + .chain_update(&[message_type as u8]) .chain_update(&ciphertext) .finalize(); - Self { sender_key: sender_key.to_owned(), hash: encode(sha.as_slice()) } + Self { sender_key, hash: encode(sha.as_slice()) } } } @@ -138,100 +140,86 @@ impl Deref for Account { } impl Account { - fn parse_message( - sender_key: &str, - message_type: UInt, - ciphertext: String, - ) -> Result<(OlmMessage, OlmMessageHash), EventError> { - let message_type: u8 = message_type - .try_into() - .map_err(|_| EventError::UnsupportedOlmType(message_type.into()))?; - - let message_hash = OlmMessageHash::new(sender_key, message_type, &ciphertext); - let message = OlmMessage::from_parts(message_type.into(), &ciphertext) - .map_err(|_| EventError::UnsupportedOlmType(message_type.into()))?; - - Ok((message, message_hash)) - } - pub async fn save(&self) -> Result<(), CryptoStoreError> { self.store.save_account(self.inner.clone()).await } + async fn decrypt_olm_helper( + &self, + sender: &UserId, + sender_key: Curve25519PublicKey, + ciphertext: &OlmMessage, + ) -> OlmResult { + let message_hash = OlmMessageHash::new(sender_key, ciphertext); + + match self.decrypt_olm_message(sender, sender_key, ciphertext).await { + Ok((session, event, signing_key)) => Ok(OlmDecryptionInfo { + sender: sender.to_owned(), + session, + message_hash, + event, + signing_key, + sender_key: sender_key.to_base64(), + inbound_group_session: None, + }), + Err(OlmError::SessionWedged(user_id, sender_key)) => { + if self.store.is_message_known(&message_hash).await? { + info!( + sender = sender.as_str(), + sender_key, "An Olm message got replayed, decryption failed" + ); + + Err(OlmError::ReplayedMessage(user_id, sender_key)) + } else { + Err(OlmError::SessionWedged(user_id, sender_key)) + } + } + Err(e) => Err(e), + } + } + async fn decrypt_olm_v1( &self, sender: &UserId, content: &OlmV1Curve25519AesSha2Content, ) -> OlmResult { - let identity_keys = self.inner.identity_keys(); - - // Try to find a ciphertext that was meant for our device. - if let Some(ciphertext) = content.ciphertext.get(&identity_keys.curve25519.to_base64()) { - let (message, message_hash) = match Self::parse_message( - &content.sender_key, - ciphertext.message_type, - ciphertext.body.clone(), - ) { - Ok(m) => m, - Err(e) => { - warn!(error = ?e, "Encrypted to-device event isn't valid"); - return Err(e.into()); - } - }; - - // Decrypt the OlmMessage and get a Ruma event out of it. - match self.decrypt_olm_message(sender, &content.sender_key, message).await { - Ok((session, event, signing_key)) => Ok(OlmDecryptionInfo { - sender: sender.to_owned(), - session, - message_hash, - event, - signing_key, - sender_key: content.sender_key.clone(), - inbound_group_session: None, - }), - Err(OlmError::SessionWedged(user_id, sender_key)) => { - if self.store.is_message_known(&message_hash).await? { - info!( - sender = sender.as_str(), - sender_key = content.sender_key.as_str(), - "An Olm message got replayed, decryption failed" - ); - - Err(OlmError::ReplayedMessage(user_id, sender_key)) - } else { - Err(OlmError::SessionWedged(user_id, sender_key)) - } - } - Err(e) => Err(e), - } - } else { + if content.recipient_key != self.identity_keys().curve25519 { warn!( sender = sender.as_str(), - sender_key = content.sender_key.as_str(), + sender_key = content.sender_key.to_base64(), "Olm event doesn't contain a ciphertext for our key" ); Err(EventError::MissingCiphertext.into()) + } else { + self.decrypt_olm_helper(sender, content.sender_key, &content.ciphertext).await } } pub(crate) async fn decrypt_to_device_event( &self, - event: &ToDeviceRoomEncryptedEvent, + event: &EncryptedToDeviceEvent, ) -> OlmResult { - trace!(sender = event.sender.as_str(), "Decrypting a to-device event"); + trace!( + sender = event.sender.as_str(), + algorithm = %event.content.algorithm(), + "Decrypting a to-device event" + ); - if let EncryptedEventScheme::OlmV1Curve25519AesSha2(c) = &event.content.scheme { - self.decrypt_olm_v1(&event.sender, c).await - } else { - warn!( - sender = event.sender.as_str(), - algorithm = ?event.content.scheme, - "Error, unsupported encryption algorithm" - ); + match &event.content { + ToDeviceEncryptedEventContent::OlmV1Curve25519AesSha2(c) => { + self.decrypt_olm_v1(&event.sender, c).await + } + ToDeviceEncryptedEventContent::Unknown(_) => { + warn!( + sender = event.sender.as_str(), + algorithm = %event.content.algorithm(), + "Error decrypting an to-device event, unsupported \ + encryption algorithm" + ); - Err(EventError::UnsupportedAlgorithm.into()) + Err(EventError::UnsupportedAlgorithm.into()) + } } } @@ -258,10 +246,10 @@ impl Account { /// with the given sender. async fn decrypt_with_existing_sessions( &self, - sender_key: &str, + sender_key: Curve25519PublicKey, message: &OlmMessage, ) -> OlmResult> { - let s = self.store.get_sessions(sender_key).await?; + let s = self.store.get_sessions(&sender_key.to_base64()).await?; // We don't have any existing sessions, return early. let sessions = if let Some(s) = s { @@ -293,12 +281,12 @@ impl Account { async fn decrypt_olm_message( &self, sender: &UserId, - sender_key: &str, - message: OlmMessage, + sender_key: Curve25519PublicKey, + message: &OlmMessage, ) -> OlmResult<(SessionType, Raw, String)> { // First try to decrypt using an existing session. let (session, plaintext) = if let Some(d) = - self.decrypt_with_existing_sessions(sender_key, &message).await? + self.decrypt_with_existing_sessions(sender_key, message).await? { // Decryption succeeded, de-structure the session/plaintext out of // the Option. @@ -306,17 +294,17 @@ impl Account { } else { // Decryption failed with every known session, let's try to create a // new session. - match &message { + match message { // A new session can only be created using a pre-key message, // return with an error if it isn't one. OlmMessage::Normal(_) => { warn!( sender = sender.as_str(), - sender_key = sender_key, + sender_key = sender_key.to_base64(), "Failed to decrypt a non-pre-key message with all \ available sessions", ); - return Err(OlmError::SessionWedged(sender.to_owned(), sender_key.to_owned())); + return Err(OlmError::SessionWedged(sender.to_owned(), sender_key.to_base64())); } OlmMessage::PreKey(m) => { @@ -326,14 +314,14 @@ impl Account { Err(e) => { warn!( sender = sender.as_str(), - sender_key = sender_key, + sender_key = sender_key.to_base64(), error = ?e, "Failed to create a new Olm session from a \ prekey message", ); return Err(OlmError::SessionWedged( sender.to_owned(), - sender_key.to_owned(), + sender_key.to_base64(), )); } }; @@ -356,7 +344,7 @@ impl Account { trace!( sender = sender.as_str(), - sender_key = sender_key, + sender_key = sender_key.to_base64(), "Successfully decrypted an Olm message" ); @@ -382,7 +370,7 @@ impl Account { warn!( sender = sender.as_str(), - sender_key = sender_key, + sender_key = sender_key.to_base64(), error = ?e, "A to-device message was successfully decrypted but \ parsing and checking the event fields failed" @@ -577,9 +565,8 @@ impl ReadOnlyAccount { // so. if count != old_count { debug!( - "Updated uploaded one-time key count {} -> {}.", + "Updated uploaded one-time key count {} -> {count}.", self.uploaded_key_count(), - count ); } @@ -1017,10 +1004,9 @@ impl ReadOnlyAccount { /// account. pub async fn create_inbound_session( &self, - their_identity_key: &str, + their_identity_key: Curve25519PublicKey, message: &PreKeyMessage, ) -> Result { - let their_identity_key = Curve25519PublicKey::from_base64(their_identity_key)?; let result = self.inner.lock().await.create_inbound_session(their_identity_key, message)?; let now = SecondsSinceUnixEpoch::now(); @@ -1084,6 +1070,7 @@ impl ReadOnlyAccount { &signing_key, room_id, &outbound.session_key().await, + outbound.settings().algorithm.to_owned(), Some(visibility), ); @@ -1122,20 +1109,16 @@ impl ReadOnlyAccount { let message = our_session .encrypt(&device, AnyToDeviceEventContent::Dummy(ToDeviceDummyEventContent::new())) .await + .unwrap() + .deserialize() .unwrap(); - let content = if let EncryptedEventScheme::OlmV1Curve25519AesSha2(c) = message.scheme { + let content = if let ToDeviceEncryptedEventContent::OlmV1Curve25519AesSha2(c) = message { c } else { - panic!("Invalid encrypted event algorithm"); + panic!("Invalid encrypted event algorithm {}", message.algorithm()); }; - let own_ciphertext = - content.ciphertext.get(&other.identity_keys.curve25519.to_base64()).unwrap(); - let message_type: u8 = own_ciphertext.message_type.try_into().unwrap(); - - let message = OlmMessage::from_parts(message_type.into(), &own_ciphertext.body).unwrap(); - - let prekey = if let OlmMessage::PreKey(m) = message.clone() { + let prekey = if let OlmMessage::PreKey(m) = content.ciphertext { m } else { panic!("Wrong Olm message type"); @@ -1143,7 +1126,7 @@ impl ReadOnlyAccount { let our_device = ReadOnlyDevice::from_account(self).await; let other_session = other - .create_inbound_session(&our_device.curve25519_key().unwrap().to_base64(), &prekey) + .create_inbound_session(our_device.curve25519_key().unwrap(), &prekey) .await .unwrap(); diff --git a/crates/matrix-sdk-crypto/src/olm/group_sessions/inbound.rs b/crates/matrix-sdk-crypto/src/olm/group_sessions/inbound.rs index fa22f4a02..63666e202 100644 --- a/crates/matrix-sdk-crypto/src/olm/group_sessions/inbound.rs +++ b/crates/matrix-sdk-crypto/src/olm/group_sessions/inbound.rs @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#![warn(missing_docs)] - use std::{ collections::BTreeMap, fmt, @@ -27,11 +25,7 @@ use matrix_sdk_common::locks::Mutex; use ruma::{ events::{ forwarded_room_key::ToDeviceForwardedRoomKeyEventContent, - room::{ - encrypted::{EncryptedEventScheme, OriginalSyncRoomEncryptedEvent}, - history_visibility::HistoryVisibility, - }, - AnyRoomEvent, + room::history_visibility::HistoryVisibility, AnyRoomEvent, }, serde::Raw, DeviceKeyAlgorithm, EventEncryptionAlgorithm, OwnedRoomId, RoomId, @@ -41,13 +35,16 @@ use serde_json::Value; use vodozemac::{ megolm::{ DecryptedMessage, DecryptionError, ExportedSessionKey, InboundGroupSession as InnerSession, - InboundGroupSessionPickle, MegolmMessage, SessionKeyDecodeError, + InboundGroupSessionPickle, MegolmMessage, }, PickleError, }; -use super::{BackedUpRoomKey, ExportedRoomKey, SessionKey}; -use crate::error::{EventError, MegolmResult}; +use super::{BackedUpRoomKey, ExportedRoomKey, SessionCreationError, SessionKey}; +use crate::{ + error::{EventError, MegolmResult}, + types::events::room::encrypted::{EncryptedEvent, RoomEventEncryptionScheme}, +}; // TODO add creation times to the inbound group sessions so we can export // sessions that were created between some time period, this should only be set @@ -72,6 +69,7 @@ pub struct InboundGroupSession { pub room_id: Arc, forwarding_chains: Arc>, imported: bool, + algorithm: Arc, backed_up: Arc, } @@ -97,6 +95,7 @@ impl InboundGroupSession { signing_key: &str, room_id: &RoomId, session_key: &SessionKey, + encryption_algorithm: EventEncryptionAlgorithm, history_visibility: Option, ) -> Self { let session = InnerSession::new(session_key); @@ -108,14 +107,15 @@ impl InboundGroupSession { InboundGroupSession { inner: Arc::new(Mutex::new(session)), - session_id: session_id.into(), history_visibility: history_visibility.into(), - sender_key: sender_key.to_owned().into(), + session_id: session_id.into(), first_known_index, + sender_key: sender_key.to_owned().into(), signing_keys: keys.into(), room_id: room_id.into(), forwarding_chains: Vec::new().into(), imported: false, + algorithm: encryption_algorithm.into(), backed_up: AtomicBool::new(false).into(), } } @@ -127,16 +127,19 @@ impl InboundGroupSession { /// previous [`export()`] call. /// /// [`export()`]: #method.export - pub fn from_export(exported_session: ExportedRoomKey) -> Self { - Self::from(exported_session) + pub fn from_export(exported_session: &ExportedRoomKey) -> Result { + Self::try_from(exported_session) } #[allow(dead_code)] - fn from_backup(room_id: &RoomId, backup: BackedUpRoomKey) -> Self { + fn from_backup( + room_id: &RoomId, + backup: BackedUpRoomKey, + ) -> Result { let session = InnerSession::import(&backup.session_key); let session_id = session.session_id(); - Self::from_export(ExportedRoomKey { + Self::from_export(&ExportedRoomKey { algorithm: backup.algorithm, room_id: room_id.to_owned(), sender_key: backup.sender_key, @@ -159,10 +162,12 @@ impl InboundGroupSession { pub fn from_forwarded_key( sender_key: &str, content: &ToDeviceForwardedRoomKeyEventContent, - ) -> Result { + ) -> Result { let key = ExportedSessionKey::from_base64(&content.session_key)?; + let algorithm = EventEncryptionAlgorithm::from(content.algorithm.as_str()); let session = InnerSession::import(&key); + let first_known_index = session.first_known_index(); let mut forwarding_chains = content.forwarding_curve25519_key_chain.clone(); forwarding_chains.push(sender_key.to_owned()); @@ -182,6 +187,7 @@ impl InboundGroupSession { forwarding_chains: forwarding_chains.into(), imported: true, backed_up: AtomicBool::new(false).into(), + algorithm: algorithm.into(), }) } @@ -203,6 +209,7 @@ impl InboundGroupSession { imported: self.imported, backed_up: self.backed_up(), history_visibility: self.history_visibility.as_ref().clone(), + algorithm: (*self.algorithm).to_owned(), } } @@ -293,6 +300,7 @@ impl InboundGroupSession { room_id: (*pickle.room_id).into(), forwarding_chains: pickle.forwarding_chains.into(), backed_up: AtomicBool::from(pickle.backed_up).into(), + algorithm: pickle.algorithm.into(), imported: pickle.imported, }) } @@ -307,6 +315,12 @@ impl InboundGroupSession { &self.session_id } + /// The algorithm that this inbound group session is using to decrypt + /// events. + pub fn algorithm(&self) -> &EventEncryptionAlgorithm { + &self.algorithm + } + /// Get the first message index we know how to decrypt. pub fn first_known_index(&self) -> u32 { self.first_known_index @@ -339,18 +353,16 @@ impl InboundGroupSession { /// # Arguments /// /// * `event` - The event that should be decrypted. - pub async fn decrypt( - &self, - event: &OriginalSyncRoomEncryptedEvent, - ) -> MegolmResult<(Raw, u32)> { - let content = match &event.content.scheme { - EncryptedEventScheme::MegolmV1AesSha2(c) => c, - _ => return Err(EventError::UnsupportedAlgorithm.into()), + pub async fn decrypt(&self, event: &EncryptedEvent) -> MegolmResult<(Raw, u32)> { + let decrypted = match &event.content.scheme { + RoomEventEncryptionScheme::MegolmV1AesSha2(c) => { + self.decrypt_helper(&c.ciphertext).await? + } + RoomEventEncryptionScheme::Unknown(_) => { + return Err(EventError::UnsupportedAlgorithm.into()); + } }; - let message = MegolmMessage::from_base64(&content.ciphertext)?; - - let decrypted = self.decrypt_helper(&message).await?; let plaintext = String::from_utf8_lossy(&decrypted.plaintext); let mut decrypted_value = serde_json::from_str::(&plaintext)?; @@ -432,24 +444,34 @@ pub struct PickledInboundGroupSession { pub backed_up: bool, /// History visibility of the room when the session was created. pub history_visibility: Option, + /// The algorithm of this inbound group session. + #[serde(default = "default_algorithm")] + pub algorithm: EventEncryptionAlgorithm, } -impl From for InboundGroupSession { - fn from(key: ExportedRoomKey) -> Self { +fn default_algorithm() -> EventEncryptionAlgorithm { + EventEncryptionAlgorithm::MegolmV1AesSha2 +} + +impl TryFrom<&ExportedRoomKey> for InboundGroupSession { + type Error = SessionCreationError; + + fn try_from(key: &ExportedRoomKey) -> Result { let session = InnerSession::import(&key.session_key); let first_known_index = session.first_known_index(); - InboundGroupSession { + Ok(InboundGroupSession { inner: Mutex::new(session).into(), - session_id: key.session_id.into(), - sender_key: key.sender_key.into(), + session_id: key.session_id.to_owned().into(), + sender_key: key.sender_key.to_owned().into(), history_visibility: None.into(), first_known_index, - signing_keys: key.sender_claimed_keys.into(), - room_id: (*key.room_id).into(), - forwarding_chains: key.forwarding_curve25519_key_chain.into(), + signing_keys: key.sender_claimed_keys.to_owned().into(), + room_id: key.room_id.to_owned().into(), + forwarding_chains: key.forwarding_curve25519_key_chain.to_owned().into(), imported: true, + algorithm: key.algorithm.to_owned().into(), backed_up: AtomicBool::from(false).into(), - } + }) } } diff --git a/crates/matrix-sdk-crypto/src/olm/group_sessions/mod.rs b/crates/matrix-sdk-crypto/src/olm/group_sessions/mod.rs index d5b7b7543..9edf462b3 100644 --- a/crates/matrix-sdk-crypto/src/olm/group_sessions/mod.rs +++ b/crates/matrix-sdk-crypto/src/olm/group_sessions/mod.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{collections::BTreeMap, convert::TryInto}; +use std::collections::BTreeMap; use ruma::{ events::forwarded_room_key::{ @@ -30,10 +30,22 @@ pub(crate) use outbound::ShareState; pub use outbound::{ EncryptionSettings, GroupSession, OutboundGroupSession, PickledOutboundGroupSession, ShareInfo, }; +use thiserror::Error; use vodozemac::megolm::SessionKeyDecodeError; pub use vodozemac::megolm::{ExportedSessionKey, SessionKey}; use zeroize::Zeroize; +/// An error type for the creation of group sessions. +#[derive(Debug, Error)] +pub enum SessionCreationError { + /// The provided algorithm is not supported. + #[error("The provided algorithm is not supported: {0}")] + Algorithm(EventEncryptionAlgorithm), + /// The room key key couldn't be decoded. + #[error(transparent)] + Decode(#[from] SessionKeyDecodeError), +} + /// An exported version of an `InboundGroupSession` /// /// This can be used to share the `InboundGroupSession` in an exported file. diff --git a/crates/matrix-sdk-crypto/src/olm/group_sessions/outbound.rs b/crates/matrix-sdk-crypto/src/olm/group_sessions/outbound.rs index 912356d87..5ac3af7ab 100644 --- a/crates/matrix-sdk-crypto/src/olm/group_sessions/outbound.rs +++ b/crates/matrix-sdk-crypto/src/olm/group_sessions/outbound.rs @@ -27,16 +27,11 @@ use dashmap::DashMap; use matrix_sdk_common::locks::Mutex; use ruma::{ events::{ - room::{ - encrypted::{ - EncryptedEventScheme, MegolmV1AesSha2ContentInit, RoomEncryptedEventContent, - }, - encryption::RoomEncryptionEventContent, - history_visibility::HistoryVisibility, - }, + room::{encryption::RoomEncryptionEventContent, history_visibility::HistoryVisibility}, room_key::ToDeviceRoomKeyEventContent, AnyToDeviceEventContent, }, + serde::Raw, DeviceId, EventEncryptionAlgorithm, OwnedDeviceId, OwnedTransactionId, OwnedUserId, RoomId, SecondsSinceUnixEpoch, TransactionId, UserId, }; @@ -50,7 +45,12 @@ pub use vodozemac::{ PickleError, }; -use crate::{Device, ToDeviceRequest}; +use crate::{ + types::events::room::encrypted::{ + MegolmV1AesSha2Content, RoomEncryptedEventContent, RoomEventEncryptionScheme, + }, + Device, ToDeviceRequest, +}; const ROTATION_PERIOD: Duration = Duration::from_millis(604800000); const ROTATION_MESSAGES: u64 = 100; @@ -272,7 +272,11 @@ impl OutboundGroupSession { /// # Panics /// /// Panics if the content can't be serialized. - pub async fn encrypt(&self, content: Value, event_type: &str) -> RoomEncryptedEventContent { + pub async fn encrypt( + &self, + content: Value, + event_type: &str, + ) -> Raw { let json_content = json!({ "content": content, "room_id": &*self.room_id, @@ -280,22 +284,21 @@ impl OutboundGroupSession { }); let plaintext = json_content.to_string(); - let relation = serde_json::from_value(content).ok(); + let relates_to = content.get("relates_to").cloned(); let ciphertext = self.encrypt_helper(plaintext).await; - let encrypted_content = MegolmV1AesSha2ContentInit { - ciphertext: ciphertext.to_base64(), - sender_key: self.account_identity_keys.curve25519.to_base64(), + let scheme: RoomEventEncryptionScheme = MegolmV1AesSha2Content { + ciphertext, + sender_key: self.account_identity_keys.curve25519, session_id: self.session_id().to_owned(), device_id: (*self.device_id).to_owned(), } .into(); - RoomEncryptedEventContent::new( - EncryptedEventScheme::MegolmV1AesSha2(encrypted_content), - relation, - ) + let content = RoomEncryptedEventContent { scheme, relates_to }; + + Raw::new(&content).expect("m.room.encrypted event content can always be serialized") } fn elapsed(&self) -> bool { diff --git a/crates/matrix-sdk-crypto/src/olm/mod.rs b/crates/matrix-sdk-crypto/src/olm/mod.rs index de4db40e1..387cbad03 100644 --- a/crates/matrix-sdk-crypto/src/olm/mod.rs +++ b/crates/matrix-sdk-crypto/src/olm/mod.rs @@ -28,7 +28,8 @@ pub use account::{OlmMessageHash, PickledAccount, ReadOnlyAccount}; pub(crate) use group_sessions::ShareState; pub use group_sessions::{ EncryptionSettings, ExportedRoomKey, InboundGroupSession, OutboundGroupSession, - PickledInboundGroupSession, PickledOutboundGroupSession, SessionKey, ShareInfo, + PickledInboundGroupSession, PickledOutboundGroupSession, SessionCreationError, SessionKey, + ShareInfo, }; pub use session::{PickledSession, Session}; pub use signing::{CrossSigningStatus, PickledCrossSigningIdentity, PrivateCrossSigningIdentity}; @@ -44,15 +45,17 @@ pub(crate) mod tests { events::{ forwarded_room_key::ToDeviceForwardedRoomKeyEventContent, room::message::{Relation, Replacement, RoomMessageEventContent}, - AnyMessageLikeEvent, AnyRoomEvent, AnySyncMessageLikeEvent, AnySyncRoomEvent, - MessageLikeEvent, SyncMessageLikeEvent, + AnyMessageLikeEvent, AnyRoomEvent, MessageLikeEvent, }, room_id, user_id, DeviceId, UserId, }; use serde_json::json; use vodozemac::olm::OlmMessage; - use crate::olm::{ExportedRoomKey, InboundGroupSession, ReadOnlyAccount, Session}; + use crate::{ + olm::{ExportedRoomKey, InboundGroupSession, ReadOnlyAccount, Session}, + utilities::json_convert, + }; fn alice_id() -> &'static UserId { user_id!("@alice:example.org") @@ -136,10 +139,8 @@ pub(crate) mod tests { }; let bob_keys = bob.identity_keys(); - let result = alice - .create_inbound_session(&bob_keys.curve25519.to_base64(), &prekey_message) - .await - .unwrap(); + let result = + alice.create_inbound_session(bob_keys.curve25519, &prekey_message).await.unwrap(); assert_eq!(bob_session.session_id(), result.session.session_id()); @@ -163,6 +164,7 @@ pub(crate) mod tests { "test_key", room_id, &outbound.session_key().await, + outbound.settings().algorithm.to_owned(), None, ); @@ -203,6 +205,7 @@ pub(crate) mod tests { "test_key", room_id, &outbound.session_key().await, + outbound.settings().algorithm.to_owned(), None, ); @@ -220,20 +223,9 @@ pub(crate) mod tests { "room_id": room_id, "type": "m.room.encrypted", "content": encrypted_content, - }) - .to_string(); - - let event: AnySyncRoomEvent = serde_json::from_str(&event).unwrap(); - - let event = if let AnySyncRoomEvent::MessageLike(AnySyncMessageLikeEvent::RoomEncrypted( - SyncMessageLikeEvent::Original(event), - )) = event - { - event - } else { - panic!("Invalid event type") - }; + }); + let event = json_convert(&event).unwrap(); let decrypted = inbound.decrypt(&event).await.unwrap().0; if let AnyRoomEvent::MessageLike(AnyMessageLikeEvent::RoomMessage( @@ -257,7 +249,8 @@ pub(crate) mod tests { let export: ToDeviceForwardedRoomKeyEventContent = export.try_into().unwrap(); let export = ExportedRoomKey::try_from(export).unwrap(); - let imported = InboundGroupSession::from_export(export); + let imported = InboundGroupSession::from_export(&export) + .expect("We can always import an inbound group session from a fresh export"); assert_eq!(inbound.session_id(), imported.session_id()); } diff --git a/crates/matrix-sdk-crypto/src/olm/session.rs b/crates/matrix-sdk-crypto/src/olm/session.rs index 6120ef5da..ae4d63bf2 100644 --- a/crates/matrix-sdk-crypto/src/olm/session.rs +++ b/crates/matrix-sdk-crypto/src/olm/session.rs @@ -12,17 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{collections::BTreeMap, fmt, sync::Arc}; +use std::{fmt, sync::Arc}; use matrix_sdk_common::locks::Mutex; use ruma::{ - events::{ - room::encrypted::{ - CiphertextInfo, EncryptedEventScheme, OlmV1Curve25519AesSha2Content, - ToDeviceRoomEncryptedEventContent, - }, - AnyToDeviceEventContent, EventContent, - }, + events::{AnyToDeviceEventContent, EventContent}, + serde::Raw, DeviceId, SecondsSinceUnixEpoch, UserId, }; use serde::{Deserialize, Serialize}; @@ -35,6 +30,9 @@ use vodozemac::{ use super::IdentityKeys; use crate::{ error::{EventError, OlmResult}, + types::events::room::encrypted::{ + OlmV1Curve25519AesSha2Content, ToDeviceEncryptedEventContent, + }, ReadOnlyDevice, }; @@ -120,7 +118,7 @@ impl Session { &mut self, recipient_device: &ReadOnlyDevice, content: AnyToDeviceEventContent, - ) -> OlmResult { + ) -> OlmResult> { let recipient_signing_key = recipient_device.ed25519_key().ok_or(EventError::MissingSigningKey)?; @@ -141,19 +139,18 @@ impl Session { }); let plaintext = serde_json::to_string(&payload)?; - let ciphertext = self.encrypt_helper(&plaintext).await.to_parts(); + let ciphertext = self.encrypt_helper(&plaintext).await; - let message_type = ciphertext.0; - let ciphertext = CiphertextInfo::new(ciphertext.1, (message_type as u32).into()); + let content = OlmV1Curve25519AesSha2Content { + ciphertext, + recipient_key: self.sender_key, + sender_key: self.our_identity_keys.curve25519, + } + .into(); - let mut content = BTreeMap::new(); - content.insert(self.sender_key.to_base64(), ciphertext); + let content = Raw::new(&content).expect("A encrypted can always be serialized"); - Ok(EncryptedEventScheme::OlmV1Curve25519AesSha2(OlmV1Curve25519AesSha2Content::new( - content, - self.our_identity_keys.curve25519.to_base64(), - )) - .into()) + Ok(content) } /// Returns the unique identifier for this session. diff --git a/crates/matrix-sdk-crypto/src/olm/signing/mod.rs b/crates/matrix-sdk-crypto/src/olm/signing/mod.rs index 1b6ddb3ea..0d5028c43 100644 --- a/crates/matrix-sdk-crypto/src/olm/signing/mod.rs +++ b/crates/matrix-sdk-crypto/src/olm/signing/mod.rs @@ -715,15 +715,15 @@ mod tests { let master_key = identity.master_key.lock().await; let master_key = master_key.as_ref().unwrap(); - assert!(master_key + master_key .public_key - .verify_subkey(&identity.self_signing_key.lock().await.as_ref().unwrap().public_key,) - .is_ok()); + .verify_subkey(&identity.self_signing_key.lock().await.as_ref().unwrap().public_key) + .unwrap(); - assert!(master_key + master_key .public_key - .verify_subkey(&identity.user_signing_key.lock().await.as_ref().unwrap().public_key,) - .is_ok()); + .verify_subkey(&identity.user_signing_key.lock().await.as_ref().unwrap().public_key) + .unwrap(); } #[async_test] diff --git a/crates/matrix-sdk-crypto/src/olm/utility.rs b/crates/matrix-sdk-crypto/src/olm/utility.rs index dcca398d5..64c23b60c 100644 --- a/crates/matrix-sdk-crypto/src/olm/utility.rs +++ b/crates/matrix-sdk-crypto/src/olm/utility.rs @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::convert::TryInto; - use ruma::{CanonicalJsonValue, DeviceKeyAlgorithm, DeviceKeyId, UserId}; use serde::Serialize; use serde_json::Value; diff --git a/crates/matrix-sdk-crypto/src/requests.rs b/crates/matrix-sdk-crypto/src/requests.rs index 67ca3f9b3..fa382bb52 100644 --- a/crates/matrix-sdk-crypto/src/requests.rs +++ b/crates/matrix-sdk-crypto/src/requests.rs @@ -78,9 +78,14 @@ impl ToDeviceRequest { pub(crate) fn new( recipient: &UserId, recipient_device: impl Into, - content: AnyToDeviceEventContent, + event_type: &str, + content: Raw, ) -> Self { - Self::with_id(recipient, recipient_device, content, TransactionId::new()) + let event_type = ToDeviceEventType::from(event_type); + let user_messages = iter::once((recipient_device.into(), content)).collect(); + let messages = iter::once((recipient.to_owned(), user_messages)).collect(); + + ToDeviceRequest { event_type, txn_id: TransactionId::new(), messages } } pub(crate) fn for_recipients( @@ -89,20 +94,24 @@ impl ToDeviceRequest { content: AnyToDeviceEventContent, txn_id: OwnedTransactionId, ) -> Self { + let event_type = content.event_type(); + let raw_content = Raw::new(&content).expect("Failed to serialize to-device event"); + if recipient_devices.is_empty() { - Self::new(recipient, DeviceIdOrAllDevices::AllDevices, content) + Self::new( + recipient, + DeviceIdOrAllDevices::AllDevices, + &event_type.to_string(), + raw_content, + ) } else { - let event_type = content.event_type(); let device_messages = recipient_devices .into_iter() - .map(|d| { - let raw_content = - Raw::new(&content).expect("Failed to serialize to-device event"); - (DeviceIdOrAllDevices::DeviceId(d), raw_content) - }) + .map(|d| (DeviceIdOrAllDevices::DeviceId(d), raw_content.clone())) .collect(); let messages = iter::once((recipient.to_owned(), device_messages)).collect(); + ToDeviceRequest { event_type, txn_id, messages } } } diff --git a/crates/matrix-sdk-crypto/src/session_manager/group_sessions.rs b/crates/matrix-sdk-crypto/src/session_manager/group_sessions.rs index 25f24799d..be1483bc5 100644 --- a/crates/matrix-sdk-crypto/src/session_manager/group_sessions.rs +++ b/crates/matrix-sdk-crypto/src/session_manager/group_sessions.rs @@ -22,10 +22,7 @@ use dashmap::DashMap; use futures_util::future::join_all; use matrix_sdk_common::executor::spawn; use ruma::{ - events::{ - room::{encrypted::RoomEncryptedEventContent, history_visibility::HistoryVisibility}, - AnyToDeviceEventContent, ToDeviceEventType, - }, + events::{AnyToDeviceEventContent, ToDeviceEventType}, serde::Raw, to_device::DeviceIdOrAllDevices, DeviceId, OwnedDeviceId, OwnedRoomId, OwnedTransactionId, OwnedUserId, RoomId, TransactionId, @@ -38,6 +35,7 @@ use crate::{ error::{EventError, MegolmResult, OlmResult}, olm::{Account, InboundGroupSession, OutboundGroupSession, Session, ShareInfo, ShareState}, store::{Changes, Result as StoreResult, Store}, + types::events::room::encrypted::RoomEncryptedEventContent, Device, EncryptionSettings, OlmError, ToDeviceRequest, }; @@ -160,7 +158,7 @@ impl GroupSessionManager { room_id: &RoomId, content: Value, event_type: &str, - ) -> MegolmResult { + ) -> MegolmResult> { let session = self.sessions.get(room_id).expect("Session wasn't created nor shared"); assert!(!session.expired(), "Session expired"); @@ -252,8 +250,7 @@ impl GroupSessionManager { .or_insert_with(BTreeMap::new) .insert( DeviceIdOrAllDevices::DeviceId(device.device_id().into()), - Raw::new(&AnyToDeviceEventContent::RoomEncrypted(encrypted)) - .expect("Failed to serialize encrypted event"), + encrypted.cast(), ); share_info .entry(device.user_id().to_owned()) @@ -323,7 +320,7 @@ impl GroupSessionManager { pub async fn collect_session_recipients( &self, users: impl Iterator, - history_visibility: HistoryVisibility, + settings: &EncryptionSettings, outbound: &OutboundGroupSession, ) -> OlmResult<(bool, HashMap>)> { let users: HashSet<&UserId> = users.collect(); @@ -331,7 +328,7 @@ impl GroupSessionManager { trace!( ?users, - ?history_visibility, + ?settings, session_id = outbound.session_id(), room_id = outbound.room_id().as_str(), "Calculating group session recipients" @@ -347,16 +344,19 @@ impl GroupSessionManager { // get the session but is in the set of users that received the session. let user_left = !users_shared_with.difference(&users).collect::>().is_empty(); - let visibility_changed = outbound.settings().history_visibility != history_visibility; + let visibility_changed = + outbound.settings().history_visibility != settings.history_visibility; + let algorithm_changed = outbound.settings().algorithm != settings.algorithm; // To protect the room history we need to rotate the session if either: // // 1. Any user left the room. // 2. Any of the users' devices got deleted or blacklisted. // 3. The history visibility changed. + // 4. The encryption algorithm changed. // // This is calculated in the following code and stored in this variable. - let mut should_rotate = user_left || visibility_changed; + let mut should_rotate = user_left || visibility_changed || algorithm_changed; for user_id in users { let user_devices = self.store.get_user_devices_filtered(user_id).await?; @@ -445,22 +445,27 @@ impl GroupSessionManager { users: impl Iterator, encryption_settings: impl Into, ) -> OlmResult>> { - trace!(room_id = room_id.as_str(), "Checking if a room key needs to be shared",); + trace!(room_id = room_id.as_str(), "Checking if a room key needs to be shared"); let encryption_settings = encryption_settings.into(); - let history_visibility = encryption_settings.history_visibility.clone(); let mut changes = Changes::default(); + // Try to get an existing session or create a new one. let (outbound, inbound) = self.get_or_create_outbound_session(room_id, encryption_settings.clone()).await?; + // Having an inbound group session here means that we created a new + // group session pair, which we then need to store. if let Some(inbound) = inbound { changes.outbound_group_sessions.push(outbound.clone()); changes.inbound_group_sessions.push(inbound); } + // Collect the recipient devices and check if either the settings + // or the recipient list changed in a way that requires the + // session to be rotated. let (should_rotate, devices) = - self.collect_session_recipients(users, history_visibility, &outbound).await?; + self.collect_session_recipients(users, &encryption_settings, &outbound).await?; let outbound = if should_rotate { let old_session_id = outbound.session_id(); @@ -475,7 +480,8 @@ impl GroupSessionManager { old_session_id = old_session_id, session_id = outbound.session_id(), "A user or device has left the room since we last sent a \ - message, rotating the room key.", + message, or the encryption settings have changed. Rotating the \ + room key.", ); outbound @@ -483,6 +489,8 @@ impl GroupSessionManager { outbound }; + // Filter out the devices that already received this room key or have a + // to-device message already queued up. let devices: Vec = devices .into_iter() .flat_map(|(_, d)| { @@ -494,12 +502,17 @@ impl GroupSessionManager { let key_content = outbound.as_content().await; let message_index = outbound.message_index().await; + // If we have some recipients, log them here. if !devices.is_empty() { let recipients = devices.iter().fold(BTreeMap::new(), |mut acc, d| { acc.entry(d.user_id()).or_insert_with(BTreeSet::new).insert(d.device_id()); acc }); + // If there are new recipients we need to persist the outbound group + // session as the to-device requests are persisted with the session. + changes.outbound_group_sessions = vec![outbound.clone()]; + info!( index = message_index, ?recipients, @@ -509,6 +522,10 @@ impl GroupSessionManager { ); } + // Chunk the recipients out so each to-device request will contain a + // limited amount of to-device messages. + // + // Create concurrent tasks for each chunk of recipients. let tasks: Vec<_> = devices .chunks(Self::MAX_TO_DEVICE_MESSAGES) .map(|chunk| { @@ -522,12 +539,19 @@ impl GroupSessionManager { }) .collect(); + // Wait for all the tasks to finish up and queue up the Olm session that + // was used to encrypt the room key to be persisted again. This is + // needed because each encryption step will mutate the Olm session, + // ratcheting its state forward. for result in join_all(tasks).await { let used_sessions: OlmResult> = result.expect("Encryption task panicked"); changes.sessions.extend(used_sessions?); } + // The to-device requests get added to the outbound group session, this + // way we're making sure that they are persisted and scoped to the + // session. let requests = outbound.pending_requests(); if requests.is_empty() { @@ -545,6 +569,7 @@ impl GroupSessionManager { let mut recipients: BTreeMap<&UserId, BTreeSet<&DeviceIdOrAllDevices>> = BTreeMap::new(); + // We're just collecting the recipients for logging reasons. for request in &requests { for (user_id, device_map) in &request.messages { let devices = device_map.keys(); @@ -565,6 +590,7 @@ impl GroupSessionManager { ); } + // Persist any changes we might have collected. if !changes.is_empty() { let session_count = changes.sessions.len(); @@ -584,19 +610,22 @@ impl GroupSessionManager { #[cfg(test)] mod tests { - use std::ops::Deref; + use std::{collections::HashSet, ops::Deref}; use matrix_sdk_test::{async_test, response_from_file}; use ruma::{ api::{ - client::keys::{claim_keys, get_keys}, + client::{ + keys::{claim_keys, get_keys}, + to_device::send_event_to_device::v3::Response as ToDeviceResponse, + }, IncomingResponse, }, device_id, events::room::history_visibility::HistoryVisibility, - room_id, user_id, DeviceId, TransactionId, UserId, + room_id, user_id, DeviceId, EventEncryptionAlgorithm, TransactionId, UserId, }; - use serde_json::Value; + use serde_json::{json, Value}; use crate::{EncryptionSettings, OlmMachine}; @@ -616,12 +645,67 @@ mod tests { .expect("Can't parse the keys upload response") } + fn bob_keys_query_response() -> get_keys::v3::Response { + let data = json!({ + "device_keys": { + "@bob:localhost": { + "BOBDEVICE": { + "user_id": "@bob:localhost", + "device_id": "BOBDEVICE", + "algorithms": [ + "m.olm.v1.curve25519-aes-sha2", + "m.megolm.v1.aes-sha2", + "m.megolm.v2.aes-sha2" + ], + "keys": { + "curve25519:BOBDEVICE": "QzXDFZj0Pt5xG4r11XGSrqE4mnFOTgRM5pz7n3tzohU", + "ed25519:BOBDEVICE": "T7QMEXcEo/NfiC/8doVHT+2XnMm0pDpRa27bmE8PlPI" + }, + "signatures": { + "@bob:localhost": { + "ed25519:BOBDEVICE": "1Ee9J02KoVf4DKhT+LkurpZJEygiznqpgkT4lqvMTLtZyzShsVTnwmoMPttuGcJkLp9lMK1egveNYCEaYP80Cw" + } + } + } + } + } + }); + let data = response_from_file(&data); + + get_keys::v3::Response::try_from_http_response(data) + .expect("Can't parse the keys upload response") + } + + fn bob_one_time_key() -> claim_keys::v3::Response { + let data = json!({ + "failures": {}, + "one_time_keys":{ + "@bob:localhost":{ + "BOBDEVICE":{ + "signed_curve25519:AAAAAAAAAAA": { + "key":"bm1olfbksjC5SwKxCLLK4XaINCA0FwR/155J85gIpCk", + "signatures":{ + "@bob:localhost":{ + "ed25519:BOBDEVICE":"BKyS/+EV76zdZkWgny2D0svZ0ycS3etfyHCrsDgm7MYe166HqQmSoX29HsjGLvE/5F+Sg2zW7RJileUvquPwDA" + } + } + } + } + } + } + }); + let data = response_from_file(&data); + + claim_keys::v3::Response::try_from_http_response(data) + .expect("Can't parse the keys claim response") + } + fn keys_claim_response() -> claim_keys::v3::Response { let data = include_bytes!("../../../../benchmarks/benches/crypto_bench/keys_claim.json"); let data: Value = serde_json::from_slice(data).unwrap(); let data = response_from_file(&data); claim_keys::v3::Response::try_from_http_response(data) - .expect("Can't parse the keys upload response") + .expect("Can't parse the keys claim response") } async fn machine_with_user(user_id: &UserId, device_id: &DeviceId) -> OlmMachine { @@ -633,6 +717,8 @@ mod tests { machine.mark_request_as_sent(&txn_id, &keys_query).await.unwrap(); machine.mark_request_as_sent(&txn_id, &keys_claim).await.unwrap(); + machine.mark_request_as_sent(&txn_id, &bob_keys_query_response()).await.unwrap(); + machine.mark_request_as_sent(&txn_id, &bob_one_time_key()).await.unwrap(); machine } @@ -641,6 +727,31 @@ mod tests { machine_with_user(alice_id(), alice_device_id()).await } + async fn machine_with_shared_room_key() -> OlmMachine { + let machine = machine().await; + let room_id = room_id!("!test:localhost"); + let keys_claim = keys_claim_response(); + + let users = keys_claim.one_time_keys.keys().map(Deref::deref); + let requests = + machine.share_room_key(room_id, users, EncryptionSettings::default()).await.unwrap(); + + let outbound = machine.group_session_manager.get_outbound_group_session(room_id).unwrap(); + + assert!(!outbound.pending_requests().is_empty()); + assert!(!outbound.shared()); + + let response = ToDeviceResponse::new(); + for request in requests { + machine.mark_request_as_sent(&request.txn_id, &response).await.unwrap(); + } + + assert!(outbound.shared()); + assert!(outbound.pending_requests().is_empty()); + + machine + } + #[async_test] async fn test_sharing() { let machine = machine().await; @@ -660,6 +771,73 @@ mod tests { assert_eq!(event_count, 148); } + #[async_test] + async fn ratcheted_sharing() { + let machine = machine_with_shared_room_key().await; + + let room_id = room_id!("!test:localhost"); + let late_joiner = user_id!("@bob:localhost"); + let keys_claim = keys_claim_response(); + + let mut users: HashSet<_> = keys_claim.one_time_keys.keys().map(Deref::deref).collect(); + users.insert(late_joiner); + + let requests = machine + .share_room_key(room_id, users.into_iter(), EncryptionSettings::default()) + .await + .unwrap(); + + let event_count: usize = requests.iter().map(|r| r.message_count()).sum(); + let outbound = machine.group_session_manager.get_outbound_group_session(room_id).unwrap(); + + assert_eq!(event_count, 1); + assert!(!outbound.pending_requests().is_empty()); + } + + #[async_test] + async fn changing_encryption_settings() { + let machine = machine_with_shared_room_key().await; + let room_id = room_id!("!test:localhost"); + let keys_claim = keys_claim_response(); + + let users = keys_claim.one_time_keys.keys().map(Deref::deref); + let outbound = machine.group_session_manager.get_outbound_group_session(room_id).unwrap(); + + let (should_rotate, _) = machine + .group_session_manager + .collect_session_recipients(users.clone(), &EncryptionSettings::default(), &outbound) + .await + .unwrap(); + + assert!(!should_rotate); + + let settings = EncryptionSettings { + history_visibility: HistoryVisibility::Invited, + ..Default::default() + }; + + let (should_rotate, _) = machine + .group_session_manager + .collect_session_recipients(users.clone(), &settings, &outbound) + .await + .unwrap(); + + assert!(should_rotate); + + let settings = EncryptionSettings { + algorithm: EventEncryptionAlgorithm::from("m.megolm.v2.aes-sha2"), + ..Default::default() + }; + + let (should_rotate, _) = machine + .group_session_manager + .collect_session_recipients(users, &settings, &outbound) + .await + .unwrap(); + + assert!(should_rotate); + } + #[async_test] async fn key_recipient_collecting() { // The user id comes from the fact that the keys_query.json file uses @@ -676,12 +854,13 @@ mod tests { .await .expect("We should be able to create a new session"); let history_visibility = HistoryVisibility::Joined; + let settings = EncryptionSettings { history_visibility, ..Default::default() }; let users = [user_id].into_iter(); let (_, recipients) = machine .group_session_manager - .collect_session_recipients(users, history_visibility, &outbound) + .collect_session_recipients(users, &settings, &outbound) .await .expect("We should be able to collect the session recipients"); diff --git a/crates/matrix-sdk-crypto/src/session_manager/sessions.rs b/crates/matrix-sdk-crypto/src/session_manager/sessions.rs index 5c275ac4d..87360f338 100644 --- a/crates/matrix-sdk-crypto/src/session_manager/sessions.rs +++ b/crates/matrix-sdk-crypto/src/session_manager/sessions.rs @@ -37,6 +37,7 @@ use crate::{ olm::Account, requests::{OutgoingRequest, ToDeviceRequest}, store::{Changes, Result as StoreResult, Store}, + types::events::EventType, ReadOnlyDevice, }; @@ -145,7 +146,8 @@ impl SessionManager { let request = ToDeviceRequest::new( device.user_id(), device.device_id().to_owned(), - AnyToDeviceEventContent::RoomEncrypted(content), + content.event_type(), + content.cast(), ); let request = OutgoingRequest { @@ -338,10 +340,7 @@ impl SessionManager { self.key_request_machine.retry_keyshare(user_id, device_id); if let Err(e) = self.check_if_unwedged(user_id, device_id).await { - error!( - "Error while treating an unwedged device {} {} {:?}", - user_id, device_id, e - ); + error!(%user_id, %device_id, "Error while treating an unwedged device: {e:?}"); } changes.sessions.push(session); diff --git a/crates/matrix-sdk-crypto/src/store/caches.rs b/crates/matrix-sdk-crypto/src/store/caches.rs index 271868007..e86f2992d 100644 --- a/crates/matrix-sdk-crypto/src/store/caches.rs +++ b/crates/matrix-sdk-crypto/src/store/caches.rs @@ -240,6 +240,7 @@ mod tests { "test_key", room_id, &outbound.session_key().await, + outbound.settings().algorithm.to_owned(), None, ); diff --git a/crates/matrix-sdk-crypto/src/store/integration_tests.rs b/crates/matrix-sdk-crypto/src/store/integration_tests.rs index 9007049dc..96ab40ebb 100644 --- a/crates/matrix-sdk-crypto/src/store/integration_tests.rs +++ b/crates/matrix-sdk-crypto/src/store/integration_tests.rs @@ -288,7 +288,7 @@ macro_rules! cryptostore_integration_tests { export.forwarding_curve25519_key_chain = vec!["some_chain".to_owned()]; - let session = InboundGroupSession::from_export(export); + let session = InboundGroupSession::from_export(&export).unwrap(); let changes = Changes { inbound_group_sessions: vec![session.clone()], ..Default::default() }; diff --git a/crates/matrix-sdk-crypto/src/store/memorystore.rs b/crates/matrix-sdk-crypto/src/store/memorystore.rs index a82874627..11ef4fe3c 100644 --- a/crates/matrix-sdk-crypto/src/store/memorystore.rs +++ b/crates/matrix-sdk-crypto/src/store/memorystore.rs @@ -350,6 +350,7 @@ mod tests { "test_key", room_id, &outbound.session_key().await, + outbound.settings().algorithm.to_owned(), None, ); diff --git a/crates/matrix-sdk-crypto/src/store/mod.rs b/crates/matrix-sdk-crypto/src/store/mod.rs index 574416cd4..075f18047 100644 --- a/crates/matrix-sdk-crypto/src/store/mod.rs +++ b/crates/matrix-sdk-crypto/src/store/mod.rs @@ -75,7 +75,7 @@ use crate::{ }, olm::{ InboundGroupSession, OlmMessageHash, OutboundGroupSession, PrivateCrossSigningIdentity, - ReadOnlyAccount, Session, + ReadOnlyAccount, Session, SessionCreationError, }, utilities::encode, verification::VerificationMachine, @@ -603,7 +603,7 @@ pub enum CryptoStoreError { /// The received room key couldn't be converted into a valid Megolm session. #[error(transparent)] - SessionCreation(#[from] vodozemac::megolm::SessionKeyDecodeError), + SessionCreation(#[from] SessionCreationError), /// A Matrix identifier failed to be validated. #[error(transparent)] diff --git a/crates/matrix-sdk-crypto/src/types/events/mod.rs b/crates/matrix-sdk-crypto/src/types/events/mod.rs index a85fecef6..fed94e3e3 100644 --- a/crates/matrix-sdk-crypto/src/types/events/mod.rs +++ b/crates/matrix-sdk-crypto/src/types/events/mod.rs @@ -18,16 +18,30 @@ //! types. Once deserialized they aim to zeroize all the secret material once //! the type is dropped. +pub mod room; pub mod room_key; pub mod secret_send; mod to_device; +use ruma::serde::Raw; pub use to_device::{ToDeviceCustomEvent, ToDeviceEvent, ToDeviceEvents}; /// A trait for event contents to define their event type. pub trait EventType { + /// The event type of the event content. + const EVENT_TYPE: &'static str; + /// Get the event type of the event content. - fn event_type(&self) -> &str; + /// + /// **Note**: This should never be implemented manually, this takes the + /// event type from the constant. + fn event_type(&self) -> &'static str { + Self::EVENT_TYPE + } +} + +impl EventType for Raw { + const EVENT_TYPE: &'static str = T::EVENT_TYPE; } fn from_str<'a, T, E>(string: &'a str) -> Result diff --git a/crates/matrix-sdk-crypto/src/types/events/room/encrypted.rs b/crates/matrix-sdk-crypto/src/types/events/room/encrypted.rs new file mode 100644 index 000000000..e64a1d1cf --- /dev/null +++ b/crates/matrix-sdk-crypto/src/types/events/room/encrypted.rs @@ -0,0 +1,400 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Types for the `m.room.encrypted` room events. + +use std::collections::BTreeMap; + +use ruma::{DeviceId, EventEncryptionAlgorithm, OwnedDeviceId}; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use vodozemac::{megolm::MegolmMessage, olm::OlmMessage, Curve25519PublicKey}; + +use super::Event; +use crate::types::{ + deserialize_curve_key, + events::{EventType, ToDeviceEvent}, + serialize_curve_key, +}; + +/// An m.room.encrypted room event. +pub type EncryptedEvent = Event; + +/// An m.room.encrypted to-device event. +pub type EncryptedToDeviceEvent = ToDeviceEvent; + +impl EncryptedToDeviceEvent { + /// Get the algorithm of the encrypted event content. + pub fn algorithm(&self) -> EventEncryptionAlgorithm { + self.content.algorithm() + } +} + +/// The content for `m.room.encrypted` to-device events. +#[derive(Clone, Debug, PartialEq, Eq, Deserialize)] +#[serde(try_from = "Helper")] +pub enum ToDeviceEncryptedEventContent { + /// The event content for events encrypted with the m.megolm.v1.aes-sha2 + /// algorithm. + OlmV1Curve25519AesSha2(Box), + /// An event content that was encrypted with an unknown encryption + /// algorithm. + Unknown(UnknownEncryptedContent), +} + +impl EventType for ToDeviceEncryptedEventContent { + const EVENT_TYPE: &'static str = "m.room.encrypted"; +} + +impl ToDeviceEncryptedEventContent { + /// Get the algorithm of the event content. + pub fn algorithm(&self) -> EventEncryptionAlgorithm { + match self { + ToDeviceEncryptedEventContent::OlmV1Curve25519AesSha2(_) => { + EventEncryptionAlgorithm::OlmV1Curve25519AesSha2 + } + ToDeviceEncryptedEventContent::Unknown(c) => c.algorithm.to_owned(), + } + } +} + +/// The event content for events encrypted with the m.olm.v1.curve25519-aes-sha2 +/// algorithm. +#[derive(Clone, Debug, PartialEq, Eq, Deserialize)] +#[serde(try_from = "OlmHelper")] +pub struct OlmV1Curve25519AesSha2Content { + /// The encrypted content of the event. + pub ciphertext: OlmMessage, + + /// The Curve25519 key of the recipient device. + pub recipient_key: Curve25519PublicKey, + + /// The Curve25519 key of the sender. + pub sender_key: Curve25519PublicKey, +} + +#[derive(Clone, Debug, PartialEq, Eq, Deserialize, Serialize)] +struct OlmHelper { + #[serde(deserialize_with = "deserialize_curve_key", serialize_with = "serialize_curve_key")] + sender_key: Curve25519PublicKey, + ciphertext: BTreeMap, +} + +impl Serialize for OlmV1Curve25519AesSha2Content { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let ciphertext = + BTreeMap::from([(self.recipient_key.to_base64(), self.ciphertext.clone())]); + + OlmHelper { sender_key: self.sender_key, ciphertext }.serialize(serializer) + } +} + +impl TryFrom for OlmV1Curve25519AesSha2Content { + type Error = serde_json::Error; + + fn try_from(value: OlmHelper) -> Result { + let (recipient_key, ciphertext) = value.ciphertext.into_iter().next().ok_or_else(|| { + serde::de::Error::custom( + "The `m.room.encrypted` event is missing a ciphertext".to_owned(), + ) + })?; + + let recipient_key = + Curve25519PublicKey::from_base64(&recipient_key).map_err(serde::de::Error::custom)?; + + Ok(Self { ciphertext, recipient_key, sender_key: value.sender_key }) + } +} + +/// The content for `m.room.encrypted` room events. +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +pub struct RoomEncryptedEventContent { + /// Algorithm-specific fields. + #[serde(flatten)] + pub scheme: RoomEventEncryptionScheme, + + /// Information about related events. + #[serde(flatten, skip_serializing_if = "Option::is_none")] + pub relates_to: Option, +} + +impl RoomEncryptedEventContent { + /// Get the algorithm of the event content. + pub fn algorithm(&self) -> EventEncryptionAlgorithm { + self.scheme.algorithm() + } +} + +impl EventType for RoomEncryptedEventContent { + const EVENT_TYPE: &'static str = "m.room.encrypted"; +} + +/// An enum for per encryption algorithm event contents. +#[derive(Clone, Debug, PartialEq, Eq, Deserialize)] +#[serde(try_from = "Helper")] +pub enum RoomEventEncryptionScheme { + /// The event content for events encrypted with the m.megolm.v1.aes-sha2 + /// algorithm. + MegolmV1AesSha2(MegolmV1AesSha2Content), + /// An event content that was encrypted with an unknown encryption + /// algorithm. + Unknown(UnknownEncryptedContent), +} + +impl RoomEventEncryptionScheme { + /// Get the algorithm of the event content. + pub fn algorithm(&self) -> EventEncryptionAlgorithm { + match self { + RoomEventEncryptionScheme::MegolmV1AesSha2(_) => { + EventEncryptionAlgorithm::MegolmV1AesSha2 + } + RoomEventEncryptionScheme::Unknown(c) => c.algorithm.to_owned(), + } + } +} + +pub(crate) enum SupportedEventEncryptionSchemes<'a> { + MegolmV1AesSha2(&'a MegolmV1AesSha2Content), +} + +impl SupportedEventEncryptionSchemes<'_> { + /// The Curve25519 key of the sender. + pub fn sender_key(&self) -> Curve25519PublicKey { + match self { + SupportedEventEncryptionSchemes::MegolmV1AesSha2(c) => c.sender_key, + } + } + + /// The ID of the session used to encrypt the message. + pub fn session_id(&self) -> &str { + match self { + SupportedEventEncryptionSchemes::MegolmV1AesSha2(c) => &c.session_id, + } + } + + /// The ID of the sending device. + pub fn device_id(&self) -> &DeviceId { + match self { + SupportedEventEncryptionSchemes::MegolmV1AesSha2(c) => &c.device_id, + } + } + + /// The algorithm that was used to encrypt the event content. + pub fn algorithm(&self) -> EventEncryptionAlgorithm { + match self { + SupportedEventEncryptionSchemes::MegolmV1AesSha2(_) => { + EventEncryptionAlgorithm::MegolmV1AesSha2 + } + } + } +} + +impl<'a> From<&'a MegolmV1AesSha2Content> for SupportedEventEncryptionSchemes<'a> { + fn from(c: &'a MegolmV1AesSha2Content) -> Self { + Self::MegolmV1AesSha2(c) + } +} + +/// The event content for events encrypted with the m.megolm.v1.aes-sha2 +/// algorithm. +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +pub struct MegolmV1AesSha2Content { + /// The encrypted content of the event. + pub ciphertext: MegolmMessage, + + /// The Curve25519 key of the sender. + #[serde(deserialize_with = "deserialize_curve_key", serialize_with = "serialize_curve_key")] + pub sender_key: Curve25519PublicKey, + + /// The ID of the sending device. + pub device_id: OwnedDeviceId, + + /// The ID of the session used to encrypt the message. + pub session_id: String, +} + +/// An unknown and unsupported `m.room.encrypted` event content. +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +pub struct UnknownEncryptedContent { + /// The algorithm that was used to encrypt the given event content. + pub algorithm: EventEncryptionAlgorithm, + /// The other data of the unknown encryped content. + #[serde(flatten)] + other: BTreeMap, +} + +#[derive(Debug, Deserialize, Serialize)] +struct Helper { + algorithm: EventEncryptionAlgorithm, + #[serde(flatten)] + other: Value, +} + +macro_rules! scheme_serialization { + ($something:ident, $($algorithm:ident => $content:ident),+ $(,)?) => { + $( + impl From<$content> for $something { + fn from(c: $content) -> Self { + Self::$algorithm(c.into()) + } + } + )+ + + impl TryFrom for $something { + type Error = serde_json::Error; + + fn try_from(value: Helper) -> Result { + Ok(match value.algorithm { + $( + EventEncryptionAlgorithm::$algorithm => { + let content: $content = serde_json::from_value(value.other)?; + content.into() + } + )+ + _ => Self::Unknown(UnknownEncryptedContent { + algorithm: value.algorithm, + other: serde_json::from_value(value.other)?, + }), + }) + } + } + + impl Serialize for $something { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let helper = match self { + $( + Self::$algorithm(r) => Helper { + algorithm: self.algorithm(), + other: serde_json::to_value(r).map_err(serde::ser::Error::custom)?, + }, + )+ + Self::Unknown(r) => Helper { + algorithm: r.algorithm.clone(), + other: serde_json::to_value(r.other.clone()).map_err(serde::ser::Error::custom)?, + }, + }; + + helper.serialize(serializer) + } + } + }; +} + +scheme_serialization!( + RoomEventEncryptionScheme, + MegolmV1AesSha2 => MegolmV1AesSha2Content, +); + +scheme_serialization!( + ToDeviceEncryptedEventContent, + OlmV1Curve25519AesSha2 => OlmV1Curve25519AesSha2Content, +); + +#[cfg(test)] +pub(crate) mod test { + use matches::assert_matches; + use serde_json::{json, Value}; + use vodozemac::Curve25519PublicKey; + + use super::{ + EncryptedEvent, EncryptedToDeviceEvent, OlmV1Curve25519AesSha2Content, + RoomEventEncryptionScheme, ToDeviceEncryptedEventContent, + }; + + pub fn json() -> Value { + json!({ + "sender": "@alice:example.org", + "event_id": "$Nhl3rsgHMjk-DjMJANawr9HHAhLg4GcoTYrSiYYGqEE", + "content": { + "m.custom": "something custom", + "algorithm": "m.megolm.v1.aes-sha2", + "device_id": "DEWRCMENGS", + "session_id": "ZFD6+OmV7fVCsJ7Gap8UnORH8EnmiAkes8FAvQuCw/I", + "sender_key": "WJ6Ce7U67a6jqkHYHd8o0+5H4bqdi9hInZdk0+swuXs", + "ciphertext": "AwgAEiBQs2LgBD2CcB+RLH2bsgp9VadFUJhBXOtCmcJuttBD\ + OeDNjL21d9z0AcVSfQFAh9huh4or7sWuNrHcvu9/sMbweTgc\ + 0UtdA5xFLheubHouXy4aewze+ShndWAaTbjWJMLsPSQDUMQH\ + BA" + }, + "type": "m.room.encrypted", + "origin_server_ts": 1632491098485u64, + "m.custom.top": "something custom in the top", + }) + } + + pub fn olm_v1_json() -> Value { + json!({ + "algorithm": "m.olm.v1.curve25519-aes-sha2", + "ciphertext": { + "Nn0L2hkcCMFKqynTjyGsJbth7QrVmX3lbrksMkrGOAw": { + "body": "Awogv7Iysf062hV1gZNfG/SdO5TdLYtkRI12em6LxralPxoSIC\ + C/Avnha6NfkaMWSC+5h+khS0wHiUzA2bPmAvVo/iYhGiAfDNh4\ + F0eqPvOc4Hw9wMgd+frzedZgmhUNfKT0UzHQZSJPAwogF8fTdT\ + cPt1ppJ/KAEivFZ4dIyAlRUjzhlqzYsw9C1HoQACIgb9MK/a9T\ + RLtwol9gfy7OeKdpmSe39YhP+5OchhKvX6eO3/aED3X1oA", + "type": 0 + } + }, + "sender_key": "mjkTX0I0Cp44ZfolOVbFe5WYPRmT6AX3J0ZbnGWnnWs" + }) + } + + pub fn to_device_json() -> Value { + json!({ + "content": olm_v1_json(), + "sender": "@example:morpheus.localhost", + "type": "m.room.encrypted" + }) + } + + #[test] + fn deserialization() -> Result<(), serde_json::Error> { + let json = json(); + let event: EncryptedEvent = serde_json::from_value(json.clone())?; + + assert_matches!(event.content.scheme, RoomEventEncryptionScheme::MegolmV1AesSha2(_)); + let serialized = serde_json::to_value(event)?; + assert_eq!(json, serialized); + + let json = olm_v1_json(); + let content: OlmV1Curve25519AesSha2Content = serde_json::from_value(json)?; + + assert_eq!( + content.sender_key, + Curve25519PublicKey::from_base64("mjkTX0I0Cp44ZfolOVbFe5WYPRmT6AX3J0ZbnGWnnWs") + .unwrap() + ); + + assert_eq!( + content.recipient_key, + Curve25519PublicKey::from_base64("Nn0L2hkcCMFKqynTjyGsJbth7QrVmX3lbrksMkrGOAw") + .unwrap() + ); + + let json = to_device_json(); + let event: EncryptedToDeviceEvent = serde_json::from_value(json.clone())?; + + assert_matches!(event.content, ToDeviceEncryptedEventContent::OlmV1Curve25519AesSha2(_)); + let serialized = serde_json::to_value(event)?; + assert_eq!(json, serialized); + + Ok(()) + } +} diff --git a/crates/matrix-sdk-crypto/src/types/events/room/mod.rs b/crates/matrix-sdk-crypto/src/types/events/room/mod.rs new file mode 100644 index 000000000..ddb59fb97 --- /dev/null +++ b/crates/matrix-sdk-crypto/src/types/events/room/mod.rs @@ -0,0 +1,91 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Types for room events. + +use std::{collections::BTreeMap, fmt::Debug}; + +use ruma::{EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedUserId, UserId}; +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +use super::EventType; + +pub mod encrypted; + +/// Generic room event with a known type and content. +#[derive(Debug, Deserialize)] +pub struct Event +where + C: EventType + Debug + Sized + Serialize, +{ + /// Contains the fully-qualified ID of the user who sent this event. + pub sender: OwnedUserId, + + /// The globally unique identifier for this event. + pub event_id: OwnedEventId, + + /// The body of this event, as created by the client which sent it. + pub content: C, + + /// Timestamp (in milliseconds since the unix epoch) on originating + /// homeserver when this event was sent. + pub origin_server_ts: MilliSecondsSinceUnixEpoch, + + /// Contains optional extra information about the event. + #[serde(default)] + pub unsigned: BTreeMap, + + /// Any other unknown data of the room event. + #[serde(flatten)] + other: BTreeMap, +} + +impl Serialize for Event +where + C: EventType + Debug + Sized + Serialize, +{ + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + #[derive(Serialize)] + struct Helper<'a, C> { + sender: &'a UserId, + event_id: &'a EventId, + #[serde(rename = "type")] + event_type: &'a str, + content: &'a C, + origin_server_ts: MilliSecondsSinceUnixEpoch, + #[serde(skip_serializing_if = "BTreeMap::is_empty")] + unsigned: &'a BTreeMap, + #[serde(flatten)] + other: &'a BTreeMap, + } + + let event_type = C::EVENT_TYPE; + + let helper = Helper { + sender: &self.sender, + content: &self.content, + event_type, + other: &self.other, + event_id: &self.event_id, + origin_server_ts: self.origin_server_ts, + unsigned: &self.unsigned, + }; + + helper.serialize(serializer) + } +} diff --git a/crates/matrix-sdk-crypto/src/types/events/room_key.rs b/crates/matrix-sdk-crypto/src/types/events/room_key.rs index 5ce6ecc2b..4be9b1d38 100644 --- a/crates/matrix-sdk-crypto/src/types/events/room_key.rs +++ b/crates/matrix-sdk-crypto/src/types/events/room_key.rs @@ -26,12 +26,17 @@ use super::{EventType, ToDeviceEvent}; /// The `m.room_key` to-device event. pub type RoomKeyEvent = ToDeviceEvent; -impl EventType for RoomKeyContent { - fn event_type(&self) -> &str { - "m.room_key" +impl RoomKeyEvent { + /// Get the algorithm of the room key. + pub fn algorithm(&self) -> EventEncryptionAlgorithm { + self.content.algorithm() } } +impl EventType for RoomKeyContent { + const EVENT_TYPE: &'static str = "m.room_key"; +} + /// The `m.room_key` event content. /// /// This is an enum over the different room key algorithms we support. @@ -49,6 +54,14 @@ pub enum RoomKeyContent { } impl RoomKeyContent { + /// Get the algorithm of the room key. + pub fn algorithm(&self) -> EventEncryptionAlgorithm { + match &self { + RoomKeyContent::MegolmV1AesSha2(_) => EventEncryptionAlgorithm::MegolmV1AesSha2, + RoomKeyContent::Unknown(c) => c.algorithm.to_owned(), + } + } + pub(super) fn serialize_zeroized(&self) -> Result, serde_json::Error> { #[derive(Serialize)] struct Helper<'a> { @@ -69,7 +82,7 @@ impl RoomKeyContent { }; let helper = RoomKeyHelper { - algorithm: EventEncryptionAlgorithm::MegolmV1AesSha2, + algorithm: self.algorithm(), other: serde_json::to_value(helper)?, }; diff --git a/crates/matrix-sdk-crypto/src/types/events/secret_send.rs b/crates/matrix-sdk-crypto/src/types/events/secret_send.rs index d288fe09d..649c9655e 100644 --- a/crates/matrix-sdk-crypto/src/types/events/secret_send.rs +++ b/crates/matrix-sdk-crypto/src/types/events/secret_send.rs @@ -69,9 +69,7 @@ impl std::fmt::Debug for SecretSendContent { } impl EventType for SecretSendContent { - fn event_type(&self) -> &str { - "m.secret.send" - } + const EVENT_TYPE: &'static str = "m.secret.send"; } #[cfg(test)] diff --git a/crates/matrix-sdk-crypto/src/types/events/to_device.rs b/crates/matrix-sdk-crypto/src/types/events/to_device.rs index 8918f2bc9..e4c097097 100644 --- a/crates/matrix-sdk-crypto/src/types/events/to_device.rs +++ b/crates/matrix-sdk-crypto/src/types/events/to_device.rs @@ -24,7 +24,6 @@ use ruma::{ mac::ToDeviceKeyVerificationMacEvent, ready::ToDeviceKeyVerificationReadyEvent, request::ToDeviceKeyVerificationRequestEvent, start::ToDeviceKeyVerificationStartEvent, }, - room::encrypted::ToDeviceRoomEncryptedEvent, room_key_request::ToDeviceRoomKeyRequestEvent, secret::request::{SecretName, ToDeviceSecretRequestEvent}, EventContent, ToDeviceEventType, @@ -39,7 +38,10 @@ use serde_json::{ }; use zeroize::Zeroize; -use super::{room_key::RoomKeyEvent, secret_send::SecretSendEvent, EventType}; +use super::{ + room::encrypted::EncryptedToDeviceEvent, room_key::RoomKeyEvent, secret_send::SecretSendEvent, + EventType, +}; use crate::types::events::from_str; /// An enum over the various to-device events we support. @@ -68,7 +70,7 @@ pub enum ToDeviceEvents { KeyVerificationRequest(ToDeviceKeyVerificationRequestEvent), /// The `m.room.encrypted` to-device event. - RoomEncrypted(ToDeviceRoomEncryptedEvent), + RoomEncrypted(EncryptedToDeviceEvent), /// The `m.room_key` to-device event. RoomKey(RoomKeyEvent), /// The `m.room_key_request` to-device event. @@ -122,7 +124,7 @@ impl ToDeviceEvents { ToDeviceEvents::KeyVerificationReady(e) => e.content.event_type(), ToDeviceEvents::KeyVerificationRequest(e) => e.content.event_type(), - ToDeviceEvents::RoomEncrypted(e) => e.content.event_type(), + ToDeviceEvents::RoomEncrypted(_) => ToDeviceEventType::RoomEncrypted, ToDeviceEvents::RoomKey(_) => ToDeviceEventType::RoomKey, ToDeviceEvents::RoomKeyRequest(e) => e.content.event_type(), ToDeviceEvents::ForwardedRoomKey(e) => e.content.event_type(), @@ -241,7 +243,7 @@ where pub content: C, /// Any other unknown data of the to-device event. #[serde(flatten)] - other: BTreeMap, + pub(crate) other: BTreeMap, } impl Serialize for ToDeviceEvent @@ -396,23 +398,6 @@ mod test { }) } - fn room_encrypted_event() -> Value { - json!({ - "sender": "@alice:example.org", - "content": { - "algorithm": "m.olm.v1.curve25519-aes-sha2", - "sender_key": "", - "ciphertext": { - "": { - "type": 0, - "body": "" - } - } - }, - "type": "m.room.encrypted", - }) - } - fn forwarded_room_key_event() -> Value { json!({ "sender": "@alice:example.org", @@ -490,7 +475,7 @@ mod test { dummy_event => Dummy, // `m.room.encrypted` - room_encrypted_event => RoomEncrypted, + crate::types::events::room::encrypted::test::to_device_json => RoomEncrypted, ); Ok(()) diff --git a/crates/matrix-sdk-crypto/src/utilities.rs b/crates/matrix-sdk-crypto/src/utilities.rs index 4c2cc2a6f..b8f2804e3 100644 --- a/crates/matrix-sdk-crypto/src/utilities.rs +++ b/crates/matrix-sdk-crypto/src/utilities.rs @@ -24,3 +24,13 @@ pub fn decode(input: impl AsRef<[u8]>) -> Result, DecodeError> { pub fn encode(input: impl AsRef<[u8]>) -> String { encode_config(input, STANDARD_NO_PAD) } + +#[cfg(test)] +pub(crate) fn json_convert(value: &T) -> serde_json::Result +where + T: serde::Serialize, + U: serde::de::DeserializeOwned, +{ + let json = serde_json::to_string(value)?; + serde_json::from_str(&json) +} diff --git a/crates/matrix-sdk-crypto/src/verification/cache.rs b/crates/matrix-sdk-crypto/src/verification/cache.rs index 0368a5213..bb674b4e4 100644 --- a/crates/matrix-sdk-crypto/src/verification/cache.rs +++ b/crates/matrix-sdk-crypto/src/verification/cache.rs @@ -169,7 +169,12 @@ impl VerificationCache { ) { match content { OutgoingContent::ToDevice(c) => { - let request = ToDeviceRequest::new(recipient, recipient_device.to_owned(), c); + let request = ToDeviceRequest::with_id( + recipient, + recipient_device.to_owned(), + c, + TransactionId::new(), + ); let request_id = request.txn_id.clone(); let request = OutgoingRequest { diff --git a/crates/matrix-sdk-crypto/src/verification/event_enums.rs b/crates/matrix-sdk-crypto/src/verification/event_enums.rs index 340050c9f..aa4bfdc3a 100644 --- a/crates/matrix-sdk-crypto/src/verification/event_enums.rs +++ b/crates/matrix-sdk-crypto/src/verification/event_enums.rs @@ -753,7 +753,7 @@ impl TryFrom for OutgoingContent { serde_json::from_value(json).map_err(|e| e.to_string())?, ) } - e => return Err(format!("Unsupported event type {}", e)), + e => return Err(format!("Unsupported event type {e}")), }; Ok(content.into()) diff --git a/crates/matrix-sdk-crypto/src/verification/machine.rs b/crates/matrix-sdk-crypto/src/verification/machine.rs index b5def1845..b1ec9b5dd 100644 --- a/crates/matrix-sdk-crypto/src/verification/machine.rs +++ b/crates/matrix-sdk-crypto/src/verification/machine.rs @@ -132,8 +132,12 @@ impl VerificationMachine { RoomMessageRequest { room_id: r, txn_id: TransactionId::new(), content: c }.into() } OutgoingContent::ToDevice(c) => { - let request = - ToDeviceRequest::new(device.user_id(), device.device_id().to_owned(), c); + let request = ToDeviceRequest::with_id( + device.user_id(), + device.device_id().to_owned(), + c, + TransactionId::new(), + ); self.verifications.insert_sas(sas.clone()); @@ -506,7 +510,7 @@ impl VerificationMachine { #[cfg(test)] mod tests { - use std::{convert::TryFrom, sync::Arc, time::Duration}; + use std::{sync::Arc, time::Duration}; use matrix_sdk_common::{instant::Instant, locks::Mutex}; use matrix_sdk_test::async_test; diff --git a/crates/matrix-sdk-crypto/src/verification/mod.rs b/crates/matrix-sdk-crypto/src/verification/mod.rs index 6a296a3ec..29c411562 100644 --- a/crates/matrix-sdk-crypto/src/verification/mod.rs +++ b/crates/matrix-sdk-crypto/src/verification/mod.rs @@ -501,10 +501,9 @@ impl IdentitiesBeingVerified { } Err(e) => { error!( - "Error signing device keys for {} {} {:?}", - device.user_id(), - device.device_id(), - e + user_id = %device.user_id(), + device_id = %device.device_id(), + "Error signing device keys: {e:?}", ); None } @@ -527,17 +526,16 @@ impl IdentitiesBeingVerified { Ok(r) => Some(r), Err(SignatureError::MissingSigningKey) => { warn!( - "Can't sign the public cross signing keys for {}, \ - no private user signing key found", - i.user_id() + user_id = %i.user_id(), + "Can't sign the public cross signing keys, \ + no private user signing key found", ); None } Err(e) => { error!( - "Error signing the public cross signing keys for {} {:?}", - i.user_id(), - e + user_id = %i.user_id(), + "Error signing the public cross signing keys: {e:?}", ); None } @@ -707,7 +705,6 @@ impl IdentitiesBeingVerified { #[cfg(test)] pub(crate) mod tests { - use std::convert::TryInto; use ruma::{ events::{AnyToDeviceEventContent, ToDeviceEvent}, diff --git a/crates/matrix-sdk-crypto/src/verification/qrcode.rs b/crates/matrix-sdk-crypto/src/verification/qrcode.rs index 8969df1fa..ee01d5147 100644 --- a/crates/matrix-sdk-crypto/src/verification/qrcode.rs +++ b/crates/matrix-sdk-crypto/src/verification/qrcode.rs @@ -291,10 +291,11 @@ impl QrVerification { OutgoingContent::Room(room_id, content) => { RoomMessageRequest { room_id, txn_id: TransactionId::new(), content }.into() } - OutgoingContent::ToDevice(c) => ToDeviceRequest::new( + OutgoingContent::ToDevice(c) => ToDeviceRequest::with_id( self.identities.other_user_id(), self.identities.other_device_id().to_owned(), c, + TransactionId::new(), ) .into(), } @@ -787,7 +788,7 @@ impl QrState { #[cfg(test)] mod tests { - use std::{convert::TryFrom, sync::Arc}; + use std::sync::Arc; use matrix_sdk_common::locks::Mutex; use matrix_sdk_qrcode::QrVerificationData; diff --git a/crates/matrix-sdk-crypto/src/verification/requests.rs b/crates/matrix-sdk-crypto/src/verification/requests.rs index ceb248a67..fd33342a0 100644 --- a/crates/matrix-sdk-crypto/src/verification/requests.rs +++ b/crates/matrix-sdk-crypto/src/verification/requests.rs @@ -380,9 +380,13 @@ impl VerificationRequest { let mut inner = self.inner.lock().unwrap(); inner.accept(methods).map(|c| match c { - OutgoingContent::ToDevice(content) => { - ToDeviceRequest::new(self.other_user(), inner.other_device_id(), content).into() - } + OutgoingContent::ToDevice(content) => ToDeviceRequest::with_id( + self.other_user(), + inner.other_device_id(), + content, + TransactionId::new(), + ) + .into(), OutgoingContent::Room(room_id, content) => { RoomMessageRequest { room_id, txn_id: TransactionId::new(), content }.into() } @@ -435,7 +439,13 @@ impl VerificationRequest { ) .into() } else { - ToDeviceRequest::new(self.other_user(), other_device, content).into() + ToDeviceRequest::with_id( + self.other_user(), + other_device, + content, + TransactionId::new(), + ) + .into() } } OutgoingContent::Room(room_id, content) => { @@ -627,10 +637,11 @@ impl VerificationRequest { self.verification_cache.insert_sas(sas.clone()); let request = match content { - OutgoingContent::ToDevice(content) => ToDeviceRequest::new( + OutgoingContent::ToDevice(content) => ToDeviceRequest::with_id( self.other_user(), inner.other_device_id(), content, + TransactionId::new(), ) .into(), OutgoingContent::Room(room_id, content) => { diff --git a/crates/matrix-sdk-crypto/src/verification/sas/helpers.rs b/crates/matrix-sdk-crypto/src/verification/sas/helpers.rs index 3ce340747..da617e73a 100644 --- a/crates/matrix-sdk-crypto/src/verification/sas/helpers.rs +++ b/crates/matrix-sdk-crypto/src/verification/sas/helpers.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{collections::BTreeMap, convert::TryInto}; +use std::collections::BTreeMap; use ruma::{ events::{ @@ -197,18 +197,17 @@ pub fn receive_mac_event( let info = extra_mac_info_receive(ids, flow_id); trace!( - "Received a key.verification.mac event from {} {}", - sender, - ids.other_device.device_id() + %sender, + device_id = %ids.other_device.device_id(), + "Received a key.verification.mac event" ); let mut keys = content.mac().keys().map(|k| k.as_str()).collect::>(); keys.sort_unstable(); - let keys = Base64::parse( - sas.calculate_mac_invalid_base64(&keys.join(","), &format!("{}KEY_IDS", &info)), - ) - .expect("Can't base64-decode SAS MAC"); + let keys = + Base64::parse(sas.calculate_mac_invalid_base64(&keys.join(","), &format!("{info}KEY_IDS"))) + .expect("Can't base64-decode SAS MAC"); if keys != *content.keys() { return Err(CancelCode::KeyMismatch); @@ -216,10 +215,10 @@ pub fn receive_mac_event( for (key_id, key_mac) in content.mac() { trace!( - "Checking MAC for the key id {} from {} {}", + %sender, + device_id = %ids.other_device.device_id(), key_id, - sender, - ids.other_device.device_id() + "Checking a SAS MAC", ); let key_id: OwnedDeviceKeyId = match key_id.as_str().try_into() { @@ -234,7 +233,7 @@ pub fn receive_mac_event( .expect("Can't base64-decode SAS MAC"); if *key_mac == calculated_mac { - trace!("Successfully verified the device key {} from {}", key_id, sender); + trace!(%sender, %key_id, "Successfully verified a device key"); verified_devices.push(ids.other_device.clone()); } else { return Err(CancelCode::KeyMismatch); @@ -243,14 +242,13 @@ pub fn receive_mac_event( if let Some(key) = identity.master_key().get_key(&key_id) { // TODO we should check that the master key signs the device, // this way we know the master key also trusts the device - let calculated_mac = Base64::parse(sas.calculate_mac_invalid_base64( - &key.to_base64(), - &format!("{}{}", info, key_id), - )) + let calculated_mac = Base64::parse( + sas.calculate_mac_invalid_base64(&key.to_base64(), &format!("{info}{key_id}")), + ) .expect("Can't base64-decode SAS MAC"); if *key_mac == calculated_mac { - trace!("Successfully verified the master key {} from {}", key_id, sender); + trace!(%sender, %key_id, "Successfully verified a master key"); verified_identities.push(identity.clone()) } else { return Err(CancelCode::KeyMismatch); @@ -258,10 +256,8 @@ pub fn receive_mac_event( } } else { warn!( - "Key ID {} in MAC event from {} {} doesn't belong to any device \ + "Key ID {key_id} in MAC event from {sender} {} doesn't belong to any device \ or user identity", - key_id, - sender, ids.other_device.device_id() ); } @@ -312,7 +308,7 @@ pub fn get_mac_content(sas: &EstablishedSas, ids: &SasIds, flow_id: &FlowId) -> mac.insert( key_id.to_string(), - Base64::parse(sas.calculate_mac_invalid_base64(&key, &format!("{}{}", info, key_id))) + Base64::parse(sas.calculate_mac_invalid_base64(&key, &format!("{info}{key_id}"))) .expect("Can't base64-decode SAS MAC"), ); @@ -321,10 +317,9 @@ pub fn get_mac_content(sas: &EstablishedSas, ids: &SasIds, flow_id: &FlowId) -> if let Some(key) = own_identity.master_key().get_first_key() { let key_id = format!("{}:{}", DeviceKeyAlgorithm::Ed25519, key.to_base64()); - let calculated_mac = Base64::parse(sas.calculate_mac_invalid_base64( - &key.to_base64(), - &format!("{}{}", info, &key_id), - )) + let calculated_mac = Base64::parse( + sas.calculate_mac_invalid_base64(&key.to_base64(), &format!("{info}{key_id}")), + ) .expect("Can't base64-decode SAS Master key MAC"); mac.insert(key_id, calculated_mac); diff --git a/crates/matrix-sdk-crypto/src/verification/sas/mod.rs b/crates/matrix-sdk-crypto/src/verification/sas/mod.rs index 12aacdb2e..ab802967b 100644 --- a/crates/matrix-sdk-crypto/src/verification/sas/mod.rs +++ b/crates/matrix-sdk-crypto/src/verification/sas/mod.rs @@ -473,7 +473,12 @@ impl Sas { } pub(crate) fn content_to_request(&self, content: AnyToDeviceEventContent) -> ToDeviceRequest { - ToDeviceRequest::new(self.other_user_id(), self.other_device_id().to_owned(), content) + ToDeviceRequest::with_id( + self.other_user_id(), + self.other_device_id().to_owned(), + content, + TransactionId::new(), + ) } } @@ -508,7 +513,7 @@ impl AcceptSettings { #[cfg(test)] mod tests { - use std::{convert::TryFrom, sync::Arc}; + use std::sync::Arc; use matrix_sdk_common::locks::Mutex; use matrix_sdk_test::async_test; diff --git a/crates/matrix-sdk-crypto/src/verification/sas/sas_state.rs b/crates/matrix-sdk-crypto/src/verification/sas/sas_state.rs index e9133eb50..6c8f5bdfe 100644 --- a/crates/matrix-sdk-crypto/src/verification/sas/sas_state.rs +++ b/crates/matrix-sdk-crypto/src/verification/sas/sas_state.rs @@ -13,7 +13,6 @@ // limitations under the License. use std::{ - convert::{TryFrom, TryInto}, matches, sync::{Arc, Mutex}, time::Duration, @@ -546,10 +545,10 @@ impl SasState { let commitment = calculate_commitment(our_public_key, content); info!( - "Calculated commitment for pubkey {} and content {:?} {}", - our_public_key.to_base64(), - content, - commitment + public_key = our_public_key.to_base64(), + %commitment, + ?content, + "Calculated SAS commitment", ); if let Ok(accepted_protocols) = AcceptedProtocols::try_from(method_content) { @@ -1224,8 +1223,6 @@ impl SasState { #[cfg(test)] mod tests { - use std::convert::TryFrom; - use matrix_sdk_test::async_test; use ruma::{ device_id, diff --git a/crates/matrix-sdk-indexeddb/Cargo.toml b/crates/matrix-sdk-indexeddb/Cargo.toml index 79e659f95..ba4a1455a 100644 --- a/crates/matrix-sdk-indexeddb/Cargo.toml +++ b/crates/matrix-sdk-indexeddb/Cargo.toml @@ -15,17 +15,19 @@ rustdoc-args = ["--cfg", "docsrs"] [features] default = ["e2e-encryption"] -e2e-encryption = ["matrix-sdk-base/e2e-encryption", "dep:matrix-sdk-crypto"] -experimental-timeline = ["matrix-sdk-base/experimental-timeline"] +e2e-encryption = ["matrix-sdk-base/e2e-encryption", "dep:matrix-sdk-crypto", "dashmap"] +experimental-timeline = ["matrix-sdk-base/experimental-timeline", "dep:futures-util"] [dependencies] anyhow = "1.0.57" async-trait = "0.1.53" base64 = "0.13.0" -dashmap = "5.2.0" -futures-util = { version = "0.3.21", default-features = false } +dashmap = { version = "5.2.0", optional = true } +derive_builder = "0.11.2" +futures-util = { version = " 0.3.21", default-features = false, features = ["alloc"], optional = true } indexed_db_futures = "0.2.3" +js-sys = { version = "0.3.58" } matrix-sdk-base = { version = "0.5.0", path = "../matrix-sdk-base" } matrix-sdk-crypto = { version = "0.5.0", path = "../matrix-sdk-crypto", optional = true } matrix-sdk-store-encryption = { version = "0.1.0", path = "../matrix-sdk-store-encryption" } diff --git a/crates/matrix-sdk-indexeddb/src/lib.rs b/crates/matrix-sdk-indexeddb/src/lib.rs index a11e863e4..ffb31de69 100644 --- a/crates/matrix-sdk-indexeddb/src/lib.rs +++ b/crates/matrix-sdk-indexeddb/src/lib.rs @@ -19,7 +19,7 @@ pub use cryptostore::IndexeddbStore as CryptoStore; #[cfg(feature = "e2e-encryption")] use cryptostore::IndexeddbStoreError; #[cfg(target_arch = "wasm32")] -pub use state_store::IndexeddbStore as StateStore; +pub use state_store::{IndexeddbStore as StateStore, IndexeddbStoreBuilder as StateStoreBuilder}; #[cfg(target_arch = "wasm32")] #[cfg(feature = "e2e-encryption")] @@ -30,17 +30,18 @@ async fn open_stores_with_name( passphrase: Option<&str>, ) -> Result<(StateStore, CryptoStore), OpenStoreError> { let name = name.into(); + let mut builder = StateStore::builder(); + builder.name(name.clone()); if let Some(passphrase) = passphrase { - let state_store = StateStore::open_with_passphrase(name.clone(), passphrase).await?; - let crypto_store = - CryptoStore::open_with_store_cipher(name, state_store.store_cipher.clone()).await?; - Ok((state_store, crypto_store)) - } else { - let state_store = StateStore::open_with_name(name.clone()).await?; - let crypto_store = CryptoStore::open_with_name(name).await?; - Ok((state_store, crypto_store)) + builder.passphrase(passphrase.to_owned()); } + + let state_store = builder.build().await.map_err(StoreError::from)?; + let crypto_store = + CryptoStore::open_with_store_cipher(name, state_store.store_cipher.clone()).await?; + + Ok((state_store, crypto_store)) } #[cfg(target_arch = "wasm32")] @@ -61,11 +62,14 @@ pub async fn make_store_config( #[cfg(not(feature = "e2e-encryption"))] { - let state_store = if let Some(passphrase) = passphrase { - StateStore::open_with_passphrase(name, passphrase).await? - } else { - StateStore::open_with_name(name).await? - }; + let mut builder = StateStore::builder(); + builder.name(name.clone()); + + if let Some(passphrase) = passphrase { + builder.passphrase(passphrase.to_owned()); + } + + let state_store = builder.build().await.map_err(StoreError::from)?; Ok(StoreConfig::new().state_store(state_store)) } diff --git a/crates/matrix-sdk-indexeddb/src/safe_encode.rs b/crates/matrix-sdk-indexeddb/src/safe_encode.rs index 081c1a35a..55bfe4df4 100644 --- a/crates/matrix-sdk-indexeddb/src/safe_encode.rs +++ b/crates/matrix-sdk-indexeddb/src/safe_encode.rs @@ -72,8 +72,7 @@ pub trait SafeEncode { store_cipher: &StoreCipher, i: usize, ) -> JsValue { - format!("{}{}{:016x}", self.as_secure_string(table_name, store_cipher), KEY_SEPARATOR, i,) - .into() + format!("{}{KEY_SEPARATOR}{i:016x}", self.as_secure_string(table_name, store_cipher)).into() } /// Encode self into a IdbKeyRange for searching all keys that are diff --git a/crates/matrix-sdk-indexeddb/src/state_store.rs b/crates/matrix-sdk-indexeddb/src/state_store.rs index 229dc7e53..c6081fbaa 100644 --- a/crates/matrix-sdk-indexeddb/src/state_store.rs +++ b/crates/matrix-sdk-indexeddb/src/state_store.rs @@ -12,13 +12,21 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{collections::BTreeSet, sync::Arc}; +use std::{ + collections::BTreeSet, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, +}; use anyhow::anyhow; use async_trait::async_trait; +use derive_builder::Builder; #[cfg(feature = "experimental-timeline")] use futures_util::stream; use indexed_db_futures::prelude::*; +use js_sys::Date as JsDate; use matrix_sdk_base::{ deserialized_responses::MemberEvent, media::{MediaRequest, UniqueKey}, @@ -46,6 +54,8 @@ use ruma::{ RoomVersionId, }; use serde::{de::DeserializeOwned, Deserialize, Serialize}; +#[cfg(feature = "experimental-timeline")] +use tracing::{info, warn}; use wasm_bindgen::JsValue; use web_sys::IdbKeyRange; @@ -55,7 +65,7 @@ use crate::safe_encode::SafeEncode; struct StoreKeyWrapper(Vec); #[derive(Debug, thiserror::Error)] -pub enum SerializationError { +pub enum IndexeddbStoreError { #[error(transparent)] Json(#[from] serde_json::Error), #[error(transparent)] @@ -64,11 +74,28 @@ pub enum SerializationError { DomException { name: String, message: String, code: u16 }, #[error(transparent)] StoreError(#[from] StoreError), + #[error("Can't migrate {name} from {old_version} to {new_version} without deleting data. See MigrationConflictStrategy for ways to configure.")] + MigrationConflict { name: String, old_version: f64, new_version: f64 }, } -impl From for SerializationError { - fn from(frm: indexed_db_futures::web_sys::DomException) -> SerializationError { - SerializationError::DomException { +/// Sometimes Migrations can't proceed without having to drop existing +/// data. This allows you to configure, how these cases should be handled. +#[allow(dead_code)] +#[derive(PartialEq, Clone, Debug)] +pub enum MigrationConflictStrategy { + /// Just drop the data, we don't care that we have to sync again + Drop, + /// Raise a `IndexedDBStore::MigrationConflict` error with the path to the + /// DB in question. The caller then has to take care about what they want + /// to do and try again after. + Raise, + /// Default. + BackupAndDrop, +} + +impl From for IndexeddbStoreError { + fn from(frm: indexed_db_futures::web_sys::DomException) -> IndexeddbStoreError { + IndexeddbStoreError::DomException { name: frm.name(), message: frm.message(), code: frm.code(), @@ -76,22 +103,20 @@ impl From for SerializationError { } } -impl From for StoreError { - fn from(e: SerializationError) -> Self { +impl From for StoreError { + fn from(e: IndexeddbStoreError) -> Self { match e { - SerializationError::Json(e) => StoreError::Json(e), - SerializationError::StoreError(e) => e, - SerializationError::Encryption(e) => match e { + IndexeddbStoreError::Json(e) => StoreError::Json(e), + IndexeddbStoreError::StoreError(e) => e, + IndexeddbStoreError::Encryption(e) => match e { EncryptionError::Random(e) => StoreError::Encryption(e.to_string()), EncryptionError::Serialization(e) => StoreError::Json(e), EncryptionError::Encryption(e) => StoreError::Encryption(e.to_string()), EncryptionError::Version(found, expected) => StoreError::Encryption(format!( - "Bad Database Encryption Version: expected {} found {}", - expected, found + "Bad Database Encryption Version: expected {expected}, found {found}", )), EncryptionError::Length(found, expected) => StoreError::Encryption(format!( - "The database key an invalid length: expected {} found {}", - expected, found + "The database key an invalid length: expected {expected}, found {found}", )), }, _ => StoreError::backend(e), @@ -103,6 +128,12 @@ impl From for StoreError { mod KEYS { // STORES + pub const CURRENT_DB_VERSION: f64 = 1.1; + pub const CURRENT_META_DB_VERSION: f64 = 2.0; + + pub const INTERNAL_STATE: &str = "matrix-sdk-state"; + pub const BACKUPS_META: &str = "backups"; + pub const SESSION: &str = "session"; pub const ACCOUNT_DATA: &str = "account_data"; @@ -137,16 +168,247 @@ mod KEYS { pub const CUSTOM: &str = "custom"; + pub const SYNC_TOKEN: &str = "sync_token"; + + /// All names of the state stores for convenience. + pub const ALL_STORES: &[&str] = &[ + SESSION, + ACCOUNT_DATA, + MEMBERS, + PROFILES, + DISPLAY_NAMES, + JOINED_USER_IDS, + INVITED_USER_IDS, + ROOM_STATE, + ROOM_INFOS, + PRESENCE, + ROOM_ACCOUNT_DATA, + STRIPPED_ROOM_INFOS, + STRIPPED_MEMBERS, + STRIPPED_ROOM_STATE, + STRIPPED_JOINED_USER_IDS, + STRIPPED_INVITED_USER_IDS, + ROOM_USER_RECEIPTS, + ROOM_EVENT_RECEIPTS, + MEDIA, + CUSTOM, + SYNC_TOKEN, + #[cfg(feature = "experimental-timeline")] + ROOM_TIMELINE, + #[cfg(feature = "experimental-timeline")] + ROOM_TIMELINE_METADATA, + #[cfg(feature = "experimental-timeline")] + ROOM_EVENT_ID_TO_POSITION, + ]; + // static keys pub const STORE_KEY: &str = "store_key"; pub const FILTER: &str = "filter"; - pub const SYNC_TOKEN: &str = "sync_token"; +} + +pub use KEYS::ALL_STORES; + +fn drop_stores(db: &IdbDatabase) -> Result<(), JsValue> { + for name in ALL_STORES { + db.delete_object_store(name)?; + } + Ok(()) +} + +fn create_stores(db: &IdbDatabase) -> Result<(), JsValue> { + for name in ALL_STORES { + db.create_object_store(name)?; + } + Ok(()) +} + +async fn backup(source: &IdbDatabase, meta: &IdbDatabase) -> Result<()> { + let now = JsDate::now(); + let backup_name = format!("backup-{}-{}", source.name(), now); + + let mut db_req: OpenDbRequest = IdbDatabase::open_f64(&backup_name, source.version())?; + db_req.set_on_upgrade_needed(Some(move |evt: &IdbVersionChangeEvent| -> Result<(), JsValue> { + // migrating to version 1 + let db = evt.db(); + for name in ALL_STORES { + db.create_object_store(name)?; + } + Ok(()) + })); + let target = db_req.into_future().await?; + + for name in ALL_STORES { + let tx = target.transaction_on_one_with_mode(name, IdbTransactionMode::Readwrite)?; + + let obj = tx.object_store(name)?; + + if let Some(curs) = source + .transaction_on_one_with_mode(name, IdbTransactionMode::Readonly)? + .object_store(name)? + .open_cursor()? + .await? + { + while let Some(key) = curs.key() { + obj.put_key_val(&key, &curs.value())?; + + curs.continue_cursor()?.await?; + } + } + + tx.await.into_result()?; + } + + let tx = + meta.transaction_on_one_with_mode(KEYS::BACKUPS_META, IdbTransactionMode::Readwrite)?; + let backup_store = tx.object_store(KEYS::BACKUPS_META)?; + backup_store.put_key_val(&JsValue::from_f64(now), &JsValue::from_str(&backup_name))?; + + tx.await; + + Ok(()) +} + +#[derive(Builder, Debug, PartialEq)] +#[builder(name = "IndexeddbStoreBuilder", build_fn(skip))] +pub struct IndexeddbStoreBuilderConfig { + /// The name for the indexeddb store to use, `state` is none given + name: String, + /// The password the indexeddb should be encrypted with. If not given, the + /// DB is not encrypted + passphrase: String, + /// The strategy to use when a merge conflict is found, see + /// [`MigrationConflictStrategy`] for details + #[builder(default = "MigrationConflictStrategy::BackupAndDrop")] + migration_conflict_strategy: MigrationConflictStrategy, +} + +impl IndexeddbStoreBuilder { + pub async fn build(&mut self) -> Result { + let migration_strategy = self + .migration_conflict_strategy + .clone() + .unwrap_or(MigrationConflictStrategy::BackupAndDrop); + let name = self.name.clone().unwrap_or_else(|| "state".to_owned()); + + let meta_name = format!("{}::{}", name, KEYS::INTERNAL_STATE); + + let mut db_req: OpenDbRequest = + IdbDatabase::open_f64(&meta_name, KEYS::CURRENT_META_DB_VERSION)?; + db_req.set_on_upgrade_needed(Some(|evt: &IdbVersionChangeEvent| -> Result<(), JsValue> { + let db = evt.db(); + if evt.old_version() < 1.0 { + // migrating to version 1 + + db.create_object_store(KEYS::INTERNAL_STATE)?; + db.create_object_store(KEYS::BACKUPS_META)?; + } else if evt.old_version() < 2.0 { + db.create_object_store(KEYS::BACKUPS_META)?; + } + Ok(()) + })); + + let meta_db: IdbDatabase = db_req.into_future().await?; + + let store_cipher = if let Some(passphrase) = &self.passphrase { + let tx: IdbTransaction<'_> = meta_db.transaction_on_one_with_mode( + KEYS::INTERNAL_STATE, + IdbTransactionMode::Readwrite, + )?; + let ob = tx.object_store(KEYS::INTERNAL_STATE)?; + + let cipher = if let Some(StoreKeyWrapper(inner)) = ob + .get(&JsValue::from_str(KEYS::STORE_KEY))? + .await? + .map(|v| v.into_serde()) + .transpose()? + { + StoreCipher::import(passphrase, &inner)? + } else { + let cipher = StoreCipher::new()?; + ob.put_key_val( + &JsValue::from_str(KEYS::STORE_KEY), + &JsValue::from_serde(&StoreKeyWrapper(cipher.export(passphrase)?))?, + )?; + cipher + }; + + tx.await.into_result()?; + Some(Arc::new(cipher)) + } else { + None + }; + + let recreate_stores = { + // checkup up in a separate call, whether we have to backup or do anything else + // to the db. Unfortunately the set_on_upgrade_needed doesn't allow async fn + // which we need to execute the backup. + let has_store_cipher = store_cipher.is_some(); + let mut db_req: OpenDbRequest = IdbDatabase::open_f64(&name, 1.0)?; + let created = Arc::new(AtomicBool::new(false)); + let created_inner = created.clone(); + + db_req.set_on_upgrade_needed(Some( + move |evt: &IdbVersionChangeEvent| -> Result<(), JsValue> { + // in case this is a fresh db, we dont't want to trigger + // further migrations other than just creating the full + // schema. + if evt.old_version() < 1.0 { + create_stores(evt.db())?; + created_inner.store(true, Ordering::Relaxed); + } + Ok(()) + }, + )); + + let pre_db = db_req.into_future().await?; + let old_version = pre_db.version(); + + if created.load(Ordering::Relaxed) { + // this is a fresh DB, return + false + } else if old_version == 1.0 && has_store_cipher { + match migration_strategy { + MigrationConflictStrategy::BackupAndDrop => { + backup(&pre_db, &meta_db).await?; + true + } + MigrationConflictStrategy::Drop => true, + MigrationConflictStrategy::Raise => { + return Err(IndexeddbStoreError::MigrationConflict { + name, + old_version, + new_version: KEYS::CURRENT_DB_VERSION, + }) + } + } + } else { + // Nothing to be done + false + } + }; + + let mut db_req: OpenDbRequest = IdbDatabase::open_f64(&name, KEYS::CURRENT_DB_VERSION)?; + db_req.set_on_upgrade_needed(Some( + move |evt: &IdbVersionChangeEvent| -> Result<(), JsValue> { + // changing the format can only happen in the upgrade procedure + if recreate_stores { + drop_stores(evt.db())?; + create_stores(evt.db())?; + } + Ok(()) + }, + )); + + let db = db_req.into_future().await?; + Ok(IndexeddbStore { name, inner: db, meta: meta_db, store_cipher }) + } } pub struct IndexeddbStore { name: String, pub(crate) inner: IdbDatabase, + pub(crate) meta: IdbDatabase, pub(crate) store_cipher: Option>, } @@ -156,118 +418,65 @@ impl std::fmt::Debug for IndexeddbStore { } } -type Result = std::result::Result; +type Result = std::result::Result; impl IndexeddbStore { - async fn open_helper(name: String, store_cipher: Option>) -> Result { - // Open my_db v1 - let mut db_req: OpenDbRequest = IdbDatabase::open_f64(&name, 1.0)?; - db_req.set_on_upgrade_needed(Some(|evt: &IdbVersionChangeEvent| -> Result<(), JsValue> { - if evt.old_version() < 1.0 { - // migrating to version 1 - let db = evt.db(); + /// Generate a IndexeddbStoreBuilder with default parameters + pub fn builder() -> IndexeddbStoreBuilder { + IndexeddbStoreBuilder::default() + } - db.create_object_store(KEYS::SESSION)?; - db.create_object_store(KEYS::SYNC_TOKEN)?; - db.create_object_store(KEYS::ACCOUNT_DATA)?; + /// Whether this database has any migration backups + pub async fn has_backups(&self) -> Result { + Ok(self + .meta + .transaction_on_one_with_mode(KEYS::BACKUPS_META, IdbTransactionMode::Readonly)? + .object_store(KEYS::BACKUPS_META)? + .count()? + .await? + > 0) + } - db.create_object_store(KEYS::MEMBERS)?; - db.create_object_store(KEYS::PROFILES)?; - db.create_object_store(KEYS::DISPLAY_NAMES)?; - db.create_object_store(KEYS::JOINED_USER_IDS)?; - db.create_object_store(KEYS::INVITED_USER_IDS)?; - - db.create_object_store(KEYS::ROOM_STATE)?; - db.create_object_store(KEYS::ROOM_INFOS)?; - db.create_object_store(KEYS::PRESENCE)?; - db.create_object_store(KEYS::ROOM_ACCOUNT_DATA)?; - - db.create_object_store(KEYS::STRIPPED_ROOM_INFOS)?; - db.create_object_store(KEYS::STRIPPED_MEMBERS)?; - db.create_object_store(KEYS::STRIPPED_ROOM_STATE)?; - db.create_object_store(KEYS::STRIPPED_JOINED_USER_IDS)?; - db.create_object_store(KEYS::STRIPPED_INVITED_USER_IDS)?; - - db.create_object_store(KEYS::ROOM_USER_RECEIPTS)?; - db.create_object_store(KEYS::ROOM_EVENT_RECEIPTS)?; - - #[cfg(feature = "experimental-timeline")] - { - db.create_object_store(KEYS::ROOM_TIMELINE)?; - db.create_object_store(KEYS::ROOM_TIMELINE_METADATA)?; - db.create_object_store(KEYS::ROOM_EVENT_ID_TO_POSITION)?; - } - - db.create_object_store(KEYS::MEDIA)?; - - db.create_object_store(KEYS::CUSTOM)?; - } - Ok(()) - })); - - let db: IdbDatabase = db_req.into_future().await?; - - Ok(Self { name, inner: db, store_cipher }) + /// What's the database name of the latest backup< + pub async fn latest_backup(&self) -> Result> { + Ok(self + .meta + .transaction_on_one_with_mode(KEYS::BACKUPS_META, IdbTransactionMode::Readonly)? + .object_store(KEYS::BACKUPS_META)? + .open_cursor_with_direction(indexed_db_futures::prelude::IdbCursorDirection::Prev)? + .await? + .and_then(|c| c.value().as_string())) } #[allow(dead_code)] + #[deprecated(note = "Use IndexeddbStoreBuilder instead.")] pub async fn open() -> StoreResult { - Ok(IndexeddbStore::open_helper("state".to_owned(), None).await?) + IndexeddbStore::builder() + .name("state".to_owned()) + .build() + .await + .map_err(StoreError::backend) } + #[deprecated(note = "Use IndexeddbStoreBuilder instead.")] pub async fn open_with_passphrase(name: String, passphrase: &str) -> StoreResult { - Ok(Self::inner_open_with_passphrase(name, passphrase).await?) - } - - pub(crate) async fn inner_open_with_passphrase(name: String, passphrase: &str) -> Result { - let name = format!("{:0}::matrix-sdk-state", name); - - let mut db_req: OpenDbRequest = IdbDatabase::open_u32(&name, 1)?; - db_req.set_on_upgrade_needed(Some(|evt: &IdbVersionChangeEvent| -> Result<(), JsValue> { - if evt.old_version() < 1.0 { - // migrating to version 1 - let db = evt.db(); - - db.create_object_store("matrix-sdk-state")?; - } - Ok(()) - })); - - let db: IdbDatabase = db_req.into_future().await?; - - let tx: IdbTransaction<'_> = - db.transaction_on_one_with_mode("matrix-sdk-state", IdbTransactionMode::Readwrite)?; - let ob = tx.object_store("matrix-sdk-state")?; - - let cipher = if let Some(StoreKeyWrapper(inner)) = ob - .get(&JsValue::from_str(KEYS::STORE_KEY))? - .await? - .map(|v| v.into_serde()) - .transpose()? - { - StoreCipher::import(passphrase, &inner)? - } else { - let cipher = StoreCipher::new()?; - ob.put_key_val( - &JsValue::from_str(KEYS::STORE_KEY), - &JsValue::from_serde(&StoreKeyWrapper(cipher.export(passphrase)?))?, - )?; - cipher - }; - - tx.await.into_result()?; - - IndexeddbStore::open_helper(name, Some(cipher.into())).await + IndexeddbStore::builder() + .name(name) + .passphrase(passphrase.to_owned()) + .build() + .await + .map_err(StoreError::backend) } + #[deprecated(note = "Use IndexeddbStoreBuilder instead.")] pub async fn open_with_name(name: String) -> StoreResult { - Ok(IndexeddbStore::open_helper(name, None).await?) + IndexeddbStore::builder().name(name).build().await.map_err(StoreError::backend) } fn serialize_event( &self, event: &impl Serialize, - ) -> std::result::Result { + ) -> std::result::Result { Ok(match &self.store_cipher { Some(cipher) => JsValue::from_serde(&cipher.encrypt_value_typed(event)?)?, None => JsValue::from_serde(event)?, @@ -277,7 +486,7 @@ impl IndexeddbStore { fn deserialize_event( &self, event: JsValue, - ) -> std::result::Result { + ) -> std::result::Result { match &self.store_cipher { Some(cipher) => Ok(cipher.decrypt_value_typed(event.into_serde()?)?), None => Ok(event.into_serde()?), @@ -298,7 +507,7 @@ impl IndexeddbStore { &self, table_name: &str, key: T, - ) -> Result + ) -> Result where T: SafeEncode, { @@ -306,7 +515,7 @@ impl IndexeddbStore { Some(cipher) => key.encode_to_range_secure(table_name, cipher), None => key.encode_to_range(), } - .map_err(|e| SerializationError::StoreError(StoreError::Backend(anyhow!(e).into()))) + .map_err(|e| IndexeddbStoreError::StoreError(StoreError::Backend(anyhow!(e).into()))) } #[cfg(feature = "experimental-timeline")] @@ -644,17 +853,14 @@ impl IndexeddbStore { for (room_id, timeline) in &changes.timeline { if timeline.sync { - tracing::info!("Save new timeline batch from sync response for {}", room_id); + info!(%room_id, "Saving new timeline batch from sync response"); } else { - tracing::info!( - "Save new timeline batch from messages response for {}", - room_id - ); + info!(%room_id, "Saving new timeline batch from messages response"); } let metadata: Option = if timeline.limited { - tracing::info!( - "Delete stored timeline for {} because the sync response was limited", - room_id + info!( + %room_id, + "Deleting stored timeline because the sync response was limited", ); let stores = &[ @@ -681,7 +887,7 @@ impl IndexeddbStore { // This should only happen when a developer adds a wrong timeline // batch to the `StateChanges` or the server returns a wrong response // to our request. - tracing::warn!("Drop unexpected timeline batch for {}", room_id); + warn!(%room_id, "Dropping unexpected timeline batch"); return Ok(()); } @@ -705,9 +911,9 @@ impl IndexeddbStore { } if delete_timeline { - tracing::info!( - "Delete stored timeline for {} because of duplicated events", - room_id + info!( + %room_id, + "Deleting stored timeline because of duplicated events", ); let stores = &[ @@ -754,9 +960,8 @@ impl IndexeddbStore { .transpose()? .and_then(|info| info.room_version().cloned()) .unwrap_or_else(|| { - tracing::warn!( - "Unable to find the room version for {}, assume version 9", - room_id + warn!( + "Unable to find the room version for {room_id}, assume version 9", ); RoomVersionId::V9 }); @@ -838,7 +1043,7 @@ impl IndexeddbStore { } } - tx.await.into_result().map_err::(|e| e.into()) + tx.await.into_result().map_err::(|e| e.into()) } pub async fn get_presence_event(&self, user_id: &UserId) -> Result>> { @@ -1163,7 +1368,7 @@ impl IndexeddbStore { tx.object_store(KEYS::CUSTOM)?.put_key_val(&jskey, &self.serialize_event(&value)?)?; - tx.await.into_result().map_err::(|e| e.into())?; + tx.await.into_result().map_err::(|e| e.into())?; Ok(prev) } @@ -1236,7 +1441,7 @@ impl IndexeddbStore { store.delete(&key)?; } } - tx.await.into_result().map_err::(|e| e.into()) + tx.await.into_result().map_err::(|e| e.into()) } #[cfg(feature = "experimental-timeline")] @@ -1259,7 +1464,7 @@ impl IndexeddbStore { { Some(tl) => tl, _ => { - tracing::info!("No timeline for {} was previously stored", room_id); + info!(%room_id, "Couldn't find a previously stored timeline"); return Ok(None); } }; @@ -1275,11 +1480,7 @@ impl IndexeddbStore { let stream = Box::pin(stream::iter(timeline.into_iter())); - tracing::info!( - "Found previously stored timeline for {}, with end token {:?}", - room_id, - end_token - ); + info!(%room_id, ?end_token, "Found previously stored timeline"); Ok(Some((stream, end_token))) } @@ -1475,8 +1676,8 @@ mod tests { use super::{IndexeddbStore, Result}; async fn get_store() -> Result { - let db_name = format!("test-state-plain-{}", Uuid::new_v4().as_hyphenated().to_string()); - Ok(IndexeddbStore::open_helper(db_name, None).await?) + let db_name = format!("test-state-plain-{}", Uuid::new_v4().as_hyphenated()); + Ok(IndexeddbStore::builder().name(db_name).build().await?) } statestore_integration_tests! { integration } @@ -1487,19 +1688,135 @@ mod encrypted_tests { #[cfg(target_arch = "wasm32")] wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser); - use std::sync::Arc; - use matrix_sdk_base::statestore_integration_tests; use uuid::Uuid; - use super::{IndexeddbStore, Result, StoreCipher}; + use super::{IndexeddbStore, Result}; async fn get_store() -> Result { - let db_name = - format!("test-state-encrypted-{}", Uuid::new_v4().as_hyphenated().to_string()); - let key = StoreCipher::new()?; - Ok(IndexeddbStore::open_helper(db_name, Some(Arc::new(key))).await?) + let db_name = format!("test-state-encrypted-{}", Uuid::new_v4().as_hyphenated()); + let passphrase = format!("some_passphrase-{}", Uuid::new_v4().as_hyphenated()); + Ok(IndexeddbStore::builder().name(db_name).passphrase(passphrase).build().await?) } statestore_integration_tests! { integration } } + +#[cfg(test)] +mod migration_tests { + #[cfg(target_arch = "wasm32")] + wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser); + use indexed_db_futures::prelude::*; + use matrix_sdk_test::async_test; + use uuid::Uuid; + use wasm_bindgen::JsValue; + + use super::{ + IndexeddbStore, IndexeddbStoreError, MigrationConflictStrategy, Result, ALL_STORES, + }; + + pub async fn create_fake_db(name: &str, version: f64) -> Result<()> { + let mut db_req: OpenDbRequest = IdbDatabase::open_f64(name, version)?; + db_req.set_on_upgrade_needed(Some( + move |evt: &IdbVersionChangeEvent| -> Result<(), JsValue> { + // migrating to version 1 + let db = evt.db(); + for name in ALL_STORES { + db.create_object_store(name)?; + } + Ok(()) + }, + )); + db_req.into_future().await?; + Ok(()) + } + #[async_test] + pub async fn test_no_upgrade() -> Result<()> { + let name = format!("simple-1.1-no-cipher-{}", Uuid::new_v4().as_hyphenated().to_string()); + + // this transparently migrates to the latest version + let store = IndexeddbStore::builder().name(name).build().await?; + // this didn't create any backup + assert_eq!(store.has_backups().await?, false); + // simple check that the layout exists. + assert_eq!(store.get_sync_token().await?, None); + Ok(()) + } + + #[async_test] + pub async fn test_migrating_v1_to_1_1_plain() -> Result<()> { + let name = + format!("migrating-1.1-no-cipher-{}", Uuid::new_v4().as_hyphenated().to_string()); + create_fake_db(&name, 1.0).await?; + + // this transparently migrates to the latest version + let store = IndexeddbStore::builder().name(name).build().await?; + // this didn't create any backup + assert_eq!(store.has_backups().await?, false); + assert_eq!(store.get_sync_token().await?, None); + Ok(()) + } + + #[async_test] + pub async fn test_migrating_v1_to_1_1_with_pw() -> Result<()> { + let name = + format!("migrating-1.1-with-cipher-{}", Uuid::new_v4().as_hyphenated().to_string()); + let passphrase = "somepassphrase".to_owned(); + create_fake_db(&name, 1.0).await?; + + // this transparently migrates to the latest version + let store = IndexeddbStore::builder().name(name).passphrase(passphrase).build().await?; + // this creates a backup by default + assert_eq!(store.has_backups().await?, true); + assert!(store.latest_backup().await?.is_some(), "No backup_found"); + assert_eq!(store.get_sync_token().await?, None); + Ok(()) + } + + #[async_test] + pub async fn test_migrating_v1_to_1_1_with_pw_drops() -> Result<()> { + let name = format!( + "migrating-1.1-with-cipher-drops-{}", + Uuid::new_v4().as_hyphenated().to_string() + ); + let passphrase = "some-other-passphrase".to_owned(); + create_fake_db(&name, 1.0).await?; + + // this transparently migrates to the latest version + let store = IndexeddbStore::builder() + .name(name) + .passphrase(passphrase) + .migration_conflict_strategy(MigrationConflictStrategy::Drop) + .build() + .await?; + // this creates a backup by default + assert_eq!(store.has_backups().await?, false); + assert_eq!(store.get_sync_token().await?, None); + Ok(()) + } + + #[async_test] + pub async fn test_migrating_v1_to_1_1_with_pw_raise() -> Result<()> { + let name = format!( + "migrating-1.1-with-cipher-raises-{}", + Uuid::new_v4().as_hyphenated().to_string() + ); + let passphrase = "some-other-passphrase".to_owned(); + create_fake_db(&name, 1.0).await?; + + // this transparently migrates to the latest version + let store_res = IndexeddbStore::builder() + .name(name) + .passphrase(passphrase) + .migration_conflict_strategy(MigrationConflictStrategy::Raise) + .build() + .await; + + if let Err(IndexeddbStoreError::MigrationConflict { .. }) = store_res { + // all fine! + } else { + assert!(false, "Conflict didn't raise: {:?}", store_res) + } + Ok(()) + } +} diff --git a/crates/matrix-sdk-qrcode/Cargo.toml b/crates/matrix-sdk-qrcode/Cargo.toml index 4a05c2e75..b26375e68 100644 --- a/crates/matrix-sdk-qrcode/Cargo.toml +++ b/crates/matrix-sdk-qrcode/Cargo.toml @@ -30,4 +30,4 @@ thiserror = "1.0.30" [dependencies.vodozemac] git = "https://github.com/matrix-org/vodozemac/" -rev = "2404f83f7d3a3779c1f518e4d949f7da9677c3dd" +rev = "18bcbc3359298894415931547ea41abb75af2d4a" diff --git a/crates/matrix-sdk-qrcode/src/lib.rs b/crates/matrix-sdk-qrcode/src/lib.rs index 28fbe9d59..a9f261f60 100644 --- a/crates/matrix-sdk-qrcode/src/lib.rs +++ b/crates/matrix-sdk-qrcode/src/lib.rs @@ -33,7 +33,7 @@ pub use types::{ #[cfg(test)] mod tests { #[cfg(feature = "decode_image")] - use std::{convert::TryFrom, io::Cursor}; + use std::io::Cursor; #[cfg(feature = "decode_image")] use image::{ImageFormat, Luma}; diff --git a/crates/matrix-sdk-qrcode/src/utils.rs b/crates/matrix-sdk-qrcode/src/utils.rs index 14812cabc..5c819b259 100644 --- a/crates/matrix-sdk-qrcode/src/utils.rs +++ b/crates/matrix-sdk-qrcode/src/utils.rs @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::convert::TryInto; - #[cfg(feature = "decode_image")] use image::{GenericImage, GenericImageView, Luma}; use qrcode::{bits::Bits, EcLevel, QrCode, Version}; diff --git a/crates/matrix-sdk-sled/Cargo.toml b/crates/matrix-sdk-sled/Cargo.toml index 357479128..63ae738a3 100644 --- a/crates/matrix-sdk-sled/Cargo.toml +++ b/crates/matrix-sdk-sled/Cargo.toml @@ -31,6 +31,8 @@ experimental-timeline = [ async-stream = "0.3.3" async-trait = "0.1.53" dashmap = "5.2.0" +derive_builder = "0.11.2" +fs_extra = "1.2.0" futures-core = "0.3.21" futures-util = { version = "0.3.21", default-features = false } matrix-sdk-base = { version = "0.5.0", path = "../matrix-sdk-base", optional = true } @@ -46,6 +48,7 @@ tokio = { version = "1.17.0", default-features = false, features = ["sync", "fs" tracing = "0.1.34" [dev-dependencies] +glob = "0.3.0" matrix-sdk-base = { path = "../matrix-sdk-base", features = ["testing"] } matrix-sdk-crypto = { path = "../matrix-sdk-crypto", features = ["testing"] } matrix-sdk-test = { path = "../../testing/matrix-sdk-test" } diff --git a/crates/matrix-sdk-sled/src/cryptostore.rs b/crates/matrix-sdk-sled/src/cryptostore.rs index 871c2d2e0..2b936e76d 100644 --- a/crates/matrix-sdk-sled/src/cryptostore.rs +++ b/crates/matrix-sdk-sled/src/cryptostore.rs @@ -14,7 +14,6 @@ use std::{ collections::{HashMap, HashSet}, - convert::TryInto, path::{Path, PathBuf}, sync::{Arc, RwLock}, }; diff --git a/crates/matrix-sdk-sled/src/lib.rs b/crates/matrix-sdk-sled/src/lib.rs index a10cd73b7..95108ffc7 100644 --- a/crates/matrix-sdk-sled/src/lib.rs +++ b/crates/matrix-sdk-sled/src/lib.rs @@ -16,7 +16,9 @@ mod state_store; #[cfg(feature = "crypto-store")] pub use cryptostore::SledStore as CryptoStore; #[cfg(feature = "state-store")] -pub use state_store::SledStore as StateStore; +pub use state_store::{ + MigrationConflictStrategy, SledStore as StateStore, SledStoreBuilder as StateStoreBuilder, +}; /// All the errors that can occur when opening a sled store. #[derive(Error, Debug)] @@ -38,8 +40,10 @@ pub enum OpenStoreError { } /// Create a [`StoreConfig`] with an opened sled [`StateStore`] that uses the -/// given path and passphrase. If `encryption` is enabled, a [`CryptoStore`] -/// with the same parameters is also opened. +/// given path and passphrase. +/// +/// If the `e2e-encryption` Cargo feature is enabled, a [`CryptoStore`] with the +/// same parameters is also opened. /// /// [`StoreConfig`]: #StoreConfig #[cfg(any(feature = "state-store", feature = "crypto-store"))] @@ -61,11 +65,13 @@ pub fn make_store_config( #[cfg(not(feature = "crypto-store"))] { - let state_store = if let Some(passphrase) = passphrase { - StateStore::open_with_passphrase(path, passphrase)? - } else { - StateStore::open_with_path(path)? + let mut store_builder = StateStore::builder(); + store_builder.path(path.as_ref().to_path_buf()); + + if let Some(passphrase) = passphrase { + store_builder.passphrase(passphrase.to_owned()); }; + let state_store = store_builder.build().map_err(StoreError::backend)?; Ok(StoreConfig::new().state_store(state_store)) } @@ -78,13 +84,14 @@ fn open_stores_with_path( path: impl AsRef, passphrase: Option<&str>, ) -> Result<(StateStore, CryptoStore), OpenStoreError> { + let mut store_builder = StateStore::builder(); + store_builder.path(path.as_ref().to_path_buf()); + if let Some(passphrase) = passphrase { - let state_store = StateStore::open_with_passphrase(path, passphrase)?; - let crypto_store = state_store.open_crypto_store()?; - Ok((state_store, crypto_store)) - } else { - let state_store = StateStore::open_with_path(path)?; - let crypto_store = state_store.open_crypto_store()?; - Ok((state_store, crypto_store)) - } + store_builder.passphrase(passphrase.to_owned()); + }; + + let state_store = store_builder.build().map_err(StoreError::backend)?; + let crypto_store = state_store.open_crypto_store()?; + Ok((state_store, crypto_store)) } diff --git a/crates/matrix-sdk-sled/src/state_store.rs b/crates/matrix-sdk-sled/src/state_store.rs index c7db5fb63..0d1797550 100644 --- a/crates/matrix-sdk-sled/src/state_store.rs +++ b/crates/matrix-sdk-sled/src/state_store.rs @@ -16,12 +16,13 @@ use std::{ collections::BTreeSet, path::{Path, PathBuf}, sync::Arc, - time::Instant, + time::{Instant, SystemTime, UNIX_EPOCH}, }; #[cfg(feature = "experimental-timeline")] use async_stream::stream; use async_trait::async_trait; +use derive_builder::Builder; use futures_core::stream::Stream; use futures_util::stream::{self, StreamExt, TryStreamExt}; use matrix_sdk_base::{ @@ -58,6 +59,7 @@ use sled::{ Config, Db, Transactional, Tree, }; use tokio::task::spawn_blocking; +use tracing::{debug, info}; #[cfg(feature = "crypto-store")] use super::OpenStoreError; @@ -81,6 +83,28 @@ pub enum SledStoreError { Identifier(#[from] IdParseError), #[error(transparent)] Task(#[from] tokio::task::JoinError), + #[error(transparent)] + Io(#[from] std::io::Error), + #[error(transparent)] + FsExtra(#[from] fs_extra::error::Error), + #[error("Can't migrate {path} from {old_version} to {new_version} without deleting data. See MigrationConflictStrategy for ways to configure.")] + MigrationConflict { path: PathBuf, old_version: usize, new_version: usize }, +} + +/// Sometimes Migrations can't proceed without having to drop existing +/// data. This allows you to configure, how these cases should be handled. +#[derive(PartialEq, Eq, Clone, Debug)] +pub enum MigrationConflictStrategy { + /// Just drop the data, we don't care that we have to sync again + Drop, + /// Raise a `SledStoreError::MigrationConflict` error with the path to the + /// DB in question. The caller then has to take care about what they want + /// to do and try again after. + Raise, + /// _Default_: The _entire_ database is backed up under + /// `$path.$timestamp.backup` (this includes the crypto store if they + /// are linked), before the state tables are dropped. + BackupAndDrop, } impl From> for SledStoreError { @@ -103,12 +127,10 @@ impl Into for SledStoreError { KeyEncryptionError::Serialization(e) => StoreError::Json(e), KeyEncryptionError::Encryption(e) => StoreError::Encryption(e.to_string()), KeyEncryptionError::Version(found, expected) => StoreError::Encryption(format!( - "Bad Database Encryption Version: expected {} found {}", - expected, found + "Bad Database Encryption Version: expected {expected}, found {found}", )), KeyEncryptionError::Length(found, expected) => StoreError::Encryption(format!( - "The database key an invalid length: expected {} found {}", - expected, found + "The database key an invalid length: expected {expected}, found {found}", )), }, SledStoreError::StoreError(e) => e, @@ -116,7 +138,7 @@ impl Into for SledStoreError { } } } -const DATABASE_VERSION: u8 = 1; +const DATABASE_VERSION: u8 = 2; const VERSION_KEY: &str = "state-store-version"; @@ -149,8 +171,135 @@ const TIMELINE_METADATA: &str = "timeline-metadata"; #[cfg(feature = "experimental-timeline")] const TIMELINE: &str = "timeline"; +const ALL_DB_STORES: &[&str] = &[ + ACCOUNT_DATA, + SYNC_TOKEN, + DISPLAY_NAME, + INVITED_USER_ID, + JOINED_USER_ID, + MEDIA, + MEMBER, + PRESENCE, + PROFILE, + ROOM_ACCOUNT_DATA, + #[cfg(feature = "experimental-timeline")] + ROOM_EVENT_ID_POSITION, + ROOM_EVENT_RECEIPT, + ROOM_INFO, + ROOM_STATE, + ROOM_USER_RECEIPT, + ROOM, + SESSION, + STRIPPED_INVITED_USER_ID, + STRIPPED_JOINED_USER_ID, + STRIPPED_ROOM_INFO, + STRIPPED_ROOM_MEMBER, + STRIPPED_ROOM_STATE, + CUSTOM, + #[cfg(feature = "experimental-timeline")] + ROOM_EVENT_ID_POSITION, + #[cfg(feature = "experimental-timeline")] + TIMELINE_METADATA, + #[cfg(feature = "experimental-timeline")] + TIMELINE, +]; +const ALL_GLOBAL_KEYS: &[&str] = &[VERSION_KEY]; + type Result = std::result::Result; +#[derive(Builder, Debug, PartialEq, Eq)] +#[builder(name = "SledStoreBuilder", build_fn(skip))] +pub struct SledStoreBuilderConfig { + /// Path to the sled store files, created if not yet existing + path: PathBuf, + /// Set the password the sled store is encrypted with (if any) + passphrase: String, + /// The strategy to use when a merge conflict is found, see + /// [`MigrationConflictStrategy`] for details + #[builder(default = "MigrationConflictStrategy::BackupAndDrop")] + migration_conflict_strategy: MigrationConflictStrategy, +} + +impl SledStoreBuilder { + pub fn build(&mut self) -> Result { + let is_temp = self.path.is_none(); + + let mut cfg = Config::new().temporary(is_temp); + + let path = if let Some(path) = &self.path { + let path = path.join("matrix-sdk-state"); + + cfg = cfg.path(&path); + Some(path) + } else { + None + }; + + let db = cfg.open().map_err(StoreError::backend)?; + + let store_cipher = if let Some(passphrase) = &self.passphrase { + if let Some(inner) = db.get("store_cipher".encode())? { + Some(StoreCipher::import(passphrase, &inner)?.into()) + } else { + let cipher = StoreCipher::new()?; + db.insert("store_cipher".encode(), cipher.export(passphrase)?)?; + Some(cipher.into()) + } + } else { + None + }; + + let mut store = SledStore::open_helper(db, path, store_cipher)?; + + let migration_res = store.upgrade(); + if let Err(SledStoreError::MigrationConflict { path, .. }) = &migration_res { + // how are supposed to react about this? + match self + .migration_conflict_strategy + .as_ref() + .unwrap_or(&MigrationConflictStrategy::BackupAndDrop) + { + MigrationConflictStrategy::BackupAndDrop => { + let mut new_path = path.clone(); + new_path.set_extension(format!( + "{}.backup", + SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("time doesn't go backwards") + .as_secs() + )); + fs_extra::dir::create_all(&new_path, false)?; + fs_extra::dir::copy(path, new_path, &fs_extra::dir::CopyOptions::new())?; + store.drop_tables()?; + return self.build(); + } + MigrationConflictStrategy::Drop => { + store.drop_tables()?; + return self.build(); + } + MigrationConflictStrategy::Raise => migration_res?, + } + } else { + migration_res?; + } + + Ok(store) + } + + // testing only + #[cfg(test)] + fn build_encrypted() -> StoreResult { + let db = Config::new().temporary(true).open().map_err(StoreError::backend)?; + + SledStore::open_helper( + db, + None, + Some(StoreCipher::new().expect("can't create store cipher").into()), + ) + .map_err(|e| e.into()) + } +} + #[derive(Clone)] pub struct SledStore { path: Option, @@ -194,6 +343,7 @@ impl std::fmt::Debug for SledStore { } } +#[allow(deprecated)] impl SledStore { fn open_helper( db: Db, @@ -234,7 +384,7 @@ impl SledStore { #[cfg(feature = "experimental-timeline")] let room_event_id_to_position = db.open_tree(ROOM_EVENT_ID_POSITION)?; - let database = Self { + Ok(Self { path, inner: db, store_cipher, @@ -264,64 +414,52 @@ impl SledStore { room_timeline_metadata, #[cfg(feature = "experimental-timeline")] room_event_id_to_position, - }; - - database.upgrade()?; - Ok(database) + }) } + /// Generate a SledStoreBuilder with default parameters + pub fn builder() -> SledStoreBuilder { + SledStoreBuilder::default() + } + + #[deprecated(note = "Use SledStoreBuilder instead.")] pub fn open() -> StoreResult { - let db = Config::new().temporary(true).open().map_err(StoreError::backend)?; - - SledStore::open_helper(db, None, None).map_err(|e| e.into()) - } - - // testing only - #[cfg(test)] - fn open_encrypted() -> StoreResult { - let db = Config::new().temporary(true).open().map_err(StoreError::backend)?; - - SledStore::open_helper( - db, - None, - Some(StoreCipher::new().expect("can't create store cipher").into()), - ) - .map_err(|e| e.into()) + SledStore::builder().build().map_err(StoreError::backend) } + #[deprecated(note = "Use SledStoreBuilder instead.")] pub fn open_with_passphrase(path: impl AsRef, passphrase: &str) -> StoreResult { - Self::inner_open_with_passphrase(path, passphrase).map_err(|e| e.into()) - } - - fn inner_open_with_passphrase(path: impl AsRef, passphrase: &str) -> Result { - let path = path.as_ref().join("matrix-sdk-state"); - let db = Config::new().temporary(false).path(&path).open()?; - - let store_cipher = if let Some(inner) = db.get("store_cipher".encode())? { - StoreCipher::import(passphrase, &inner)? - } else { - let cipher = StoreCipher::new()?; - db.insert("store_cipher".encode(), cipher.export(passphrase)?)?; - cipher - } - .into(); - - SledStore::open_helper(db, Some(path), Some(store_cipher)) + SledStore::builder() + .path(path.as_ref().into()) + .passphrase(passphrase.to_owned()) + .build() + .map_err(StoreError::backend) } + #[deprecated(note = "Use SledStoreBuilder instead.")] pub fn open_with_path(path: impl AsRef) -> StoreResult { - Self::inner_open_with_path(path).map_err(|e| e.into()) + SledStore::builder().path(path.as_ref().into()).build().map_err(StoreError::backend) } - fn inner_open_with_path(path: impl AsRef) -> Result { - let path = path.as_ref().join("matrix-sdk-state"); - let db = Config::new().temporary(false).path(&path).open()?; + fn drop_tables(self) -> StoreResult<()> { + for name in ALL_DB_STORES { + self.inner.drop_tree(name).map_err(StoreError::backend)?; + } + for name in ALL_GLOBAL_KEYS { + self.inner.remove(name).map_err(StoreError::backend)?; + } - SledStore::open_helper(db, Some(path), None) + Ok(()) } - fn upgrade(&self) -> StoreResult<()> { - let db_version = self.inner.get(VERSION_KEY).map_err(StoreError::backend)?.map(|v| { + fn set_db_version(&self, version: u8) -> Result<()> { + self.inner.insert(VERSION_KEY, version.to_be_bytes().as_ref())?; + self.inner.flush()?; + Ok(()) + } + + fn upgrade(&mut self) -> Result<()> { + let db_version = self.inner.get(VERSION_KEY)?.map(|v| { let (version_bytes, _) = v.split_at(std::mem::size_of::()); u8::from_be_bytes(version_bytes.try_into().unwrap_or_default()) }); @@ -329,11 +467,7 @@ impl SledStore { let old_version = match db_version { None => { // we are fresh, let's write the current version - self.inner - .insert(VERSION_KEY, DATABASE_VERSION.to_be_bytes().as_ref()) - .map_err(StoreError::backend)?; - self.inner.flush().map_err(StoreError::backend)?; - return Ok(()); + return self.set_db_version(DATABASE_VERSION); } Some(version) if version == DATABASE_VERSION => { // current, we don't have to do anything @@ -342,16 +476,30 @@ impl SledStore { Some(version) => version, }; - tracing::debug!( - old_version, - new_version = DATABASE_VERSION, - "Upgrading the Sled state store" - ); + debug!(old_version, new_version = DATABASE_VERSION, "Upgrading the Sled state store"); + + if old_version == 1 { + if self.store_cipher.is_some() { + // we stored some fields un-encrypted. Drop them to force re-creation + return Err(SledStoreError::MigrationConflict { + path: self.path.take().expect("Path must exist for a migration to fail"), + old_version: old_version.into(), + new_version: DATABASE_VERSION.into(), + }); + } + // no migration to handle + self.set_db_version(2u8)?; + return Ok(()); + } // FUTURE UPGRADE CODE GOES HERE // can't upgrade from that version to the new one - Err(StoreError::UnsupportedDatabaseVersion(old_version.into(), DATABASE_VERSION.into())) + Err(SledStoreError::MigrationConflict { + path: self.path.take().expect("Path must exist for a migration to fail"), + old_version: old_version.into(), + new_version: DATABASE_VERSION.into(), + }) } /// Open a `CryptoStore` that uses the same database as this store. @@ -687,7 +835,7 @@ impl SledStore { self.inner.flush_async().await?; - tracing::info!("Saved changes in {:?}", now.elapsed()); + info!("Saved changes in {:?}", now.elapsed()); Ok(()) } @@ -1147,7 +1295,6 @@ impl SledStore { ) -> Result>, Option)>> { let db = self.clone(); let key = self.encode_key(TIMELINE_METADATA, room_id); - let r_id = room_id.to_owned(); let metadata: Option = db .room_timeline_metadata .get(key.as_slice())? @@ -1156,7 +1303,7 @@ impl SledStore { let metadata = match metadata { Some(m) => m, None => { - tracing::info!("No timeline for {} was previously stored", r_id); + info!(%room_id, "Couldn't find a previously stored timeline"); return Ok(None); } }; @@ -1164,16 +1311,15 @@ impl SledStore { let mut position = metadata.start_position; let end_token = metadata.end; - tracing::info!( - "Found previously stored timeline for {}, with end token {:?}", - r_id, - end_token - ); + info!(%room_id, ?end_token, "Found previously stored timeline"); + let room_id = room_id.to_owned(); let stream = stream! { - while let Ok(Some(item)) = db.room_timeline.get(&db.encode_key_with_counter(TIMELINE, &r_id, position)) { + while let Ok(Some(item)) = + db.room_timeline.get(&db.encode_key_with_counter(TIMELINE, &room_id, position)) + { position += 1; - yield db.deserialize_value(&item).map_err(SledStoreError::from).map_err(|e| e.into()); + yield db.deserialize_value(&item).map_err(|e| SledStoreError::from(e).into()); } }; @@ -1182,7 +1328,7 @@ impl SledStore { #[cfg(feature = "experimental-timeline")] async fn remove_room_timeline(&self, room_id: &RoomId) -> Result<()> { - tracing::info!("Remove stored timeline for {}", room_id); + info!(%room_id, "Removing stored timeline"); let mut timeline_batch = sled::Batch::default(); for key in self.room_timeline.scan_prefix(self.encode_key(TIMELINE, &room_id)).keys() { @@ -1219,22 +1365,21 @@ impl SledStore { #[cfg(feature = "experimental-timeline")] async fn save_room_timeline(&self, changes: &StateChanges) -> Result<()> { + use tracing::warn; + let mut timeline_batch = sled::Batch::default(); let mut event_id_to_position_batch = sled::Batch::default(); let mut timeline_metadata_batch = sled::Batch::default(); for (room_id, timeline) in &changes.timeline { if timeline.sync { - tracing::info!("Save new timeline batch from sync response for {}", room_id); + info!(%room_id, "Saving new timeline batch from sync response"); } else { - tracing::info!("Save new timeline batch from messages response for {}", room_id); + info!(%room_id, "Saving new timeline batch from messages response"); } let metadata: Option = if timeline.limited { - tracing::info!( - "Delete stored timeline for {} because the sync response was limited", - room_id - ); + info!(%room_id, "Deleting stored timeline because the sync response was limited"); self.remove_room_timeline(room_id).await?; None } else { @@ -1248,7 +1393,7 @@ impl SledStore { // This should only happen when a developer adds a wrong timeline // batch to the `StateChanges` or the server returns a wrong response // to our request. - tracing::warn!("Drop unexpected timeline batch for {}", room_id); + warn!(%room_id, "Dropping unexpected timeline batch"); return Ok(()); } @@ -1266,10 +1411,7 @@ impl SledStore { } if delete_timeline { - tracing::info!( - "Delete stored timeline for {} because of duplicated events", - room_id - ); + info!(%room_id, "Deleting stored timeline because of duplicated events"); self.remove_room_timeline(room_id).await?; None } else if timeline.sync { @@ -1301,10 +1443,7 @@ impl SledStore { .transpose()? .and_then(|info| info.room_version().cloned()) .unwrap_or_else(|| { - tracing::warn!( - "Unable to find the room version for {}, assume version 9", - room_id - ); + warn!(%room_id, "Unable to find the room version, assume version 9"); RoomVersionId::V9 }); @@ -1589,7 +1728,7 @@ mod tests { use super::{SledStore, StateStore, StoreResult}; async fn get_store() -> StoreResult { - SledStore::open().map_err(Into::into) + SledStore::builder().build().map_err(Into::into) } statestore_integration_tests! { integration } @@ -1599,11 +1738,103 @@ mod tests { mod encrypted_tests { use matrix_sdk_base::statestore_integration_tests; - use super::{SledStore, StateStore, StoreResult}; + use super::{SledStoreBuilder, StateStore, StoreResult}; async fn get_store() -> StoreResult { - SledStore::open_encrypted().map_err(Into::into) + SledStoreBuilder::build_encrypted().map_err(Into::into) } statestore_integration_tests! { integration } } + +#[cfg(test)] +mod migration { + use matrix_sdk_test::async_test; + use tempfile::TempDir; + + use super::{MigrationConflictStrategy, Result, SledStore, SledStoreError}; + + #[async_test] + pub async fn migrating_v1_to_2_plain() -> Result<()> { + let folder = TempDir::new()?; + + let store = SledStore::builder().path(folder.path().to_path_buf()).build()?; + + store.set_db_version(1u8)?; + drop(store); + + // this transparently migrates to the latest version + let _store = SledStore::builder().path(folder.path().to_path_buf()).build()?; + Ok(()) + } + + #[async_test] + pub async fn migrating_v1_to_2_with_pw_backed_up() -> Result<()> { + let folder = TempDir::new()?; + + let store = SledStore::builder() + .path(folder.path().to_path_buf()) + .passphrase("something".to_owned()) + .build()?; + + store.set_db_version(1u8)?; + drop(store); + + // this transparently creates a backup and a fresh db + let _store = SledStore::builder() + .path(folder.path().to_path_buf()) + .passphrase("something".to_owned()) + .build()?; + assert_eq!(std::fs::read_dir(folder.path())?.count(), 2); + Ok(()) + } + + #[async_test] + pub async fn migrating_v1_to_2_with_pw_drop() -> Result<()> { + let folder = TempDir::new()?; + + let store = SledStore::builder() + .path(folder.path().to_path_buf()) + .passphrase("other thing".to_owned()) + .build()?; + + store.set_db_version(1u8)?; + drop(store); + + // this transparently creates a backup and a fresh db + let _store = SledStore::builder() + .path(folder.path().to_path_buf()) + .passphrase("other thing".to_owned()) + .migration_conflict_strategy(MigrationConflictStrategy::Drop) + .build()?; + assert_eq!(std::fs::read_dir(folder.path())?.count(), 1); + Ok(()) + } + + #[async_test] + pub async fn migrating_v1_to_2_with_pw_raises() -> Result<()> { + let folder = TempDir::new()?; + + let store = SledStore::builder() + .path(folder.path().to_path_buf()) + .passphrase("secret".to_owned()) + .build()?; + + store.set_db_version(1u8)?; + drop(store); + + // this transparently creates a backup and a fresh db + let res = SledStore::builder() + .path(folder.path().to_path_buf()) + .passphrase("secret".to_owned()) + .migration_conflict_strategy(MigrationConflictStrategy::Raise) + .build(); + if let Err(SledStoreError::MigrationConflict { .. }) = res { + // all good + } else { + panic!("Didn't raise the expected error: {:?}", res); + } + assert_eq!(std::fs::read_dir(folder.path())?.count(), 1); + Ok(()) + } +} diff --git a/crates/matrix-sdk/Cargo.toml b/crates/matrix-sdk/Cargo.toml index 9da698a17..27fa828a9 100644 --- a/crates/matrix-sdk/Cargo.toml +++ b/crates/matrix-sdk/Cargo.toml @@ -19,7 +19,7 @@ rustdoc-args = ["--cfg", "docsrs"] default = [ "e2e-encryption", "sled", - "native-tls" + "native-tls", ] e2e-encryption = [ diff --git a/crates/matrix-sdk/README.md b/crates/matrix-sdk/README.md index f8e99151c..fa2a81d34 100644 --- a/crates/matrix-sdk/README.md +++ b/crates/matrix-sdk/README.md @@ -25,7 +25,6 @@ some event handlers and then syncing. This is demonstrated in the example below. ```rust,no_run -use std::convert::TryFrom; use matrix_sdk::{ Client, config::SyncSettings, ruma::{user_id, events::room::message::SyncRoomMessageEvent}, @@ -40,7 +39,7 @@ async fn main() -> anyhow::Result<()> { client.login_username(alice, "password").send().await?; client - .register_event_handler(|ev: SyncRoomMessageEvent| async move { + .add_event_handler(|ev: SyncRoomMessageEvent| async move { println!("Received a message {:?}", ev); }) .await; diff --git a/crates/matrix-sdk/examples/autojoin.rs b/crates/matrix-sdk/examples/autojoin.rs index 4968f638e..a047f246e 100644 --- a/crates/matrix-sdk/examples/autojoin.rs +++ b/crates/matrix-sdk/examples/autojoin.rs @@ -22,13 +22,13 @@ async fn on_stripped_state_member( // retry autojoin due to synapse sending invites, before the // invited user can join for more information see // https://github.com/matrix-org/synapse/issues/4345 - eprintln!("Failed to join room {} ({:?}), retrying in {}s", room.room_id(), err, delay); + eprintln!("Failed to join room {} ({err:?}), retrying in {delay}s", room.room_id()); sleep(Duration::from_secs(delay)).await; delay *= 2; if delay > 3600 { - eprintln!("Can't join room {} ({:?})", room.room_id(), err); + eprintln!("Can't join room {} ({err:?})", room.room_id()); break; } } @@ -47,16 +47,13 @@ async fn login_and_sync( #[cfg(feature = "sled")] { // The location to save files to - let mut home = dirs::home_dir().expect("no home directory found"); - home.push("autojoin_bot"); - let state_store = matrix_sdk_sled::StateStore::open_with_path(home)?; - client_builder = client_builder.state_store(state_store); + let home = dirs::home_dir().expect("no home directory found").join("autojoin_bot"); + client_builder = client_builder.sled_store(home, None)?; } #[cfg(feature = "indexeddb")] { - let state_store = matrix_sdk_indexeddb::StateStore::open(); - client_builder = client_builder.state_store(state_store); + client_builder = client_builder.indexeddb_store("autojoin_bot", None).await?; } let client = client_builder.build().await?; @@ -67,9 +64,9 @@ async fn login_and_sync( .send() .await?; - println!("logged in as {}", username); + println!("logged in as {username}"); - client.register_event_handler(on_stripped_state_member).await; + client.add_event_handler(on_stripped_state_member).await; client.sync(SyncSettings::default()).await; diff --git a/crates/matrix-sdk/examples/command_bot.rs b/crates/matrix-sdk/examples/command_bot.rs index 7308f02dc..ea0ebfa7a 100644 --- a/crates/matrix-sdk/examples/command_bot.rs +++ b/crates/matrix-sdk/examples/command_bot.rs @@ -42,16 +42,13 @@ async fn login_and_sync( #[cfg(feature = "sled")] { // The location to save files to - let mut home = dirs::home_dir().expect("no home directory found"); - home.push("party_bot"); - let state_store = matrix_sdk_sled::StateStore::open_with_path(home)?; - client_builder = client_builder.state_store(state_store); + let home = dirs::home_dir().expect("no home directory found").join("party_bot"); + client_builder = client_builder.sled_store(home, None)?; } #[cfg(feature = "indexeddb")] { - let state_store = matrix_sdk_indexeddb::StateStore::open(); - client_builder = client_builder.state_store(state_store); + client_builder = client_builder.indexeddb_store("party_bot", None).await?; } let client = client_builder.build().await.unwrap(); @@ -61,7 +58,7 @@ async fn login_and_sync( .send() .await?; - println!("logged in as {}", username); + println!("logged in as {username}"); // An initial sync to set up state and so our bot doesn't respond to old // messages. If the `StateStore` finds saved state in the location given the @@ -69,7 +66,7 @@ async fn login_and_sync( client.sync_once(SyncSettings::default()).await.unwrap(); // add our CommandBot to be notified of incoming messages, we do this after the // initial sync to avoid responding to messages before the bot was running. - client.register_event_handler(on_room_message).await; + client.add_event_handler(on_room_message).await; // since we called `sync_once` before we entered our sync loop we must pass // that sync token to `sync` diff --git a/crates/matrix-sdk/examples/image_bot.rs b/crates/matrix-sdk/examples/image_bot.rs index b6aab0931..3453adff3 100644 --- a/crates/matrix-sdk/examples/image_bot.rs +++ b/crates/matrix-sdk/examples/image_bot.rs @@ -69,7 +69,7 @@ async fn login_and_sync( client.sync_once(SyncSettings::default()).await.unwrap(); let image = Arc::new(Mutex::new(image)); - client.register_event_handler(move |ev, room| on_room_message(ev, room, image.clone())).await; + client.add_event_handler(move |ev, room| on_room_message(ev, room, image.clone())).await; let settings = SyncSettings::default().token(client.sync_token().await.unwrap()); client.sync(settings).await; @@ -92,7 +92,7 @@ async fn main() -> anyhow::Result<()> { } }; - println!("helloooo {} {} {} {:#?}", homeserver_url, username, password, image_path); + println!("helloooo {homeserver_url} {username} {password} {image_path:#?}"); let path = PathBuf::from(image_path); let image = File::open(path).expect("Can't open image file."); diff --git a/crates/matrix-sdk/examples/login.rs b/crates/matrix-sdk/examples/login.rs index c935dc206..c21a8acb4 100644 --- a/crates/matrix-sdk/examples/login.rs +++ b/crates/matrix-sdk/examples/login.rs @@ -25,7 +25,7 @@ async fn on_room_message(event: OriginalSyncRoomMessageEvent, room: Room) { { let member = room.get_member(&sender).await.unwrap().unwrap(); let name = member.display_name().unwrap_or_else(|| member.user_id().as_str()); - println!("{}: {}", name, msg_body); + println!("{name}: {msg_body}"); } } } @@ -34,7 +34,7 @@ async fn login(homeserver_url: String, username: &str, password: &str) -> matrix let homeserver_url = Url::parse(&homeserver_url).expect("Couldn't parse the homeserver URL"); let client = Client::new(homeserver_url).await.unwrap(); - client.register_event_handler(on_room_message).await; + client.add_event_handler(on_room_message).await; client .login_username(username, password) diff --git a/crates/matrix-sdk/examples/timeline.rs b/crates/matrix-sdk/examples/timeline.rs index eb1c7b39d..ccb759d0b 100644 --- a/crates/matrix-sdk/examples/timeline.rs +++ b/crates/matrix-sdk/examples/timeline.rs @@ -54,7 +54,7 @@ async fn print_timeline(room: Room) { while let Some(event) = backward_stream.next().await { let event = event.unwrap(); if let Some(content) = event_content(event.event.deserialize().unwrap()) { - println!("{}", content); + println!("{content}"); } } } diff --git a/crates/matrix-sdk/src/account.rs b/crates/matrix-sdk/src/account.rs index 97a805381..a40cf799d 100644 --- a/crates/matrix-sdk/src/account.rs +++ b/crates/matrix-sdk/src/account.rs @@ -65,7 +65,7 @@ impl Account { /// client.login(user, "password", None, None).await?; /// /// if let Some(name) = client.account().get_display_name().await? { - /// println!("Logged in as user '{}' with display name '{}'", user, name); + /// println!("Logged in as user '{user}' with display name '{name}'"); /// } /// # anyhow::Ok(()) }); /// ``` @@ -113,7 +113,7 @@ impl Account { /// client.login(user, "password", None, None).await?; /// /// if let Some(url) = client.account().get_avatar_url().await? { - /// println!("Your avatar's mxc url is {}", url); + /// println!("Your avatar's mxc url is {url}"); /// } /// # anyhow::Ok(()) }); /// ``` @@ -161,7 +161,8 @@ impl Account { /// let client = Client::new(homeserver).await?; /// client.login(user, "password", None, None).await?; /// - /// if let Some(avatar) = client.account().get_avatar(MediaFormat::File).await? { + /// if let Some(avatar) = client.account().get_avatar(MediaFormat::File).await? + /// { /// std::fs::write("avatar.png", avatar); /// } /// # anyhow::Ok(()) }); @@ -226,8 +227,7 @@ impl Account { /// if let profile = client.account().get_profile().await? { /// println!( /// "You are '{:?}' with avatar '{:?}'", - /// profile.displayname, - /// profile.avatar_url + /// profile.displayname, profile.avatar_url /// ); /// } /// # anyhow::Ok(()) }); @@ -258,7 +258,6 @@ impl Account { /// /// # Example /// ```no_run - /// # use std::convert::TryFrom; /// # use matrix_sdk::Client; /// # use matrix_sdk::ruma::{ /// # api::client::{ @@ -307,7 +306,6 @@ impl Account { /// /// # Example /// ```no_run - /// # use std::convert::TryFrom; /// # use matrix_sdk::Client; /// # use matrix_sdk::ruma::{ /// # api::client::{ @@ -360,7 +358,10 @@ impl Account { /// let threepids = client.account().get_3pids().await?.threepids; /// /// for threepid in threepids { - /// println!("Found 3PID '{}' of type '{}'", threepid.address, threepid.medium); + /// println!( + /// "Found 3PID '{}' of type '{}'", + /// threepid.address, threepid.medium + /// ); /// } /// # anyhow::Ok(()) }); /// ``` @@ -412,20 +413,15 @@ impl Account { /// # let client = Client::new(homeserver).await?; /// # let account = client.account(); /// # let secret = ClientSecret::parse("secret")?; - /// let token_response = account.request_3pid_email_token( - /// &secret, - /// "john@matrix.org", - /// uint!(0), - /// ).await?; + /// let token_response = account + /// .request_3pid_email_token(&secret, "john@matrix.org", uint!(0)) + /// .await?; /// /// // Wait for the user to confirm that the token was submitted or prompt /// // the user for the token and send it to submit_url. /// - /// let uiaa_response = account.add_3pid( - /// &secret, - /// &token_response.sid, - /// None - /// ).await; + /// let uiaa_response = + /// account.add_3pid(&secret, &token_response.sid, None).await; /// /// // Proceed with UIAA. /// @@ -493,21 +489,15 @@ impl Account { /// # let client = Client::new(homeserver).await?; /// # let account = client.account(); /// # let secret = ClientSecret::parse("secret")?; - /// let token_response = account.request_3pid_msisdn_token( - /// &secret, - /// "FR", - /// "0123456789", - /// uint!(0), - /// ).await?; + /// let token_response = account + /// .request_3pid_msisdn_token(&secret, "FR", "0123456789", uint!(0)) + /// .await?; /// /// // Wait for the user to confirm that the token was submitted or prompt /// // the user for the token and send it to submit_url. /// - /// let uiaa_response = account.add_3pid( - /// &secret, - /// &token_response.sid, - /// None - /// ).await; + /// let uiaa_response = + /// account.add_3pid(&secret, &token_response.sid, None).await; /// /// // Proceed with UIAA. /// diff --git a/crates/matrix-sdk/src/client/builder.rs b/crates/matrix-sdk/src/client/builder.rs index fa9b17c1f..f8650a277 100644 --- a/crates/matrix-sdk/src/client/builder.rs +++ b/crates/matrix-sdk/src/client/builder.rs @@ -44,17 +44,18 @@ use crate::{ /// them. /// /// ``` -/// use matrix_sdk::Client; /// use std::sync::Arc; /// +/// use matrix_sdk::Client; +/// /// // setting up a custom http client /// let reqwest_builder = reqwest::ClientBuilder::new() /// .https_only(true) /// .no_proxy() /// .user_agent("MyApp/v3.0"); /// -/// let client_builder = Client::builder() -/// .http_client(Arc::new(reqwest_builder.build()?)); +/// let client_builder = +/// Client::builder().http_client(Arc::new(reqwest_builder.build()?)); /// # anyhow::Ok(()) /// ``` #[must_use] @@ -113,7 +114,35 @@ impl ClientBuilder { self } - /// Create a new `ClientBuilder` with the given [`StoreConfig`]. + /// Set up the store configuration for a sled store. + /// + /// This is a shorthand for + /// .[store_config](Self::store_config)([matrix_sdk_sled]::[make_store_config](matrix_sdk_sled::make_store_config)(path, passphrase)?). + #[cfg(feature = "sled")] + pub fn sled_store( + self, + path: impl AsRef, + passphrase: Option<&str>, + ) -> Result { + let config = matrix_sdk_sled::make_store_config(path, passphrase)?; + Ok(self.store_config(config)) + } + + /// Set up the store configuration for a IndexedDB store. + /// + /// This is a shorthand for + /// .[store_config](Self::store_config)([matrix_sdk_indexeddb]::[make_store_config](matrix_sdk_indexeddb::make_store_config)(path, passphrase).await?). + #[cfg(feature = "indexeddb")] + pub async fn indexeddb_store( + self, + name: impl Into, + passphrase: Option<&str>, + ) -> Result { + let config = matrix_sdk_indexeddb::make_store_config(name, passphrase).await?; + Ok(self.store_config(config)) + } + + /// Set up the store configuration. /// /// The easiest way to get a [`StoreConfig`] is to use the /// [`make_store_config`] method from the [`store`] module or directly from @@ -128,7 +157,7 @@ impl ClientBuilder { /// ``` /// # use matrix_sdk_base::store::MemoryStore; /// # let custom_state_store = MemoryStore::new(); - /// use matrix_sdk::{Client, config::StoreConfig}; + /// use matrix_sdk::{config::StoreConfig, Client}; /// /// let store_config = StoreConfig::new().state_store(custom_state_store); /// let client_builder = Client::builder().store_config(store_config); @@ -143,6 +172,11 @@ impl ClientBuilder { /// Set a custom implementation of a `StateStore`. /// /// The state store should be opened before being set. + #[deprecated = "\ + Use [`store_config`](#method.store_config), \ + [`sled_store`](#method.sled_store) or \ + [`indexeddb_store`](#method.indexeddb_store) instead + "] pub fn state_store(mut self, store: impl StateStore + 'static) -> Self { self.store_config = self.store_config.state_store(store); self @@ -151,6 +185,11 @@ impl ClientBuilder { /// Set a custom implementation of a `CryptoStore`. /// /// The crypto store should be opened before being set. + #[deprecated = "\ + Use [`store_config`](#method.store_config), \ + [`sled_store`](#method.sled_store) or \ + [`indexeddb_store`](#method.indexeddb_store) instead + "] #[cfg(feature = "e2e-encryption")] pub fn crypto_store( mut self, @@ -187,8 +226,7 @@ impl ClientBuilder { /// # futures::executor::block_on(async { /// use matrix_sdk::Client; /// - /// let client_config = Client::builder() - /// .proxy("http://localhost:8080"); + /// let client_config = Client::builder().proxy("http://localhost:8080"); /// /// # anyhow::Ok(()) /// # }); @@ -306,7 +344,7 @@ impl ClientBuilder { let well_known = http_client .send( discover_homeserver::Request::new(), - None, + Some(RequestConfig::short_retry()), homeserver, None, &[MatrixVersion::V1_0], @@ -342,6 +380,7 @@ impl ClientBuilder { typing_notice_times: Default::default(), event_handlers: Default::default(), event_handler_data: Default::default(), + event_handler_counter: Default::default(), notification_handlers: Default::default(), appservice_mode: self.appservice_mode, respect_login_well_known: self.respect_login_well_known, @@ -354,12 +393,12 @@ impl ClientBuilder { fn homeserver_from_name(server_name: &ServerName) -> String { #[cfg(not(test))] - return format!("https://{}", server_name); + return format!("https://{server_name}"); // Wiremock only knows how to test http endpoints: // https://github.com/LukeMathWalker/wiremock-rs/issues/58 #[cfg(test)] - return format!("http://{}", server_name); + return format!("http://{server_name}"); } #[derive(Clone, Debug)] diff --git a/crates/matrix-sdk/src/client/login_builder.rs b/crates/matrix-sdk/src/client/login_builder.rs index e01b8d90d..91c042b47 100644 --- a/crates/matrix-sdk/src/client/login_builder.rs +++ b/crates/matrix-sdk/src/client/login_builder.rs @@ -224,7 +224,7 @@ where const SSO_SERVER_BIND_TRIES: u8 = 10; let homeserver = self.client.homeserver().await; - info!("Logging in to {}", homeserver); + info!(%homeserver, "Logging in"); let (signal_tx, signal_rx) = oneshot::channel(); let (data_tx, data_rx) = oneshot::channel(); diff --git a/crates/matrix-sdk/src/client/mod.rs b/crates/matrix-sdk/src/client/mod.rs index 61840003c..fa8ccc07f 100644 --- a/crates/matrix-sdk/src/client/mod.rs +++ b/crates/matrix-sdk/src/client/mod.rs @@ -1,5 +1,6 @@ // Copyright 2020 Damir Jelić // Copyright 2020 The Matrix.org Foundation C.I.C. +// Copyright 2022 Famedly GmbH // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -19,7 +20,10 @@ use std::{ future::Future, io::Read, pin::Pin, - sync::{Arc, RwLock as StdRwLock}, + sync::{ + atomic::{AtomicU64, Ordering::SeqCst}, + Arc, RwLock as StdRwLock, + }, }; use anymap2::any::CloneAnySendSync; @@ -80,7 +84,10 @@ use crate::{ attachment::{AttachmentInfo, Thumbnail}, config::RequestConfig, error::{HttpError, HttpResult}, - event_handler::{EventHandler, EventHandlerData, EventHandlerResult, EventKind, SyncEvent}, + event_handler::{ + EventHandler, EventHandlerData, EventHandlerHandle, EventHandlerResult, + EventHandlerWrapper, EventKind, SyncEvent, + }, http_client::HttpClient, room, Account, Error, Result, }; @@ -101,8 +108,8 @@ const DEFAULT_UPLOAD_SPEED: u64 = 125_000; const MIN_UPLOAD_REQUEST_TIMEOUT: Duration = Duration::from_secs(60 * 5); type EventHandlerFut = Pin + Send>>; -type EventHandlerFn = Box) -> EventHandlerFut + Send + Sync>; -type EventHandlerMap = BTreeMap<(EventKind, &'static str), Vec>; +pub(crate) type EventHandlerFn = dyn Fn(EventHandlerData<'_>) -> EventHandlerFut + Send + Sync; +type EventHandlerMap = BTreeMap<(EventKind, &'static str), Vec>; type NotificationHandlerFut = EventHandlerFut; type NotificationHandlerFn = @@ -152,10 +159,13 @@ pub(crate) struct ClientInner { pub(crate) key_claim_lock: Mutex<()>, pub(crate) members_request_locks: DashMap>>, pub(crate) typing_notice_times: DashMap, - /// Event handlers. See `register_event_handler`. + /// Event handlers. See `add_event_handler`. event_handlers: RwLock, - /// Custom event handler context. See `register_event_handler_context`. + /// Custom event handler context. See `add_event_handler_context`. event_handler_data: StdRwLock, + /// When registering a event handler, the current value is used for the + /// handlers identification, then the counter is incremented. + event_handler_counter: AtomicU64, /// Notification handlers. See `register_notification_handler`. notification_handlers: RwLock>, /// Whether the client should operate in application service style mode. @@ -346,9 +356,11 @@ impl Client { /// "context" arguments: They have to implement [`EventHandlerContext`]. /// This trait is named that way because most of the types implementing it /// give additional context about an event: The room it was in, its raw form - /// and other similar things. As an exception to this, - /// [`Client`] also implements the `EventHandlerContext` trait - /// so you don't have to clone your client into the event handler manually. + /// and other similar things. As two exceptions to this, + /// [`Client`] and [`EventHandlerHandle`] also implement the + /// `EventHandlerContext` trait so you don't have to clone your client + /// into the event handler manually and a handler can decide to remove + /// itself. /// /// Some context arguments are not universally applicable. A context /// argument that isn't available for the given event type will result in @@ -388,14 +400,16 @@ impl Client { /// # .build() /// # .await /// # .unwrap(); + /// /// client - /// .register_event_handler( + /// .add_event_handler( /// |ev: SyncRoomMessageEvent, room: Room, client: Client| async move { /// // Common usage: Room event plus room and client. /// }, /// ) - /// .await - /// .register_event_handler( + /// .await; + /// client + /// .add_event_handler( /// |ev: SyncRoomMessageEvent, room: Room, encryption_info: Option| { /// async move { /// // An `Option` parameter lets you distinguish between @@ -403,8 +417,9 @@ impl Client { /// } /// }, /// ) - /// .await - /// .register_event_handler(|ev: SyncRoomTopicEvent| async move { + /// .await; + /// client + /// .add_event_handler(|ev: SyncRoomTopicEvent| async move { /// // You can omit any or all arguments after the first. /// }) /// .await; @@ -419,59 +434,82 @@ impl Client { /// expires_at: MilliSecondsSinceUnixEpoch, /// } /// - /// client.register_event_handler(|ev: SyncTokenEvent, room: Room| async move { + /// client.add_event_handler(|ev: SyncTokenEvent, room: Room| async move { /// todo!("Display the token"); /// }).await; /// /// // Adding your custom data to the handler can be done as well /// let data = "MyCustomIdentifier".to_owned(); /// - /// client.register_event_handler({ + /// client.add_event_handler({ /// let data = data.clone(); /// move |ev: SyncRoomMessageEvent | { /// let data = data.clone(); /// async move { - /// println!("Calling the handler with identifier {}", data); + /// println!("Calling the handler with identifier {data}"); /// } /// } /// }).await; /// # }); /// ``` + pub async fn add_event_handler(&self, handler: H) -> EventHandlerHandle + where + Ev: SyncEvent + DeserializeOwned + Send + 'static, + H: EventHandler, + ::Output: EventHandlerResult, + { + let key = (Ev::KIND, Ev::TYPE); + + let handler_fn: Box = Box::new(move |data| { + let maybe_fut = serde_json::from_str(data.raw.get()) + .map(|ev| handler.clone().handle_event(ev, data)); + + Box::pin(async move { + match maybe_fut { + Ok(Some(fut)) => { + fut.await.print_error(Ev::TYPE); + } + Ok(None) => { + error!( + event_type = Ev::TYPE, event_kind = ?Ev::KIND, + "Event handler has an invalid context argument", + ); + } + Err(e) => { + warn!( + event_type = Ev::TYPE, event_kind = ?Ev::KIND, + "Failed to deserialize event, skipping event handler.\n + Deserialization error: {e}", + ); + } + } + }) + }); + + let handler_id = self.inner.event_handler_counter.fetch_add(1, SeqCst); + + let handle = EventHandlerHandle { handler_id, ev_id: key }; + + self.inner + .event_handlers + .write() + .await + .entry(key) + .or_default() + .push(EventHandlerWrapper { handler_fn, handle }); + + handle + } + + #[allow(missing_docs)] + #[deprecated = "Use [`Client::add_event_handler`](#method.add_event_handler) instead"] pub async fn register_event_handler(&self, handler: H) -> &Self where Ev: SyncEvent + DeserializeOwned + Send + 'static, H: EventHandler, ::Output: EventHandlerResult, { - let event_type = H::ID.1; - self.inner.event_handlers.write().await.entry(H::ID).or_default().push(Box::new( - move |data| { - let maybe_fut = serde_json::from_str(data.raw.get()) - .map(|ev| handler.clone().handle_event(ev, data)); - - Box::pin(async move { - match maybe_fut { - Ok(Some(fut)) => { - fut.await.print_error(event_type); - } - Ok(None) => { - error!( - "Event handler for {} has an invalid context argument", - event_type - ); - } - Err(e) => { - warn!( - "Failed to deserialize `{}` event, skipping event handler.\n\ - Deserialization error: {}", - event_type, e, - ); - } - } - }) - }, - )); - + self.add_event_handler(handler).await; self } @@ -479,6 +517,63 @@ impl Client { self.inner.event_handlers.read().await } + /// Remove the event handler associated with the handle. + /// + /// Note that handlers that remove themselves will still execute + /// with events received in the same sync cycle. + /// + /// # Arguments + /// + /// `handle` - The [`EventHandlerHandle`] that is returned when + /// registering the event handler with [`Client::add_event_handler`]. + /// + /// # Examples + /// + /// ``` + /// # use futures::executor::block_on; + /// # use url::Url; + /// # use tokio::sync::mpsc; + /// # + /// # let homeserver = Url::parse("http://localhost:8080").unwrap(); + /// # + /// use matrix_sdk::{ + /// event_handler::EventHandlerHandle, + /// ruma::events::room::member::SyncRoomMemberEvent, Client, + /// }; + /// # + /// # block_on(async { + /// # let client = matrix_sdk::Client::builder() + /// # .homeserver_url(homeserver) + /// # .server_versions([ruma::api::MatrixVersion::V1_0]) + /// # .build() + /// # .await + /// # .unwrap(); + /// + /// client + /// .add_event_handler( + /// |ev: SyncRoomMemberEvent, + /// client: Client, + /// handle: EventHandlerHandle| async move { + /// // Common usage: Check arriving Event is the expected one + /// println!("Expected RoomMemberEvent received!"); + /// client.remove_event_handler(handle); + /// }, + /// ) + /// .await; + /// # }); + /// ``` + pub async fn remove_event_handler(&self, handle: EventHandlerHandle) { + let mut event_handlers = self.inner.event_handlers.write().await; + + if let Some(v) = event_handlers.get_mut(&handle.ev_id) { + v.retain(|e| e.handle.handler_id != handle.handler_id); + + if v.is_empty() { + event_handlers.remove(&handle.ev_id); + } + } + } + /// Add an arbitrary value for use as event handler context. /// /// The value can be obtained in an event handler by adding an argument of @@ -492,8 +587,7 @@ impl Client { /// ``` /// # use futures::executor::block_on; /// use matrix_sdk::{ - /// event_handler::Ctx, - /// room::Room, + /// event_handler::Ctx, room::Room, /// ruma::events::room::message::SyncRoomMessageEvent, /// }; /// # #[derive(Clone)] @@ -511,21 +605,32 @@ impl Client { /// // Handle used to send messages to the UI part of the app /// let my_gui_handle: SomeType = obtain_gui_handle(); /// + /// client.add_event_handler_context(my_gui_handle.clone()); /// client - /// .register_event_handler_context(my_gui_handle.clone()) - /// .register_event_handler( - /// |ev: SyncRoomMessageEvent, room: Room, gui_handle: Ctx| async move { + /// .add_event_handler( + /// |ev: SyncRoomMessageEvent, + /// room: Room, + /// gui_handle: Ctx| async move { /// // gui_handle.send(DisplayMessage { message: ev }); /// }, /// ) /// .await; /// # }); /// ``` - pub fn register_event_handler_context(&self, ctx: T) -> &Self + pub fn add_event_handler_context(&self, ctx: T) where T: Clone + Send + Sync + 'static, { self.inner.event_handler_data.write().unwrap().insert(ctx); + } + + #[allow(missing_docs)] + #[deprecated = "Use [`Client::add_event_handler_context`](#method.add_event_handler_context) instead"] + pub fn register_event_handler_context(&self, ctx: T) -> &Self + where + T: Clone + Send + Sync + 'static, + { + self.add_event_handler_context(ctx); self } @@ -539,7 +644,7 @@ impl Client { /// Register a handler for a notification. /// - /// Similar to [`Client::register_event_handler`], but only allows functions + /// Similar to [`Client::add_event_handler`], but only allows functions /// or closures with exactly the three arguments [`Notification`], /// [`room::Room`], [`Client`] for now. pub async fn register_notification_handler(&self, handler: H) -> &Self @@ -725,7 +830,6 @@ impl Client { /// # Example /// /// ```no_run - /// # use std::convert::TryFrom; /// # use futures::executor::block_on; /// # use url::Url; /// # let homeserver = Url::parse("http://example.com").unwrap(); @@ -742,8 +846,8 @@ impl Client { /// .await?; /// /// println!( - /// "Logged in as {}, got device_id {} and access_token {}", - /// user, response.device_id, response.access_token, + /// "Logged in as {user}, got device_id {} and access_token {}", + /// response.device_id, response.access_token, /// ); /// # anyhow::Ok(()) }); /// ``` @@ -792,7 +896,6 @@ impl Client { /// # Example /// /// ```no_run - /// # use std::convert::TryFrom; /// # use matrix_sdk::Client; /// # use matrix_sdk::ruma::{assign, DeviceId}; /// # use futures::executor::block_on; @@ -875,8 +978,10 @@ impl Client { /// .await /// .unwrap(); /// - /// println!("Logged in as {}, got device_id {} and access_token {}", - /// response.user_id, response.device_id, response.access_token); + /// println!( + /// "Logged in as {}, got device_id {} and access_token {}", + /// response.user_id, response.device_id, response.access_token + /// ); /// # }) /// ``` /// @@ -1006,7 +1111,10 @@ impl Client { /// # Examples /// /// ```no_run - /// use matrix_sdk::{Client, Session, ruma::{device_id, user_id}}; + /// use matrix_sdk::{ + /// ruma::{device_id, user_id}, + /// Client, Session, + /// }; /// # use url::Url; /// # use futures::executor::block_on; /// # block_on(async { @@ -1036,10 +1144,8 @@ impl Client { /// let homeserver = Url::parse("http://example.com")?; /// let client = Client::new(homeserver).await?; /// - /// let session: Session = client - /// .login("example", "my-password", None, None) - /// .await? - /// .into(); + /// let session: Session = + /// client.login("example", "my-password", None, None).await?.into(); /// /// // Persist the `Session` so it can later be used to restore the login. /// client.restore_login(session).await?; @@ -1061,7 +1167,6 @@ impl Client { /// # Examples /// /// ```no_run - /// # use std::convert::TryFrom; /// # use matrix_sdk::Client; /// # use matrix_sdk::ruma::{ /// # api::client::{ @@ -1092,7 +1197,7 @@ impl Client { registration: impl Into>, ) -> HttpResult { let homeserver = self.homeserver().await; - info!("Registering to {}", homeserver); + info!("Registering to {homeserver}"); let config = if self.inner.appservice_mode { Some(RequestConfig::short_retry().force_auth()) @@ -1224,7 +1329,6 @@ impl Client { /// # Examples /// ```no_run /// use matrix_sdk::Client; - /// # use std::convert::TryInto; /// # use url::Url; /// # let homeserver = Url::parse("http://example.com").unwrap(); /// # let limit = Some(10); @@ -1300,15 +1404,13 @@ impl Client { /// # Examples /// /// ```no_run - /// # use std::convert::TryFrom; /// # use url::Url; /// # use matrix_sdk::Client; /// # use futures::executor::block_on; /// # block_on(async { /// # let homeserver = Url::parse("http://example.com")?; /// use matrix_sdk::ruma::{ - /// api::client::directory::get_public_rooms_filtered, - /// directory::Filter, + /// api::client::directory::get_public_rooms_filtered, directory::Filter, /// }; /// # let mut client = Client::new(homeserver).await?; /// @@ -1356,9 +1458,7 @@ impl Client { /// let path = PathBuf::from("/home/example/my-cat.jpg"); /// let mut image = File::open(path)?; /// - /// let response = client - /// .upload(&mime::IMAGE_JPEG, &mut image) - /// .await?; + /// let response = client.upload(&mime::IMAGE_JPEG, &mut image).await?; /// /// println!("Cat URI: {}", response.content_uri); /// # anyhow::Ok(()) }); @@ -1415,7 +1515,6 @@ impl Client { /// # use matrix_sdk::{Client, config::SyncSettings}; /// # use futures::executor::block_on; /// # use url::Url; - /// # use std::convert::TryFrom; /// # block_on(async { /// # let homeserver = Url::parse("http://localhost:8080")?; /// # let mut client = Client::new(homeserver).await?; @@ -1498,7 +1597,6 @@ impl Client { /// # use matrix_sdk::{Client, config::SyncSettings}; /// # use futures::executor::block_on; /// # use url::Url; - /// # use std::convert::TryFrom; /// # block_on(async { /// # let homeserver = Url::parse("http://localhost:8080")?; /// # let mut client = Client::new(homeserver).await?; @@ -1546,7 +1644,7 @@ impl Client { /// # use futures::executor::block_on; /// # use serde_json::json; /// # use url::Url; - /// # use std::{collections::BTreeMap, convert::TryFrom}; + /// # use std::collections::BTreeMap; /// # block_on(async { /// # let homeserver = Url::parse("http://localhost:8080")?; /// # let mut client = Client::new(homeserver).await?; @@ -1635,8 +1733,8 @@ impl Client { /// # let username = ""; /// # let password = ""; /// use matrix_sdk::{ - /// Client, config::SyncSettings, - /// ruma::events::room::message::OriginalSyncRoomMessageEvent, + /// config::SyncSettings, + /// ruma::events::room::message::OriginalSyncRoomMessageEvent, Client, /// }; /// /// let client = Client::new(homeserver).await?; @@ -1647,9 +1745,11 @@ impl Client { /// /// // Register our handler so we start responding once we receive a new /// // event. - /// client.register_event_handler(|ev: OriginalSyncRoomMessageEvent| async move { - /// println!("Received event {}: {:?}", ev.sender, ev.content); - /// }).await; + /// client + /// .add_event_handler(|ev: OriginalSyncRoomMessageEvent| async move { + /// println!("Received event {}: {:?}", ev.sender, ev.content); + /// }) + /// .await; /// /// // Now keep on syncing forever. `sync()` will use the stored sync token /// // from our `sync_once()` call automatically. @@ -1718,7 +1818,7 @@ impl Client { /// /// This method will internally call [`Client::sync_once`] in a loop. /// - /// This method can be used with the [`Client::register_event_handler`] + /// This method can be used with the [`Client::add_event_handler`] /// method to react to individual events. If you instead wish to handle /// events in a bulk manner the [`Client::sync_with_callback`] and /// [`Client::sync_stream`] methods can be used instead. Those two methods @@ -1740,8 +1840,8 @@ impl Client { /// # let username = ""; /// # let password = ""; /// use matrix_sdk::{ - /// Client, config::SyncSettings, - /// ruma::events::room::message::OriginalSyncRoomMessageEvent, + /// config::SyncSettings, + /// ruma::events::room::message::OriginalSyncRoomMessageEvent, Client, /// }; /// /// let client = Client::new(homeserver).await?; @@ -1749,9 +1849,11 @@ impl Client { /// /// // Register our handler so we start responding once we receive a new /// // event. - /// client.register_event_handler(|ev: OriginalSyncRoomMessageEvent| async move { - /// println!("Received event {}: {:?}", ev.sender, ev.content); - /// }).await; + /// client + /// .add_event_handler(|ev: OriginalSyncRoomMessageEvent| async move { + /// println!("Received event {}: {:?}", ev.sender, ev.content); + /// }) + /// .await; /// /// // Now keep on syncing forever. `sync()` will use the latest sync token /// // automatically. @@ -1838,8 +1940,6 @@ impl Client { if callback(r).await == LoopCtrl::Break { return; } - } else { - continue; } Client::delay_sync(&mut last_sync_time).await @@ -1868,12 +1968,13 @@ impl Client { /// # let username = ""; /// # let password = ""; /// use futures::StreamExt; - /// use matrix_sdk::{Client, config::SyncSettings}; + /// use matrix_sdk::{config::SyncSettings, Client}; /// /// let client = Client::new(homeserver).await?; /// client.login(&username, &password, None, None).await?; /// - /// let mut sync_stream = Box::pin(client.sync_stream(SyncSettings::default()).await); + /// let mut sync_stream = + /// Box::pin(client.sync_stream(SyncSettings::default()).await); /// /// while let Some(Ok(response)) = sync_stream.next().await { /// for room in response.rooms.join.values() { diff --git a/crates/matrix-sdk/src/encryption/identities/devices.rs b/crates/matrix-sdk/src/encryption/identities/devices.rs index e03427b5b..a090bc622 100644 --- a/crates/matrix-sdk/src/encryption/identities/devices.rs +++ b/crates/matrix-sdk/src/encryption/identities/devices.rs @@ -81,7 +81,6 @@ impl Device { /// # Examples /// /// ```no_run - /// # use std::convert::TryFrom; /// # use matrix_sdk::{Client, ruma::{device_id, user_id}}; /// # use url::Url; /// # use futures::executor::block_on; @@ -89,7 +88,8 @@ impl Device { /// # let alice = user_id!("@alice:example.org"); /// # let homeserver = Url::parse("http://example.com")?; /// # let client = Client::new(homeserver).await?; - /// let device = client.encryption().get_device(alice, device_id!("DEVICEID")).await?; + /// let device = + /// client.encryption().get_device(alice, device_id!("DEVICEID")).await?; /// /// if let Some(device) = device { /// let verification = device.request_verification().await?; @@ -123,7 +123,6 @@ impl Device { /// # Examples /// /// ```no_run - /// # use std::convert::TryFrom; /// # use matrix_sdk::{ /// # Client, /// # ruma::{ @@ -137,14 +136,16 @@ impl Device { /// # let alice = user_id!("@alice:example.org"); /// # let homeserver = Url::parse("http://example.com")?; /// # let client = Client::new(homeserver).await?; - /// let device = client.encryption().get_device(alice, device_id!("DEVICEID")).await?; + /// let device = + /// client.encryption().get_device(alice, device_id!("DEVICEID")).await?; /// /// // We don't want to support showing a QR code, we only support SAS /// // verification /// let methods = vec![VerificationMethod::SasV1]; /// /// if let Some(device) = device { - /// let verification = device.request_verification_with_methods(methods).await?; + /// let verification = + /// device.request_verification_with_methods(methods).await?; /// } /// # anyhow::Ok(()) }); /// ``` @@ -171,7 +172,6 @@ impl Device { /// # Examples /// /// ```no_run - /// # use std::convert::TryFrom; /// # use matrix_sdk::{Client, ruma::{device_id, user_id}}; /// # use url::Url; /// # use futures::executor::block_on; @@ -179,7 +179,8 @@ impl Device { /// # let alice = user_id!("@alice:example.org"); /// # let homeserver = Url::parse("http://example.com")?; /// # let client = Client::new(homeserver).await?; - /// let device = client.encryption().get_device(alice, device_id!("DEVICEID")).await?; + /// let device = + /// client.encryption().get_device(alice, device_id!("DEVICEID")).await?; /// /// if let Some(device) = device { /// let verification = device.start_verification().await?; @@ -229,7 +230,6 @@ impl Device { /// # Examples /// /// ```no_run - /// # use std::convert::TryFrom; /// # use matrix_sdk::{ /// # Client, /// # ruma::{ @@ -243,7 +243,8 @@ impl Device { /// # let alice = user_id!("@alice:example.org"); /// # let homeserver = Url::parse("http://example.com")?; /// # let client = Client::new(homeserver).await?; - /// let device = client.encryption().get_device(alice, device_id!("DEVICEID")).await?; + /// let device = + /// client.encryption().get_device(alice, device_id!("DEVICEID")).await?; /// /// if let Some(device) = device { /// device.verify().await?; @@ -344,7 +345,6 @@ impl Device { /// Let's check if a device is verified: /// /// ```no_run - /// # use std::convert::TryFrom; /// # use matrix_sdk::{ /// # Client, /// # ruma::{ @@ -358,18 +358,21 @@ impl Device { /// # let alice = user_id!("@alice:example.org"); /// # let homeserver = Url::parse("http://example.com")?; /// # let client = Client::new(homeserver).await?; - /// let device = client.encryption().get_device(alice, device_id!("DEVICEID")).await?; + /// let device = + /// client.encryption().get_device(alice, device_id!("DEVICEID")).await?; /// /// if let Some(device) = device { /// if device.verified() { /// println!( /// "Device {} of user {} is verified", - /// device.device_id().as_str(), device.user_id().as_str() + /// device.device_id().as_str(), + /// device.user_id().as_str() /// ); /// } else { /// println!( /// "Device {} of user {} is not verified", - /// device.device_id().as_str(), device.user_id().as_str() + /// device.device_id().as_str(), + /// device.user_id().as_str() /// ); /// } /// } diff --git a/crates/matrix-sdk/src/encryption/identities/mod.rs b/crates/matrix-sdk/src/encryption/identities/mod.rs index 220c3c3d3..bb154f883 100644 --- a/crates/matrix-sdk/src/encryption/identities/mod.rs +++ b/crates/matrix-sdk/src/encryption/identities/mod.rs @@ -35,7 +35,6 @@ //! Verifying a device is pretty straightforward: //! //! ```no_run -//! # use std::convert::TryFrom; //! # use matrix_sdk::{Client, ruma::{device_id, user_id}}; //! # use url::Url; //! # use futures::executor::block_on; @@ -43,7 +42,8 @@ //! # let homeserver = Url::parse("http://example.com").unwrap(); //! # block_on(async { //! # let client = Client::new(homeserver).await.unwrap(); -//! let device = client.encryption().get_device(alice, device_id!("DEVICEID")).await?; +//! let device = +//! client.encryption().get_device(alice, device_id!("DEVICEID")).await?; //! //! if let Some(device) = device { //! // Let's request the device to be verified. @@ -61,7 +61,6 @@ //! Verifying a user identity works largely the same: //! //! ```no_run -//! # use std::convert::TryFrom; //! # use matrix_sdk::{Client, ruma::user_id}; //! # use url::Url; //! # use futures::executor::block_on; diff --git a/crates/matrix-sdk/src/encryption/identities/users.rs b/crates/matrix-sdk/src/encryption/identities/users.rs index ae85e5054..f3e8eb11f 100644 --- a/crates/matrix-sdk/src/encryption/identities/users.rs +++ b/crates/matrix-sdk/src/encryption/identities/users.rs @@ -91,7 +91,6 @@ impl UserIdentity { /// # Examples /// /// ```no_run - /// # use std::convert::TryFrom; /// # use matrix_sdk::{Client, ruma::user_id}; /// # use url::Url; /// # let alice = user_id!("@alice:example.org"); @@ -143,7 +142,6 @@ impl UserIdentity { /// # Examples /// /// ```no_run - /// # use std::convert::TryFrom; /// # use matrix_sdk::{Client, ruma::user_id}; /// # use url::Url; /// # let alice = user_id!("@alice:example.org"); @@ -195,7 +193,6 @@ impl UserIdentity { /// # Examples /// /// ```no_run - /// # use std::convert::TryFrom; /// # use matrix_sdk::{ /// # Client, /// # ruma::{ @@ -216,7 +213,8 @@ impl UserIdentity { /// let methods = vec![VerificationMethod::SasV1]; /// /// if let Some(user) = user { - /// let verification = user.request_verification_with_methods(methods).await?; + /// let verification = + /// user.request_verification_with_methods(methods).await?; /// } /// # anyhow::Ok(()) }); /// ``` @@ -274,7 +272,6 @@ impl UserIdentity { /// # Examples /// /// ```no_run - /// # use std::convert::TryFrom; /// # use matrix_sdk::{ /// # Client, /// # ruma::{ @@ -318,7 +315,6 @@ impl UserIdentity { /// # Examples /// /// ```no_run - /// # use std::convert::TryFrom; /// # use matrix_sdk::{ /// # Client, /// # ruma::{ @@ -358,7 +354,6 @@ impl UserIdentity { /// # Examples /// /// ```no_run - /// # use std::convert::TryFrom; /// # use matrix_sdk::{ /// # Client, /// # ruma::{ @@ -379,14 +374,19 @@ impl UserIdentity { /// // matches what we expect, for this we fetch the first public key we /// // can find, there's currently only a single key allowed so this is /// // fine. - /// if user.master_key().get_first_key().map(|k| k.to_base64()) == Some("MyMasterKey".to_string()) { + /// if user.master_key().get_first_key().map(|k| k.to_base64()) + /// == Some("MyMasterKey".to_string()) + /// { /// println!( /// "Master keys match for user {}, marking the user as verified", /// user.user_id().as_str(), /// ); /// user.verify().await?; /// } else { - /// println!("Master keys don't match for user {}", user.user_id().as_str()); + /// println!( + /// "Master keys don't match for user {}", + /// user.user_id().as_str() + /// ); /// } /// } /// # anyhow::Ok(()) }); diff --git a/crates/matrix-sdk/src/encryption/mod.rs b/crates/matrix-sdk/src/encryption/mod.rs index fd597437b..005cd5864 100644 --- a/crates/matrix-sdk/src/encryption/mod.rs +++ b/crates/matrix-sdk/src/encryption/mod.rs @@ -555,7 +555,6 @@ impl Encryption { /// # Example /// /// ```no_run - /// # use std::convert::TryFrom; /// # use matrix_sdk::{Client, ruma::{device_id, user_id}}; /// # use url::Url; /// # use futures::executor::block_on; @@ -563,15 +562,14 @@ impl Encryption { /// # let alice = user_id!("@alice:example.org"); /// # let homeserver = Url::parse("http://example.com")?; /// # let client = Client::new(homeserver).await?; - /// if let Some(device) = client - /// .encryption() - /// .get_device(alice, device_id!("DEVICEID")) - /// .await? { - /// println!("{:?}", device.verified()); + /// if let Some(device) = + /// client.encryption().get_device(alice, device_id!("DEVICEID")).await? + /// { + /// println!("{:?}", device.verified()); /// - /// if !device.verified() { - /// let verification = device.request_verification().await?; - /// } + /// if !device.verified() { + /// let verification = device.request_verification().await?; + /// } /// } /// # anyhow::Ok(()) }); /// ``` @@ -600,7 +598,6 @@ impl Encryption { /// # Example /// /// ```no_run - /// # use std::convert::TryFrom; /// # use matrix_sdk::{Client, ruma::user_id}; /// # use url::Url; /// # use futures::executor::block_on; @@ -640,7 +637,6 @@ impl Encryption { /// # Example /// /// ```no_run - /// # use std::convert::TryFrom; /// # use matrix_sdk::{Client, ruma::user_id}; /// # use url::Url; /// # use futures::executor::block_on; @@ -691,7 +687,7 @@ impl Encryption { /// /// # Examples /// ```no_run - /// # use std::{convert::TryFrom, collections::BTreeMap}; + /// # use std::collections::BTreeMap; /// # use matrix_sdk::{ruma::api::client::uiaa, Client}; /// # use url::Url; /// # use futures::executor::block_on; @@ -841,7 +837,8 @@ impl Encryption { /// # let homeserver = Url::parse("http://localhost:8080")?; /// # let mut client = Client::new(homeserver).await?; /// let path = PathBuf::from("/home/example/e2e-keys.txt"); - /// let result = client.encryption().import_keys(path, "secret-passphrase").await?; + /// let result = + /// client.encryption().import_keys(path, "secret-passphrase").await?; /// /// println!( /// "Imported {} room keys out of {}", diff --git a/crates/matrix-sdk/src/encryption/verification/sas.rs b/crates/matrix-sdk/src/encryption/verification/sas.rs index c9dc3928a..2991e6c69 100644 --- a/crates/matrix-sdk/src/encryption/verification/sas.rs +++ b/crates/matrix-sdk/src/encryption/verification/sas.rs @@ -44,7 +44,7 @@ impl SasVerification { /// # use url::Url; /// # use ruma::user_id; /// use matrix_sdk::{ - /// encryption::verification::{SasVerification, AcceptSettings}, + /// encryption::verification::{AcceptSettings, SasVerification}, /// ruma::events::key::verification::ShortAuthenticationString, /// }; /// @@ -60,9 +60,9 @@ impl SasVerification { /// .and_then(|v| v.sas()); /// /// if let Some(sas) = sas { - /// let only_decimal = AcceptSettings::with_allowed_methods( - /// vec![ShortAuthenticationString::Decimal] - /// ); + /// let only_decimal = AcceptSettings::with_allowed_methods(vec![ + /// ShortAuthenticationString::Decimal, + /// ]); /// /// sas.accept_with_settings(only_decimal).await?; /// } @@ -119,7 +119,7 @@ impl SasVerification { /// # use url::Url; /// # use ruma::user_id; /// use matrix_sdk::{ - /// encryption::verification::{SasVerification, AcceptSettings}, + /// encryption::verification::{AcceptSettings, SasVerification}, /// ruma::events::key::verification::ShortAuthenticationString, /// }; /// @@ -147,7 +147,7 @@ impl SasVerification { /// .collect::>() /// .join(""); /// - /// println!("Do the emojis match?\n{}\n{}", emoji_string, description); + /// println!("Do the emojis match?\n{emoji_string}\n{description}"); /// } /// # anyhow::Ok(()) }); /// ``` diff --git a/crates/matrix-sdk/src/event_handler.rs b/crates/matrix-sdk/src/event_handler.rs index fec80128d..bb38e3231 100644 --- a/crates/matrix-sdk/src/event_handler.rs +++ b/crates/matrix-sdk/src/event_handler.rs @@ -1,4 +1,5 @@ // Copyright 2021 Jonas Platte +// Copyright 2022 Famedly GmbH // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -13,11 +14,11 @@ // limitations under the License. //! Types and traits related for event handlers. For usage, see -//! [`Client::register_event_handler`]. +//! [`Client::add_event_handler`]. //! //! ### How it works //! -//! The `register_event_handler` method registers event handlers of different +//! The `add_event_handler` method registers event handlers of different //! signatures by actually storing boxed closures that all have the same //! signature of `async (EventHandlerData) -> ()` where `EventHandlerData` is a //! private type that contains all of the data an event handler *might* need. @@ -38,8 +39,9 @@ use matrix_sdk_base::deserialized_responses::{EncryptionInfo, SyncRoomEvent}; use ruma::{events::AnySyncStateEvent, serde::Raw}; use serde::Deserialize; use serde_json::value::RawValue as RawJsonValue; +use tracing::error; -use crate::{room, Client}; +use crate::{client::EventHandlerFn, room, Client}; #[doc(hidden)] #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] @@ -80,7 +82,28 @@ impl EventKind { /// A statically-known event kind/type that can be retrieved from an event sync. pub trait SyncEvent { #[doc(hidden)] - const ID: (EventKind, &'static str); + const KIND: EventKind; + #[doc(hidden)] + const TYPE: &'static str; +} + +pub(crate) struct EventHandlerWrapper { + pub handler_fn: Box, + pub handle: EventHandlerHandle, +} + +/// Handle to remove a registered event handler by passing it to +/// [`Client::remove_event_handler`]. +#[derive(Clone, Copy, Debug)] +pub struct EventHandlerHandle { + pub(crate) ev_id: (EventKind, &'static str), + pub(crate) handler_id: u64, +} + +impl EventHandlerContext for EventHandlerHandle { + fn from_data(data: &EventHandlerData<'_>) -> Option { + Some(data.handle) + } } /// Interface for event handlers. @@ -107,7 +130,7 @@ pub trait SyncEvent { /// `Ev` and `Ctx` are generic parameters rather than associated types because /// the argument list is a generic parameter for the `Fn` traits too, so a /// single type could implement `Fn` multiple times with different argument -/// lists¹. Luckily, when calling [`Client::register_event_handler`] with a +/// lists¹. Luckily, when calling [`Client::add_event_handler`] with a /// closure argument the trait solver takes into account that only a single one /// of the implementations applies (even though this could theoretically change /// through a dependency upgrade) and uses that rather than raising an ambiguity @@ -121,11 +144,6 @@ pub trait EventHandler: Clone + Send + Sync + 'static { #[doc(hidden)] type Future: Future + Send + 'static; - /// The event type being handled, for example a message event of type - /// `m.room.message`. - #[doc(hidden)] - const ID: (EventKind, &'static str); - /// Create a future for handling the given event. /// /// `data` provides additional data about the event, for example the room it @@ -143,6 +161,7 @@ pub struct EventHandlerData<'a> { pub room: Option, pub raw: &'a RawJsonValue, pub encryption_info: Option<&'a EncryptionInfo>, + pub handle: EventHandlerHandle, } /// Context for an event handler. @@ -175,7 +194,7 @@ impl EventHandlerContext for room::Room { /// The raw JSON form of an event. /// /// Used as a context argument for event handlers (see -/// [`Client::register_event_handler`]). +/// [`Client::add_event_handler`]). // FIXME: This could be made to not own the raw JSON value with some changes to // the traits above, but only with GATs. #[derive(Clone, Debug)] @@ -202,7 +221,7 @@ impl EventHandlerContext for Option { } /// A custom value registered with -/// [`.register_event_handler_context`][Client::register_event_handler_context]. +/// [`.add_event_handler_context`][Client::add_event_handler_context]. #[derive(Debug)] pub struct Ctx(pub T); @@ -237,14 +256,14 @@ impl EventHandlerResult for Result<(), E match self { #[cfg(feature = "anyhow")] Err(e) if TypeId::of::() == TypeId::of::() => { - tracing::error!("Event handler for `{}` failed: {:?}", event_type, e); + error!("Event handler for `{event_type}` failed: {e:?}"); } #[cfg(feature = "eyre")] Err(e) if TypeId::of::() == TypeId::of::() => { - tracing::error!("Event handler for `{}` failed: {:?}", event_type, e); + error!("Event handler for `{event_type}` failed: {e:?}"); } Err(e) => { - tracing::error!("Event handler for `{}` failed: {}", event_type, e); + error!("Event handler for `{event_type}` failed: {e}"); } Ok(_) => {} } @@ -383,14 +402,15 @@ impl Client { .get(&event_handler_id) .into_iter() .flatten() - .map(|handler| { + .map(|handler_wrapper| { let data = EventHandlerData { client: self.clone(), room: room.clone(), raw: raw_event.json(), encryption_info, + handle: handler_wrapper.handle, }; - (handler)(data) + (handler_wrapper.handler_fn)(data) }) .collect(); @@ -416,7 +436,6 @@ macro_rules! impl_event_handler { $($ty: EventHandlerContext),* { type Future = Fut; - const ID: (EventKind, &'static str) = Ev::ID; fn handle_event(&self, ev: Ev, _d: EventHandlerData<'_>) -> Option { Some((self)(ev, $($ty::from_data(&_d)?),*)) @@ -450,21 +469,24 @@ mod static_events { where C: StaticEventContent + GlobalAccountDataEventContent, { - const ID: (EventKind, &'static str) = (EventKind::GlobalAccountData, C::TYPE); + const KIND: EventKind = EventKind::GlobalAccountData; + const TYPE: &'static str = C::TYPE; } impl SyncEvent for events::RoomAccountDataEvent where C: StaticEventContent + RoomAccountDataEventContent, { - const ID: (EventKind, &'static str) = (EventKind::RoomAccountData, C::TYPE); + const KIND: EventKind = EventKind::RoomAccountData; + const TYPE: &'static str = C::TYPE; } impl SyncEvent for events::SyncEphemeralRoomEvent where C: StaticEventContent + EphemeralRoomEventContent, { - const ID: (EventKind, &'static str) = (EventKind::EphemeralRoomData, C::TYPE); + const KIND: EventKind = EventKind::EphemeralRoomData; + const TYPE: &'static str = C::TYPE; } impl SyncEvent for events::SyncMessageLikeEvent @@ -472,40 +494,39 @@ mod static_events { C: StaticEventContent + MessageLikeEventContent + RedactContent, C::Redacted: MessageLikeEventContent + RedactedEventContent, { - const ID: (EventKind, &'static str) = (EventKind::MessageLike, C::TYPE); + const KIND: EventKind = EventKind::MessageLike; + const TYPE: &'static str = C::TYPE; } impl SyncEvent for events::OriginalSyncMessageLikeEvent where C: StaticEventContent + MessageLikeEventContent, { - const ID: (EventKind, &'static str) = (EventKind::OriginalMessageLike, C::TYPE); + const KIND: EventKind = EventKind::OriginalMessageLike; + const TYPE: &'static str = C::TYPE; } impl SyncEvent for events::RedactedSyncMessageLikeEvent where C: StaticEventContent + MessageLikeEventContent + RedactedEventContent, { - const ID: (EventKind, &'static str) = (EventKind::RedactedMessageLike, C::TYPE); + const KIND: EventKind = EventKind::RedactedMessageLike; + const TYPE: &'static str = C::TYPE; } impl SyncEvent for events::room::redaction::SyncRoomRedactionEvent { - const ID: (EventKind, &'static str) = - (EventKind::MessageLike, events::room::redaction::RoomRedactionEventContent::TYPE); + const KIND: EventKind = EventKind::MessageLike; + const TYPE: &'static str = events::room::redaction::RoomRedactionEventContent::TYPE; } impl SyncEvent for events::room::redaction::OriginalSyncRoomRedactionEvent { - const ID: (EventKind, &'static str) = ( - EventKind::OriginalMessageLike, - events::room::redaction::RoomRedactionEventContent::TYPE, - ); + const KIND: EventKind = EventKind::OriginalMessageLike; + const TYPE: &'static str = events::room::redaction::RoomRedactionEventContent::TYPE; } impl SyncEvent for events::room::redaction::RedactedSyncRoomRedactionEvent { - const ID: (EventKind, &'static str) = ( - EventKind::RedactedMessageLike, - events::room::redaction::RoomRedactionEventContent::TYPE, - ); + const KIND: EventKind = EventKind::RedactedMessageLike; + const TYPE: &'static str = events::room::redaction::RoomRedactionEventContent::TYPE; } impl SyncEvent for events::SyncStateEvent @@ -513,46 +534,53 @@ mod static_events { C: StaticEventContent + StateEventContent + RedactContent, C::Redacted: StateEventContent + RedactedEventContent, { - const ID: (EventKind, &'static str) = (EventKind::State, C::TYPE); + const KIND: EventKind = EventKind::State; + const TYPE: &'static str = C::TYPE; } impl SyncEvent for events::OriginalSyncStateEvent where C: StaticEventContent + StateEventContent, { - const ID: (EventKind, &'static str) = (EventKind::OriginalState, C::TYPE); + const KIND: EventKind = EventKind::OriginalState; + const TYPE: &'static str = C::TYPE; } impl SyncEvent for events::RedactedSyncStateEvent where C: StaticEventContent + StateEventContent + RedactedEventContent, { - const ID: (EventKind, &'static str) = (EventKind::RedactedState, C::TYPE); + const KIND: EventKind = EventKind::RedactedState; + const TYPE: &'static str = C::TYPE; } impl SyncEvent for events::StrippedStateEvent where C: StaticEventContent + StateEventContent, { - const ID: (EventKind, &'static str) = (EventKind::StrippedState, C::TYPE); + const KIND: EventKind = EventKind::StrippedState; + const TYPE: &'static str = C::TYPE; } impl SyncEvent for events::InitialStateEvent where C: StaticEventContent + StateEventContent, { - const ID: (EventKind, &'static str) = (EventKind::InitialState, C::TYPE); + const KIND: EventKind = EventKind::InitialState; + const TYPE: &'static str = C::TYPE; } impl SyncEvent for events::ToDeviceEvent where C: StaticEventContent + ToDeviceEventContent, { - const ID: (EventKind, &'static str) = (EventKind::ToDevice, C::TYPE); + const KIND: EventKind = EventKind::ToDevice; + const TYPE: &'static str = C::TYPE; } impl SyncEvent for PresenceEvent { - const ID: (EventKind, &'static str) = (EventKind::Presence, PresenceEventContent::TYPE); + const KIND: EventKind = EventKind::Presence; + const TYPE: &'static str = PresenceEventContent::TYPE; } } @@ -567,7 +595,13 @@ mod tests { EphemeralTestEvent, EventBuilder, StateTestEvent, StrippedStateTestEvent, TimelineTestEvent, }; use ruma::{ - events::room::member::{OriginalSyncRoomMemberEvent, StrippedRoomMemberEvent}, + events::{ + room::{ + member::{OriginalSyncRoomMemberEvent, StrippedRoomMemberEvent}, + power_levels::OriginalSyncRoomPowerLevelsEvent, + }, + typing::SyncTypingEvent, + }, room_id, }; use serde_json::json; @@ -575,7 +609,7 @@ mod tests { use crate::{room, Client}; #[async_test] - async fn event_handler() -> crate::Result<()> { + async fn add_event_handler() -> crate::Result<()> { use std::sync::atomic::{AtomicU8, Ordering::SeqCst}; let client = crate::client::tests::logged_in_client(None).await; @@ -586,31 +620,34 @@ mod tests { let invited_member_count = Arc::new(AtomicU8::new(0)); client - .register_event_handler({ + .add_event_handler({ let member_count = member_count.clone(); move |_ev: OriginalSyncRoomMemberEvent, _room: room::Room| { member_count.fetch_add(1, SeqCst); future::ready(()) } }) - .await - .register_event_handler({ + .await; + client + .add_event_handler({ let typing_count = typing_count.clone(); - move |_ev: OriginalSyncRoomMemberEvent| { + move |_ev: SyncTypingEvent| { typing_count.fetch_add(1, SeqCst); future::ready(()) } }) - .await - .register_event_handler({ + .await; + client + .add_event_handler({ let power_levels_count = power_levels_count.clone(); - move |_ev: OriginalSyncRoomMemberEvent, _client: Client, _room: room::Room| { + move |_ev: OriginalSyncRoomPowerLevelsEvent, _client: Client, _room: room::Room| { power_levels_count.fetch_add(1, SeqCst); future::ready(()) } }) - .await - .register_event_handler({ + .await; + client + .add_event_handler({ let invited_member_count = invited_member_count.clone(); move |_ev: StrippedRoomMemberEvent| { invited_member_count.fetch_add(1, SeqCst); @@ -674,4 +711,57 @@ mod tests { Ok(()) } + + #[async_test] + async fn remove_event_handler() -> crate::Result<()> { + use std::sync::atomic::{AtomicU8, Ordering::SeqCst}; + + let client = crate::client::tests::logged_in_client(None).await; + + let member_count = Arc::new(AtomicU8::new(0)); + + client + .add_event_handler({ + let member_count = member_count.clone(); + move |_ev: OriginalSyncRoomMemberEvent| { + member_count.fetch_add(1, SeqCst); + future::ready(()) + } + }) + .await; + + let handle = client + .add_event_handler({ + move |_ev: OriginalSyncRoomMemberEvent| { + panic!("handler should have been removed"); + #[allow(unreachable_code)] + future::ready(()) + } + }) + .await; + + client + .add_event_handler({ + let member_count = member_count.clone(); + move |_ev: OriginalSyncRoomMemberEvent| { + member_count.fetch_add(1, SeqCst); + future::ready(()) + } + }) + .await; + + let response = EventBuilder::default() + .add_joined_room( + JoinedRoomBuilder::default().add_timeline_event(TimelineTestEvent::Member), + ) + .build_sync_response(); + + client.remove_event_handler(handle).await; + + client.process_sync(response).await?; + + assert_eq!(member_count.load(SeqCst), 2); + + Ok(()) + } } diff --git a/crates/matrix-sdk/src/http_client.rs b/crates/matrix-sdk/src/http_client.rs index 13b4fe25f..b03ad5400 100644 --- a/crates/matrix-sdk/src/http_client.rs +++ b/crates/matrix-sdk/src/http_client.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{any::type_name, convert::TryFrom, fmt::Debug, sync::Arc, time::Duration}; +use std::{any::type_name, fmt::Debug, sync::Arc, time::Duration}; use async_trait::async_trait; use bytes::{Bytes, BytesMut}; @@ -50,8 +50,9 @@ pub trait HttpSend: AsyncTraitDeps { /// # Examples /// /// ``` - /// use std::convert::TryFrom; - /// use matrix_sdk::{HttpSend, async_trait, HttpError, config::RequestConfig, bytes::Bytes}; + /// use matrix_sdk::{ + /// async_trait, bytes::Bytes, config::RequestConfig, HttpError, HttpSend, + /// }; /// /// #[derive(Debug)] /// struct Client(reqwest::Client); diff --git a/crates/matrix-sdk/src/lib.rs b/crates/matrix-sdk/src/lib.rs index 5303dfe01..34fb3010a 100644 --- a/crates/matrix-sdk/src/lib.rs +++ b/crates/matrix-sdk/src/lib.rs @@ -17,10 +17,10 @@ #![warn(missing_debug_implementations, missing_docs)] #![cfg_attr(docsrs, feature(doc_auto_cfg))] -#[cfg(not(any(feature = "native-tls", feature = "rustls-tls",)))] +#[cfg(not(any(feature = "native-tls", feature = "rustls-tls")))] compile_error!("one of 'native-tls' or 'rustls-tls' features must be enabled"); -#[cfg(all(feature = "native-tls", feature = "rustls-tls",))] +#[cfg(all(feature = "native-tls", feature = "rustls-tls"))] compile_error!("only one of 'native-tls' or 'rustls-tls' features can be enabled"); #[cfg(all(feature = "sso-login", target_arch = "wasm32"))] diff --git a/crates/matrix-sdk/src/room/common.rs b/crates/matrix-sdk/src/room/common.rs index 9e740a9e2..8d4e600fe 100644 --- a/crates/matrix-sdk/src/room/common.rs +++ b/crates/matrix-sdk/src/room/common.rs @@ -138,9 +138,7 @@ impl Common { /// let client = Client::new(homeserver).await.unwrap(); /// client.login(user, "password", None, None).await.unwrap(); /// let room_id = room_id!("!roomid:example.com"); - /// let room = client - /// .get_joined_room(&room_id) - /// .unwrap(); + /// let room = client.get_joined_room(&room_id).unwrap(); /// if let Some(avatar) = room.avatar(MediaFormat::File).await.unwrap() { /// std::fs::write("avatar.png", avatar); /// } @@ -165,7 +163,6 @@ impl Common { /// /// # Examples /// ```no_run - /// # use std::convert::TryFrom; /// use matrix_sdk::{room::MessagesOptions, Client}; /// # use matrix_sdk::ruma::{ /// # api::client::filter::RoomEventFilter, @@ -176,12 +173,11 @@ impl Common { /// # let homeserver = Url::parse("http://example.com").unwrap(); /// # use futures::executor::block_on; /// # block_on(async { - /// let options = MessagesOptions::backward().from("t47429-4392820_219380_26003_2265"); + /// let options = + /// MessagesOptions::backward().from("t47429-4392820_219380_26003_2265"); /// /// let mut client = Client::new(homeserver).await.unwrap(); - /// let room = client - /// .get_joined_room(room_id!("!roomid:example.com")) - /// .unwrap(); + /// let room = client.get_joined_room(room_id!("!roomid:example.com")).unwrap(); /// assert!(room.messages(options).await.is_ok()); /// # }); /// ``` @@ -209,12 +205,10 @@ impl Common { if let Some(machine) = self.client.olm_machine() { for event in http_response.chunk { let decrypted_event = if let Ok(AnySyncRoomEvent::MessageLike( - AnySyncMessageLikeEvent::RoomEncrypted(SyncMessageLikeEvent::Original( - encrypted_event, - )), + AnySyncMessageLikeEvent::RoomEncrypted(SyncMessageLikeEvent::Original(_)), )) = event.deserialize_as::() { - if let Ok(event) = machine.decrypt_room_event(&encrypted_event, room_id).await { + if let Ok(event) = machine.decrypt_room_event(event.cast_ref(), room_id).await { event } else { RoomEvent { event, encryption_info: None } @@ -263,7 +257,6 @@ impl Common { /// /// # Examples /// ```no_run - /// # use std::convert::TryFrom; /// use matrix_sdk::Client; /// # use matrix_sdk::ruma::{ /// # api::client::filter::RoomEventFilter, @@ -279,25 +272,26 @@ impl Common { /// /// let mut client = Client::new(homeserver).await?; /// - /// if let Some(room) = client.get_joined_room(room_id!("!roomid:example.com")) { - /// let (forward_stream, backward_stream) = room.timeline().await?; + /// if let Some(room) = client.get_joined_room(room_id!("!roomid:example.com")) + /// { + /// let (forward_stream, backward_stream) = room.timeline().await?; /// - /// tokio::spawn(async move { - /// pin_mut!(backward_stream); + /// tokio::spawn(async move { + /// pin_mut!(backward_stream); /// - /// while let Some(item) = backward_stream.next().await { - /// match item { - /// Ok(event) => println!("{:?}", event), - /// Err(_) => println!("Some error occurred!"), - /// } - /// } - /// }); + /// while let Some(item) = backward_stream.next().await { + /// match item { + /// Ok(event) => println!("{:?}", event), + /// Err(_) => println!("Some error occurred!"), + /// } + /// } + /// }); /// - /// pin_mut!(forward_stream); + /// pin_mut!(forward_stream); /// - /// while let Some(event) = forward_stream.next().await { - /// println!("{:?}", event); - /// } + /// while let Some(event) = forward_stream.next().await { + /// println!("{:?}", event); + /// } /// } /// /// # anyhow::Ok(()) @@ -315,8 +309,10 @@ impl Common { for await item in backward_store { match item { Ok(event) => yield Ok(event), - Err(TimelineStreamError::EndCache { fetch_more_token }) => if let Err(error) = room.request_messages(&fetch_more_token).await { - yield Err(error); + Err(TimelineStreamError::EndCache { fetch_more_token }) => { + if let Err(error) = room.request_messages(&fetch_more_token).await { + yield Err(error); + } }, Err(TimelineStreamError::Store(error)) => yield Err(error.into()), } @@ -343,7 +339,6 @@ impl Common { /// /// # Examples /// ```no_run - /// # use std::convert::TryFrom; /// use matrix_sdk::Client; /// # use matrix_sdk::ruma::{ /// # api::client::filter::RoomEventFilter, @@ -359,14 +354,15 @@ impl Common { /// /// let mut client = Client::new(homeserver).await?; /// - /// if let Some(room) = client.get_joined_room(room_id!("!roomid:example.com")) { - /// let forward_stream = room.timeline_forward().await?; + /// if let Some(room) = client.get_joined_room(room_id!("!roomid:example.com")) + /// { + /// let forward_stream = room.timeline_forward().await?; /// - /// pin_mut!(forward_stream); + /// pin_mut!(forward_stream); /// - /// while let Some(event) = forward_stream.next().await { - /// println!("{:?}", event); - /// } + /// while let Some(event) = forward_stream.next().await { + /// println!("{:?}", event); + /// } /// } /// /// # anyhow::Ok(()) @@ -402,7 +398,6 @@ impl Common { /// /// # Examples /// ```no_run - /// # use std::convert::TryFrom; /// use matrix_sdk::Client; /// # use matrix_sdk::ruma::{ /// # api::client::filter::RoomEventFilter, @@ -418,19 +413,20 @@ impl Common { /// /// let mut client = Client::new(homeserver).await?; /// - /// if let Some(room) = client.get_joined_room(room_id!("!roomid:example.com")) { - /// let backward_stream = room.timeline_backward().await?; + /// if let Some(room) = client.get_joined_room(room_id!("!roomid:example.com")) + /// { + /// let backward_stream = room.timeline_backward().await?; /// - /// tokio::spawn(async move { - /// pin_mut!(backward_stream); + /// tokio::spawn(async move { + /// pin_mut!(backward_stream); /// - /// while let Some(item) = backward_stream.next().await { - /// match item { - /// Ok(event) => println!("{:?}", event), - /// Err(_) => println!("Some error occurred!"), - /// } - /// } - /// }); + /// while let Some(item) = backward_stream.next().await { + /// match item { + /// Ok(event) => println!("{:?}", event), + /// Err(_) => println!("Some error occurred!"), + /// } + /// } + /// }); /// } /// /// # anyhow::Ok(()) @@ -445,8 +441,10 @@ impl Common { for await item in backward_store { match item { Ok(event) => yield Ok(event), - Err(TimelineStreamError::EndCache { fetch_more_token }) => if let Err(error) = room.request_messages(&fetch_more_token).await { - yield Err(error); + Err(TimelineStreamError::EndCache { fetch_more_token }) => { + if let Err(error) = room.request_messages(&fetch_more_token).await { + yield Err(error); + } }, Err(TimelineStreamError::Store(error)) => yield Err(error.into()), } @@ -489,10 +487,10 @@ impl Common { #[cfg(feature = "e2e-encryption")] { if let Ok(AnySyncRoomEvent::MessageLike(AnySyncMessageLikeEvent::RoomEncrypted( - SyncMessageLikeEvent::Original(encrypted_event), + SyncMessageLikeEvent::Original(_), ))) = event.deserialize_as::() { - if let Ok(event) = self.decrypt_event(&encrypted_event).await { + if let Ok(event) = self.decrypt_event(event.cast_ref()).await { return Ok(event); } } @@ -711,9 +709,12 @@ impl Common { /// ```no_run /// # async { /// # let room: matrix_sdk::room::Common = todo!(); - /// use matrix_sdk::ruma::{events::room::member::SyncRoomMemberEvent, serde::Raw}; + /// use matrix_sdk::ruma::{ + /// events::room::member::SyncRoomMemberEvent, serde::Raw, + /// }; /// - /// let room_members: Vec> = room.get_state_events_static().await?; + /// let room_members: Vec> = + /// room.get_state_events_static().await?; /// # anyhow::Ok(()) /// # }; /// ``` @@ -749,7 +750,8 @@ impl Common { /// use matrix_sdk::ruma::events::room::power_levels::SyncRoomPowerLevelsEvent; /// /// let power_levels: SyncRoomPowerLevelsEvent = room - /// .get_state_event_static("").await? + /// .get_state_event_static("") + /// .await? /// .expect("every room has a power_levels event") /// .deserialize()?; /// # anyhow::Ok(()) @@ -788,7 +790,9 @@ impl Common { /// use matrix_sdk::ruma::events::fully_read::FullyReadEventContent; /// /// match room.account_data_static::().await? { - /// Some(fully_read) => println!("Found read marker: {:?}", fully_read.deserialize()?), + /// Some(fully_read) => { + /// println!("Found read marker: {:?}", fully_read.deserialize()?) + /// } /// None => println!("No read marker for this room"), /// } /// # anyhow::Ok(()) @@ -847,7 +851,7 @@ impl Common { /// tag_info.order = Some(0.9); /// let user_tag = UserTagName::from_str("u.work")?; /// - /// room.set_tag(TagName::User(user_tag), tag_info ).await?; + /// room.set_tag(TagName::User(user_tag), tag_info).await?; /// } /// # anyhow::Ok(()) }); /// ``` @@ -928,9 +932,12 @@ impl Common { /// /// Returns the decrypted event. #[cfg(feature = "e2e-encryption")] - pub async fn decrypt_event(&self, event: &OriginalSyncRoomEncryptedEvent) -> Result { + pub async fn decrypt_event( + &self, + event: &Raw, + ) -> Result { if let Some(machine) = self.client.olm_machine() { - Ok(machine.decrypt_room_event(event, self.inner.room_id()).await?) + Ok(machine.decrypt_room_event(event.cast_ref(), self.inner.room_id()).await?) } else { Err(Error::NoOlmMachine) } @@ -1075,7 +1082,9 @@ impl Common { /// Options for [`messages`][Common::messages]. /// -/// See that method and for details. +/// See that method and +/// +/// for details. #[derive(Debug)] #[non_exhaustive] pub struct MessagesOptions<'a> { diff --git a/crates/matrix-sdk/src/room/joined.rs b/crates/matrix-sdk/src/room/joined.rs index 5ffad481e..b54a75093 100644 --- a/crates/matrix-sdk/src/room/joined.rs +++ b/crates/matrix-sdk/src/room/joined.rs @@ -159,6 +159,7 @@ impl Joined { /// /// ```no_run /// use std::time::Duration; + /// /// use matrix_sdk::ruma::api::client::typing::create_typing_event::v3::Typing; /// /// # use matrix_sdk::{ @@ -416,7 +417,6 @@ impl Joined { /// # use url::Url; /// # use futures::executor::block_on; /// # use matrix_sdk::ruma::room_id; - /// # use std::convert::TryFrom; /// # use serde::{Deserialize, Serialize}; /// use matrix_sdk::ruma::{ /// events::{ @@ -513,7 +513,6 @@ impl Joined { /// # use url::Url; /// # use futures::executor::block_on; /// # use matrix_sdk::ruma::room_id; - /// # use std::convert::TryFrom; /// # block_on(async { /// # let homeserver = Url::parse("http://localhost:8080")?; /// # let mut client = Client::new(homeserver).await?; @@ -557,7 +556,7 @@ impl Joined { if event_type == "m.reaction" { debug!( room_id = %self.room_id(), - "Sending plaintext event because the event type is {}", event_type + "Sending plaintext event because the event type is {event_type}", ); (Raw::new(&content)?.cast(), event_type) } else { @@ -577,11 +576,8 @@ impl Joined { let encrypted_content = olm.encrypt_room_event_raw(self.inner.room_id(), content, event_type).await?; - let raw_content = Raw::new(&encrypted_content) - .expect("Failed to serialize encrypted event") - .cast(); - (raw_content, "m.room.encrypted") + (encrypted_content.cast(), "m.room.encrypted") } } else { debug!( diff --git a/crates/matrix-sdk/src/room_member.rs b/crates/matrix-sdk/src/room_member.rs index a180993ee..8a9d86fab 100644 --- a/crates/matrix-sdk/src/room_member.rs +++ b/crates/matrix-sdk/src/room_member.rs @@ -51,9 +51,7 @@ impl RoomMember { /// let client = Client::new(homeserver).await.unwrap(); /// client.login(user, "password", None, None).await.unwrap(); /// let room_id = room_id!("!roomid:example.com"); - /// let room = client - /// .get_joined_room(&room_id) - /// .unwrap(); + /// let room = client.get_joined_room(&room_id).unwrap(); /// let members = room.members().await.unwrap(); /// let member = members.first().unwrap(); /// if let Some(avatar) = member.avatar(MediaFormat::File).await.unwrap() { diff --git a/crates/matrix-sdk/src/sync.rs b/crates/matrix-sdk/src/sync.rs index a580e56c5..3796c38d7 100644 --- a/crates/matrix-sdk/src/sync.rs +++ b/crates/matrix-sdk/src/sync.rs @@ -36,7 +36,7 @@ impl Client { for (room_id, room_info) in &rooms.join { let room = self.get_room(room_id); if room.is_none() { - error!("Can't call event handler, room {} not found", room_id); + error!(%room_id, "Can't call event handler, room not found"); continue; } @@ -53,7 +53,7 @@ impl Client { for (room_id, room_info) in &rooms.leave { let room = self.get_room(room_id); if room.is_none() { - error!("Can't call event handler, room {} not found", room_id); + error!(%room_id, "Can't call event handler, room not found"); continue; } @@ -68,7 +68,7 @@ impl Client { for (room_id, room_info) in &rooms.invite { let room = self.get_room(room_id); if room.is_none() { - error!("Can't call event handler, room {} not found", room_id); + error!(%room_id, "Can't call event handler, room not found"); continue; } @@ -88,7 +88,7 @@ impl Client { let room = match self.get_room(room_id) { Some(room) => room, None => { - warn!("Can't call notification handler, room {} not found", room_id); + warn!(%room_id, "Can't call notification handler, room not found"); continue; } }; @@ -128,9 +128,7 @@ impl Client { Ok(r) } Err(e) => { - error!("Received an invalid response: {}", e); - Self::sleep().await; - + error!("Received an invalid response: {e}"); Err(e) } } diff --git a/crates/matrix-sdk/tests/integration/client.rs b/crates/matrix-sdk/tests/integration/client.rs index c5c7b4cd0..8d9f936e3 100644 --- a/crates/matrix-sdk/tests/integration/client.rs +++ b/crates/matrix-sdk/tests/integration/client.rs @@ -172,7 +172,7 @@ async fn login_with_sso_token() { assert!(can_sso); let sso_url = client.get_sso_login_url("http://127.0.0.1:3030", None).await; - assert!(sso_url.is_ok()); + sso_url.unwrap(); Mock::given(method("POST")) .and(path("/_matrix/client/r0/login")) @@ -280,7 +280,7 @@ async fn devices() { .mount(&server) .await; - assert!(client.devices().await.is_ok()); + client.devices().await.unwrap(); } #[async_test] @@ -360,7 +360,7 @@ async fn resolve_room_alias() { .await; let alias = ruma::room_alias_id!("#alias:example.org"); - assert!(client.resolve_room_alias(alias).await.is_ok()); + client.resolve_room_alias(alias).await.unwrap(); } #[async_test] @@ -520,9 +520,9 @@ async fn get_media_content() { .mount(&server) .await; - assert!(client.get_media_content(&request, true).await.is_ok()); - assert!(client.get_media_content(&request, true).await.is_ok()); - assert!(client.get_media_content(&request, false).await.is_ok()); + client.get_media_content(&request, true).await.unwrap(); + client.get_media_content(&request, true).await.unwrap(); + client.get_media_content(&request, false).await.unwrap(); } #[async_test] @@ -548,8 +548,8 @@ async fn get_media_file() { .mount(&server) .await; - assert!(client.get_file(event_content.clone(), true).await.is_ok()); - assert!(client.get_file(event_content.clone(), true).await.is_ok()); + client.get_file(event_content.clone(), true).await.unwrap(); + client.get_file(event_content.clone(), true).await.unwrap(); Mock::given(method("GET")) .and(path("/_matrix/media/r0/thumbnail/example%2Eorg/image")) @@ -561,14 +561,14 @@ async fn get_media_file() { .mount(&server) .await; - assert!(client + client .get_thumbnail( event_content, MediaThumbnailSize { method: Method::Scale, width: uint!(100), height: uint!(100) }, - true + true, ) .await - .is_ok()); + .unwrap(); } #[async_test] diff --git a/crates/matrix-sdk/tests/integration/room/joined.rs b/crates/matrix-sdk/tests/integration/room/joined.rs index 78d00a239..0d4b0668b 100644 --- a/crates/matrix-sdk/tests/integration/room/joined.rs +++ b/crates/matrix-sdk/tests/integration/room/joined.rs @@ -412,7 +412,7 @@ async fn room_attachment_send_wrong_info() { let response = room.send_attachment("image", &mime::IMAGE_JPEG, &mut media, config).await; - assert!(response.is_err()) + response.unwrap_err(); } #[async_test] diff --git a/labs/sled-state-inspector/src/main.rs b/labs/sled-state-inspector/src/main.rs index 2785345c3..97cd9aac4 100644 --- a/labs/sled-state-inspector/src/main.rs +++ b/labs/sled-state-inspector/src/main.rs @@ -1,4 +1,4 @@ -use std::{convert::TryFrom, fmt::Debug, sync::Arc}; +use std::{fmt::Debug, sync::Arc}; use atty::Stream; use clap::{Arg, ArgMatches, Command as Argparse}; @@ -66,7 +66,7 @@ impl InspectorHelper { fn complete_event_types(&self, arg: Option<&&str>) -> Vec { Self::EVENT_TYPES .iter() - .map(|&t| Pair { display: t.to_owned(), replacement: format!("{} ", t) }) + .map(|&t| Pair { display: t.to_owned(), replacement: format!("{t} ") }) .filter(|r| if let Some(arg) = arg { r.replacement.starts_with(arg) } else { true }) .collect() } @@ -105,7 +105,7 @@ impl Completer for InspectorHelper { ("get-members", "get all the membership events in the given room"), ] .iter() - .map(|(r, d)| Pair { display: format!("{} ({})", r, d), replacement: format!("{} ", r) }) + .map(|(r, d)| Pair { display: format!("{r} ({d})"), replacement: format!("{r} ") }) .collect(); if args.is_empty() { @@ -188,13 +188,13 @@ impl Printer { for line in LinesWithEndings::from(&data) { let ranges: Vec<(Style, &str)> = h.highlight(line, &self.ps); let escaped = as_24_bit_terminal_escaped(&ranges[..], false); - print!("{}", escaped); + print!("{escaped}"); } // Clear the formatting println!("\x1b[0m"); } else { - println!("{}", data); + println!("{data}"); } } } @@ -202,8 +202,12 @@ impl Printer { impl Inspector { fn new(database_path: &str, json: bool, color: bool) -> Self { let printer = Printer::new(json, color); - let store = - Arc::new(StateStore::open_with_path(database_path).expect("Can't open sled database")); + let store = Arc::new( + StateStore::builder() + .path(database_path.into()) + .build() + .expect("Can't open sled database"), + ); Self { store, printer } } @@ -314,7 +318,7 @@ impl Inspector { self.run(m).await; } Err(e) => { - println!("{}", e); + println!("{e}"); } } } diff --git a/testing/matrix-sdk-test-macros/src/lib.rs b/testing/matrix-sdk-test-macros/src/lib.rs index 7e87424b2..5885a2a9b 100644 --- a/testing/matrix-sdk-test-macros/src/lib.rs +++ b/testing/matrix-sdk-test-macros/src/lib.rs @@ -38,13 +38,15 @@ pub fn async_test(_attr: TokenStream, item: TokenStream) -> TokenStream { let fn_call: TokenStream = if fun.sig.asyncness.is_some() { quote! { { - assert!(#fn_name().await.is_ok()); + let res = #fn_name().await; + assert!(res.is_ok(), "{:?}", res); } } } else { quote! { { - assert!(#fn_name().is_ok()); + let res = #fn_name(); + assert!(res.is_ok(), "{:?}", res); } } } diff --git a/testing/matrix-sdk-test/src/appservice.rs b/testing/matrix-sdk-test/src/appservice.rs index 29244f00f..cd3c3fdb8 100644 --- a/testing/matrix-sdk-test/src/appservice.rs +++ b/testing/matrix-sdk-test/src/appservice.rs @@ -1,5 +1,3 @@ -use std::convert::TryFrom; - use ruma::{events::AnyRoomEvent, serde::Raw}; use serde_json::Value; diff --git a/xtask/src/ci.rs b/xtask/src/ci.rs index c3a1df7c7..a08352c60 100644 --- a/xtask/src/ci.rs +++ b/xtask/src/ci.rs @@ -12,6 +12,9 @@ pub struct CiArgs { cmd: Option, } +const WASM_TIMEOUT_ENV_KEY: &str = "WASM_BINDGEN_TEST_TIMEOUT"; +const WASM_TIMEOUT_VALUE: &str = "120"; + #[derive(Subcommand)] enum CiCommand { /// Check style @@ -69,6 +72,7 @@ enum WasmFeatureSet { MatrixSdkIndexeddbStores, IndexeddbNoCrypto, IndexeddbWithCrypto, + Indexeddb, MatrixSdkCommandBot, } @@ -197,6 +201,12 @@ fn run_appservice_tests() -> Result<()> { } fn run_wasm_checks(cmd: Option) -> Result<()> { + if let Some(WasmFeatureSet::Indexeddb) = cmd { + run_wasm_checks(Some(WasmFeatureSet::IndexeddbNoCrypto))?; + run_wasm_checks(Some(WasmFeatureSet::IndexeddbWithCrypto))?; + return Ok(()); + } + let args = BTreeMap::from([ (WasmFeatureSet::MatrixSdkQrcode, "-p matrix-sdk-qrcode"), ( @@ -225,6 +235,7 @@ fn run_wasm_checks(cmd: Option) -> Result<()> { cmd!("rustup run stable cargo clippy --target wasm32-unknown-unknown") .args(arg_set.split_whitespace()) .args(["--", "-D", "warnings"]) + .env(WASM_TIMEOUT_ENV_KEY, WASM_TIMEOUT_VALUE) .run() }; @@ -233,6 +244,7 @@ fn run_wasm_checks(cmd: Option) -> Result<()> { cmd!("rustup run stable cargo clippy --target wasm32-unknown-unknown") .args(["--", "-D", "warnings", "-A", "clippy::unused-unit"]) + .env(WASM_TIMEOUT_ENV_KEY, WASM_TIMEOUT_VALUE) .run() }; @@ -258,6 +270,11 @@ fn run_wasm_checks(cmd: Option) -> Result<()> { } fn run_wasm_pack_tests(cmd: Option) -> Result<()> { + if let Some(WasmFeatureSet::Indexeddb) = cmd { + run_wasm_pack_tests(Some(WasmFeatureSet::IndexeddbNoCrypto))?; + run_wasm_pack_tests(Some(WasmFeatureSet::IndexeddbWithCrypto))?; + return Ok(()); + } let args = BTreeMap::from([ (WasmFeatureSet::MatrixSdkQrcode, ("matrix-sdk-qrcode", "")), ( @@ -286,16 +303,24 @@ fn run_wasm_pack_tests(cmd: Option) -> Result<()> { ]); let run = |(folder, arg_set): (&str, &str)| { - let _p = pushd(format!("crates/{}", folder)); - cmd!("pwd").run()?; // print dir so we know what might have failed - cmd!("wasm-pack test --node -- ").args(arg_set.split_whitespace()).run()?; - cmd!("wasm-pack test --firefox --headless --").args(arg_set.split_whitespace()).run() + let _p = pushd(format!("crates/{folder}")); + cmd!("pwd").env(WASM_TIMEOUT_ENV_KEY, WASM_TIMEOUT_VALUE).run()?; // print dir so we know what might have failed + cmd!("wasm-pack test --node -- ") + .args(arg_set.split_whitespace()) + .env(WASM_TIMEOUT_ENV_KEY, WASM_TIMEOUT_VALUE) + .run()?; + cmd!("wasm-pack test --firefox --headless --") + .args(arg_set.split_whitespace()) + .env(WASM_TIMEOUT_ENV_KEY, WASM_TIMEOUT_VALUE) + .run() }; let test_command_bot = || { let _p = pushd("crates/matrix-sdk/examples/wasm_command_bot"); - cmd!("wasm-pack test --node").run()?; - cmd!("wasm-pack test --firefox --headless").run() + cmd!("wasm-pack test --node").env(WASM_TIMEOUT_ENV_KEY, WASM_TIMEOUT_VALUE).run()?; + cmd!("wasm-pack test --firefox --headless") + .env(WASM_TIMEOUT_ENV_KEY, WASM_TIMEOUT_VALUE) + .run() }; match cmd {