Merge remote-tracking branch 'upstream/main' into ben-wasm-store

This commit is contained in:
Benjamin Kampmann
2021-11-24 11:48:45 +01:00
40 changed files with 2078 additions and 389 deletions

View File

@@ -3,6 +3,7 @@
# Or if we're able to ignore certain lines.
Fo = "Fo"
BA = "BA"
UE = "UE"
[files]
# Our json files contain a bunch of base64 encoded ed25519 keys which aren't

View File

@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{net::ToSocketAddrs, result::Result as StdResult};
use std::net::ToSocketAddrs;
use matrix_sdk::{
bytes::Bytes,
@@ -168,7 +168,7 @@ mod handlers {
_user_id: String,
appservice: AppService,
request: http::Request<Bytes>,
) -> StdResult<impl warp::Reply, Rejection> {
) -> Result<impl warp::Reply, Rejection> {
if let Some(user_exists) = appservice.event_handler.users.lock().await.as_mut() {
let request =
query_user::IncomingRequest::try_from_http_request(request).map_err(Error::from)?;
@@ -185,7 +185,7 @@ mod handlers {
_room_id: String,
appservice: AppService,
request: http::Request<Bytes>,
) -> StdResult<impl warp::Reply, Rejection> {
) -> Result<impl warp::Reply, Rejection> {
if let Some(room_exists) = appservice.event_handler.rooms.lock().await.as_mut() {
let request =
query_room::IncomingRequest::try_from_http_request(request).map_err(Error::from)?;
@@ -202,7 +202,7 @@ mod handlers {
_txn_id: String,
appservice: AppService,
request: http::Request<Bytes>,
) -> StdResult<impl warp::Reply, Rejection> {
) -> Result<impl warp::Reply, Rejection> {
let incoming_transaction: ruma::api::appservice::event::push_events::v1::IncomingRequest =
ruma::api::IncomingRequest::try_from_http_request(request).map_err(Error::from)?;
@@ -224,7 +224,7 @@ struct ErrorMessage {
message: String,
}
pub async fn handle_rejection(err: Rejection) -> std::result::Result<impl Reply, Rejection> {
pub async fn handle_rejection(err: Rejection) -> Result<impl Reply, Rejection> {
if err.find::<Unauthorized>().is_some() || err.find::<warp::reject::InvalidQuery>().is_some() {
let code = http::StatusCode::UNAUTHORIZED;
let message = "UNAUTHORIZED";

View File

@@ -58,7 +58,7 @@ default-features = false
features = ["sync", "fs"]
[dev-dependencies]
futures = { version = "0.3.15", default-features = false }
futures = { version = "0.3.15", default-features = false, features = ["executor"] }
http = "0.2.4"
matrix-sdk-test = { version = "0.4.0", path = "../matrix-sdk-test" }

View File

@@ -15,11 +15,26 @@
//! User sessions.
use ruma::{DeviceId, UserId};
use ruma::{DeviceIdBox, UserId};
use serde::{Deserialize, Serialize};
/// A user session, containing an access token and information about the
/// associated user account.
///
/// # Example
///
/// ```
/// use matrix_sdk_base::Session;
/// use ruma::{DeviceIdBox, user_id};
///
/// let session = Session {
/// access_token: "My-Token".to_owned(),
/// user_id: user_id!("@example:localhost"),
/// device_id: DeviceIdBox::from("MYDEVICEID"),
/// };
///
/// assert_eq!(session.device_id.as_str(), "MYDEVICEID");
/// ```
#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)]
pub struct Session {
/// The access token used for this session.
@@ -27,5 +42,15 @@ pub struct Session {
/// The user the access token was issued for.
pub user_id: UserId,
/// The ID of the client device
pub device_id: Box<DeviceId>,
pub device_id: DeviceIdBox,
}
impl From<ruma::api::client::r0::session::login::Response> for Session {
fn from(response: ruma::api::client::r0::session::login::Response) -> Self {
Self {
access_token: response.access_token,
user_id: response.user_id,
device_id: response.device_id,
}
}
}

View File

@@ -33,7 +33,7 @@ version = "0.1.9"
features = ["now"]
[target.'cfg(target_arch = "wasm32")'.dependencies]
futures-util = { version = "0.3.15", default-features = false }
futures-locks = { version = "0.6.0", default-features = false }
async-lock = "2.4.0"
futures-util = { version = "0.3.15", default-features = false, features = ["channel"] }
wasm-bindgen-futures = "0.4.24"
uuid = { version = "0.8.2", default-features = false, features = ["v4", "wasm-bindgen"] }

View File

@@ -1,8 +1,4 @@
// could switch to futures-lock completely at some point, blocker:
// https://github.com/asomers/futures-locks/issues/34
// https://www.reddit.com/r/rust/comments/f4zldz/i_audited_3_different_implementation_of_async/
#[cfg(target_arch = "wasm32")]
pub use futures_locks::{Mutex, MutexGuard, RwLock, RwLockReadGuard, RwLockWriteGuard};
pub use async_lock::{Mutex, MutexGuard, RwLock, RwLockReadGuard, RwLockWriteGuard};
#[cfg(not(target_arch = "wasm32"))]
pub use tokio::sync::{Mutex, MutexGuard, RwLock, RwLockReadGuard, RwLockWriteGuard};

View File

@@ -16,8 +16,9 @@ features = ["docs"]
rustdoc-args = ["--cfg", "feature=\"docs\""]
[features]
default = []
default = ["backups_v1"]
qrcode = ["matrix-qrcode"]
backups_v1 = []
sled_cryptostore = ["sled"]
docs = ["sled_cryptostore"]
indexeddb_cryptostore = ["indexed_db_futures", "wasm-bindgen"]
@@ -27,6 +28,7 @@ aes = { version = "0.7.4", features = ["ctr"] }
aes-gcm = "0.9.2"
atomic = "0.5.0"
base64 = "0.13.0"
bs58 = "0.4.0"
byteorder = "1.4.3"
dashmap = "4.0.2"
futures-util = { version = "0.3.15", default-features = false }
@@ -34,9 +36,13 @@ getrandom = "0.2.3"
hmac = "0.11.0"
matrix-qrcode = { version = "0.2.0", path = "../matrix-qrcode", optional = true }
matrix-sdk-common = { version = "0.4.0", path = "../matrix-sdk-common" }
olm-rs = { version = "2.0.1", features = ["serde"] }
olm-rs = { version = "2.1", features = ["serde"] }
pbkdf2 = { version = "0.9.0", default-features = false }
ruma = { git = "https://github.com/ruma/ruma", rev = "ac6ecc3e5", features = ["client-api-c", "unstable-pre-spec"] }
rand = "0.8.4"
ruma = { git = "https://github.com/ruma/ruma", rev = "ac6ecc3e5", features = [
"client-api-c",
"unstable-pre-spec",
] }
serde = { version = "1.0.126", features = ["derive", "rc"] }
serde_json = "1.0.64"
sha2 = "0.9.5"
@@ -50,17 +56,24 @@ indexed_db_futures = { version = "0.2.0", optional = true }
wasm-bindgen = { version = "0.2.74", features = ["serde-serialize"], optional = true }
[dev-dependencies]
futures = { version = "0.3.15", default-features = false }
futures = { version = "0.3.15", default-features = false, features = ["executor"] }
http = "0.2.4"
indoc = "1.0.3"
matches = "0.1.8"
matrix-sdk-test = { version = "0.4.0", path = "../matrix-sdk-test" }
[target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies]
criterion = { version = "0.3.4", features = [
"async",
"async_tokio",
"html_reports",
] }
proptest = "1.0.0"
criterion = { version = "0.3.4", features = ["async", "async_tokio", "html_reports"] }
tokio = { version = "1.7.1", default-features = false, features = ["rt-multi-thread", "macros"] }
tempfile = "3.2.0"
tokio = { version = "1.7.1", default-features = false, features = [
"rt-multi-thread",
"macros",
] }
[target.'cfg(target_os = "linux")'.dev-dependencies]
criterion = { version = "0.3.4", features = ["async", "html_reports"] }

View File

@@ -0,0 +1,148 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{
collections::BTreeMap,
sync::{Arc, Mutex},
};
use olm_rs::pk::OlmPkEncryption;
use ruma::{
api::client::r0::backup::{KeyBackupData, KeyBackupDataInit, SessionDataInit},
DeviceKeyId, UserId,
};
use zeroize::Zeroizing;
use super::recovery::DecodeError;
use crate::olm::InboundGroupSession;
#[derive(Debug)]
struct InnerBackupKey {
key: [u8; MegolmV1BackupKey::KEY_SIZE],
signatures: BTreeMap<UserId, BTreeMap<DeviceKeyId, String>>,
version: Mutex<Option<String>>,
}
/// The public part of a backup key.
#[derive(Clone)]
pub struct MegolmV1BackupKey {
inner: Arc<InnerBackupKey>,
}
impl std::fmt::Debug for MegolmV1BackupKey {
fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
formatter
.debug_struct("MegolmV1BackupKey")
.field("key", &self.to_base64())
.field("version", &self.backup_version())
.finish()
}
}
impl MegolmV1BackupKey {
const KEY_SIZE: usize = 32;
pub(super) fn new(public_key: &str, version: Option<String>) -> Self {
let key = Self::from_base64(public_key).expect("Invalid backup key");
if let Some(version) = version {
key.set_version(version)
}
key
}
/// Get the full name of the backup algorithm this backup key supports.
pub fn backup_algorithm(&self) -> &str {
"m.megolm_backup.v1.curve25519-aes-sha2"
}
/// Get all the signatures of this `MegolmV1BackupKey`.
pub fn signatures(&self) -> BTreeMap<UserId, BTreeMap<DeviceKeyId, String>> {
self.inner.signatures.to_owned()
}
/// Try to create a new `MegolmV1BackupKey` from a base 64 encoded string.
pub fn from_base64(public_key: &str) -> Result<Self, DecodeError> {
let mut key = [0u8; Self::KEY_SIZE];
let decoded_key = crate::utilities::decode(public_key)?;
if decoded_key.len() != Self::KEY_SIZE {
Err(DecodeError::Length(Self::KEY_SIZE, decoded_key.len()))
} else {
key.copy_from_slice(&decoded_key);
let inner =
InnerBackupKey { key, signatures: Default::default(), version: Mutex::new(None) };
Ok(MegolmV1BackupKey { inner: inner.into() })
}
}
/// Convert the [`MegolmV1BackupKey`] to a base 64 encoded string.
pub fn to_base64(&self) -> String {
crate::utilities::encode(&self.inner.key)
}
/// Get the backup version that this key is used with, if any.
pub fn backup_version(&self) -> Option<String> {
self.inner.version.lock().unwrap().clone()
}
/// Set the backup version that this `MegolmV1BackupKey` will be used with.
///
/// The key won't be able to encrypt room keys unless a version has been
/// set.
pub fn set_version(&self, version: String) {
*self.inner.version.lock().unwrap() = Some(version);
}
pub(crate) async fn encrypt(&self, session: InboundGroupSession) -> KeyBackupData {
let pk = OlmPkEncryption::new(&self.to_base64());
// It's ok to truncate here, there's a semantic difference only between
// 0 and 1+ anyways.
let forwarded_count = (session.forwarding_key_chain().len() as u32).into();
let first_message_index = session.first_known_index().into();
// Convert our key to the backup representation.
let key = session.to_backup().await;
// The key gets zeroized in `BackedUpRoomKey` but we're creating a copy
// here that won't, so let's wrap it up in a `Zeroizing` struct.
let key =
Zeroizing::new(serde_json::to_string(&key).expect("Can't serialize exported room key"));
let message = pk.encrypt(&key);
let session_data = SessionDataInit {
ephemeral: message.ephemeral_key,
ciphertext: message.ciphertext,
mac: message.mac,
}
.into();
KeyBackupDataInit {
first_message_index,
forwarded_count,
// TODO is this actually used anywhere? seems to be completely
// useless and requires us to get the Device out of the store?
// Also should this be checked at the time of the backup or at the
// time of the room key receival?
is_verified: false,
session_data,
}
.into()
}
}

View File

@@ -0,0 +1,54 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Module for the keys that are used to backup room keys.
//!
//! The backup key is split into two parts:
//!
//! ```text
//! ┌───────────────────────────┐
//! │ RecoveryKey | BackupKey │
//! └───────────────────────────┘
//! ```
//!
//! 1. RecoveryKey, a private Curve25519 key that is used to decrypt backed up
//! room keys.
//! 2. BackupKey, the public part of the Curve25519 RecoveryKey, this one is
//! used to encrypt room keys that get backed up.
//!
//! The `RecoveryKey` can be derived from a passphrase or randomly generated.
//! To allow other devices access to this key you'll either have to re-enter the
//! passphrase or the key as a base58 encoded string or received from another
//! device using the secret sharing mechanism.
//!
//! In practice deriving a recovery key from a passphrase isn't done, and is
//! **not** supported by the spec. Instead the secret storage key that encrypts
//! secrets and puts them into your global account data gets derived from a
//! passphrase and presented as a base58 encoded string, confusingly also called
//! recovery key.
//!
//! The `RecoveryKey` can also be uploaded to your account data as an
//! `m.megolm.v1` event, which gets encrypted by the SSSS key.
//!
//! The `MegolmV1BackupKey` is used to encrypt individual room keys so they can
//! be uploaded to the homeserver.
//!
//! The `MegolmV1BackupKey` is a public key and is uploaded to the server using
//! the `/room_keys/version` API endpoint.
mod backup;
mod recovery;
pub use backup::MegolmV1BackupKey;
pub use recovery::{DecodeError, PickledRecoveryKey, RecoveryKey};

View File

@@ -0,0 +1,354 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{
convert::TryFrom,
io::{Cursor, Read},
};
use aes::cipher::generic_array::GenericArray;
use aes_gcm::{
aead::{Aead, NewAead},
Aes256Gcm,
};
use bs58;
use olm_rs::{
errors::OlmPkDecryptionError,
pk::{OlmPkDecryption, PkMessage},
};
use rand::{thread_rng, Error as RandomError, Fill};
use serde::{Deserialize, Serialize};
use thiserror::Error;
use zeroize::{Zeroize, Zeroizing};
use super::MegolmV1BackupKey;
use crate::utilities::{decode_url_safe, encode, encode_url_safe};
const NONCE_SIZE: usize = 12;
/// Error type for the decoding of a RecoveryKey.
#[derive(Debug, Error)]
pub enum DecodeError {
/// The decoded recovery key has an invalid prefix.
#[error("The decoded recovery key has an invalid prefix: expected {0:?}, got {1:?}")]
Prefix([u8; 2], [u8; 2]),
/// The parity byte of the recovery key didn't match.
#[error("The parity byte of the recovery key doesn't match: expected {0:?}, got {1:?}")]
Parity(u8, u8),
/// The recovery key has an invalid length.
#[error("The decoded recovery key has a invalid length: expected {0}, got {1}")]
Length(usize, usize),
/// The recovry key isn't valid base58.
#[error(transparent)]
Base58(#[from] bs58::decode::Error),
/// The recovery key isn't valid base64.
#[error(transparent)]
Base64(#[from] base64::DecodeError),
/// The recovery key is too short, we couldn't read enough data.
#[error(transparent)]
Io(#[from] std::io::Error),
}
#[derive(Debug, Error)]
pub enum UnpicklingError {
#[error(transparent)]
Json(#[from] serde_json::Error),
#[error("Couldn't decrypt the pickle: {0}")]
Decryption(String),
#[error(transparent)]
Decode(#[from] DecodeError),
}
/// The private part of a backup key.
#[derive(Zeroize)]
pub struct RecoveryKey {
inner: [u8; RecoveryKey::KEY_SIZE],
}
impl Drop for RecoveryKey {
fn drop(&mut self) {
self.inner.zeroize()
}
}
impl std::fmt::Debug for RecoveryKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RecoveryKey").finish()
}
}
/// The pickled version of a recovery key.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PickledRecoveryKey(String);
impl AsRef<str> for PickledRecoveryKey {
fn as_ref(&self) -> &str {
&self.0
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct InnerPickle {
version: u8,
nonce: String,
ciphertext: String,
}
impl TryFrom<String> for RecoveryKey {
type Error = DecodeError;
fn try_from(value: String) -> Result<Self, Self::Error> {
Self::from_base58(&value)
}
}
impl std::fmt::Display for RecoveryKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let string = Zeroizing::new(self.to_base58());
let string = Zeroizing::new(
string
.chars()
.collect::<Vec<char>>()
.chunks(Self::DISPLAY_CHUNK_SIZE)
.map(|c| c.iter().collect::<String>())
.collect::<Vec<_>>()
.join(" "),
);
write!(f, "{}", string.as_str())
}
}
impl RecoveryKey {
const KEY_SIZE: usize = 32;
const PREFIX: [u8; 2] = [0x8b, 0x01];
const PREFIX_PARITY: u8 = Self::PREFIX[0] ^ Self::PREFIX[1];
const DISPLAY_CHUNK_SIZE: usize = 4;
fn parity_byte(bytes: &[u8]) -> u8 {
bytes.iter().fold(Self::PREFIX_PARITY, |acc, x| acc ^ x)
}
/// Create a new random recovery key.
pub fn new() -> Result<Self, RandomError> {
let mut rng = thread_rng();
let mut key = [0u8; Self::KEY_SIZE];
key.try_fill(&mut rng)?;
Ok(Self { inner: key })
}
/// Create a new recovery key from the given byte array.
///
/// **Warning**: You need to make sure that the byte array contains correct
/// random data, either by using a random number generator or by using an
/// exported version of a previously created [`RecoveryKey`].
pub fn from_bytes(key: [u8; Self::KEY_SIZE]) -> Self {
Self { inner: key }
}
/// Try to create a [`RecoveryKey`] from a base64 export of a `RecoveryKey`.
pub fn from_base64(key: &str) -> Result<Self, DecodeError> {
let decoded = Zeroizing::new(crate::utilities::decode(key)?);
if decoded.len() != Self::KEY_SIZE {
Err(DecodeError::Length(decoded.len(), Self::KEY_SIZE))
} else {
let mut key = [0u8; Self::KEY_SIZE];
key.copy_from_slice(&decoded);
Ok(Self::from_bytes(key))
}
}
/// Export the `RecoveryKey` as a base64 encoded string.
pub fn to_base64(&self) -> String {
encode(self.inner)
}
/// Try to create a [`RecoveryKey`] from a base58 export of a `RecoveryKey`.
pub fn from_base58(value: &str) -> Result<Self, DecodeError> {
// Remove any whitespace we might have
let value: String = value.chars().filter(|c| !c.is_whitespace()).collect();
let decoded = bs58::decode(value).with_alphabet(bs58::Alphabet::BITCOIN).into_vec()?;
let mut decoded = Cursor::new(decoded);
let mut prefix = [0u8; 2];
let mut key = [0u8; Self::KEY_SIZE];
let mut expected_parity = [0u8; 1];
decoded.read_exact(&mut prefix)?;
decoded.read_exact(&mut key)?;
decoded.read_exact(&mut expected_parity)?;
let expected_parity = expected_parity[0];
let parity = Self::parity_byte(key.as_ref());
let _ = Zeroizing::new(decoded.into_inner());
if prefix != Self::PREFIX {
Err(DecodeError::Prefix(Self::PREFIX, prefix))
} else if expected_parity != parity {
Err(DecodeError::Parity(expected_parity, parity))
} else {
Ok(Self::from_bytes(key))
}
}
/// Export the `RecoveryKey` as a base58 encoded string.
pub fn to_base58(&self) -> String {
let bytes = Zeroizing::new(
[
Self::PREFIX.as_ref(),
self.inner.as_ref(),
[Self::parity_byte(self.inner.as_ref())].as_ref(),
]
.concat(),
);
bs58::encode(bytes.as_slice()).with_alphabet(bs58::Alphabet::BITCOIN).into_string()
}
fn get_pk_decrytpion(&self) -> OlmPkDecryption {
OlmPkDecryption::from_bytes(self.inner.as_ref())
.expect("Can't generate a libom PkDecryption object from our private key")
}
/// Extract the megolm.v1 public key from this `RecoveryKey`.
pub fn megolm_v1_public_key(&self) -> MegolmV1BackupKey {
let pk = self.get_pk_decrytpion();
let public_key = MegolmV1BackupKey::new(pk.public_key(), None);
public_key
}
/// Export this [`RecoveryKey`] as an encrypted pickle that can be safely
/// stored.
pub fn pickle(&self, pickle_key: &[u8]) -> PickledRecoveryKey {
let key = GenericArray::from_slice(pickle_key);
let cipher = Aes256Gcm::new(key);
let mut nonce = vec![0u8; NONCE_SIZE];
let mut rng = thread_rng();
nonce.try_fill(&mut rng).expect("Can't generate random nocne to pickle the recovery key");
let nonce = GenericArray::from_slice(nonce.as_slice());
let ciphertext =
cipher.encrypt(nonce, self.inner.as_ref()).expect("Can't encrypt recovery key");
let ciphertext = encode_url_safe(ciphertext);
let pickle =
InnerPickle { version: 1, nonce: encode_url_safe(nonce.as_slice()), ciphertext };
PickledRecoveryKey(serde_json::to_string(&pickle).expect("Can't encode pickled signing"))
}
/// Try to import a `RecoveryKey` from a previously exported pickle.
pub fn from_pickle(
pickle: PickledRecoveryKey,
pickle_key: &[u8],
) -> Result<Self, UnpicklingError> {
let pickled: InnerPickle = serde_json::from_str(pickle.as_ref())?;
let key = GenericArray::from_slice(pickle_key);
let cipher = Aes256Gcm::new(key);
let nonce = decode_url_safe(pickled.nonce).map_err(DecodeError::from)?;
let nonce = GenericArray::from_slice(&nonce);
let ciphertext = &decode_url_safe(pickled.ciphertext).map_err(DecodeError::from)?;
let decrypted = cipher
.decrypt(nonce, ciphertext.as_slice())
.map_err(|e| UnpicklingError::Decryption(e.to_string()))?;
if decrypted.len() != Self::KEY_SIZE {
Err(DecodeError::Length(decrypted.len(), Self::KEY_SIZE).into())
} else {
let mut key = [0u8; Self::KEY_SIZE];
key.copy_from_slice(&decrypted);
Ok(Self { inner: key })
}
}
/// Try to decrypt the given ciphertext using this `RecoveryKey`.
///
/// This will use the [`m.megolm_backup.v1.curve25519-aes-sha2`] algorithm
/// to decrypt the given ciphertext.
///
/// [`m.megolm_backup.v1.curve25519-aes-sha2`]:
/// https://spec.matrix.org/unstable/client-server-api/#backup-algorithm-mmegolm_backupv1curve25519-aes-sha2
pub fn decrypt_v1(
&self,
mac: String,
ephemeral_key: String,
ciphertext: String,
) -> Result<String, OlmPkDecryptionError> {
let message = PkMessage::new(mac, ephemeral_key, ciphertext);
let pk = self.get_pk_decrytpion();
pk.decrypt(message)
}
}
#[cfg(test)]
mod test {
use super::{DecodeError, RecoveryKey};
const TEST_KEY: [u8; 32] = [
0x77, 0x07, 0x6D, 0x0A, 0x73, 0x18, 0xA5, 0x7D, 0x3C, 0x16, 0xC1, 0x72, 0x51, 0xB2, 0x66,
0x45, 0xDF, 0x4C, 0x2F, 0x87, 0xEB, 0xC0, 0x99, 0x2A, 0xB1, 0x77, 0xFB, 0xA5, 0x1D, 0xB9,
0x2C, 0x2A,
];
#[test]
fn base64_decoding() -> Result<(), DecodeError> {
let key = RecoveryKey::new().expect("Can't create a new recovery key");
let base64 = key.to_base64();
let decoded_key = RecoveryKey::from_base64(&base64)?;
assert_eq!(key.inner, decoded_key.inner, "The decode key does't match the original");
RecoveryKey::from_base64("i").expect_err("The recovery key is too short");
Ok(())
}
#[test]
fn base58_decoding() -> Result<(), DecodeError> {
let key = RecoveryKey::new().expect("Can't create a new recovery key");
let base64 = key.to_base58();
let decoded_key = RecoveryKey::from_base58(&base64)?;
assert_eq!(key.inner, decoded_key.inner, "The decode key does't match the original");
let test_key =
RecoveryKey::from_base58("EsTcLW2KPGiFwKEA3As5g5c4BXwkqeeJZJV8Q9fugUMNUE4d")?;
assert_eq!(test_key.inner, TEST_KEY, "The decoded recovery key doesn't match the test key");
let test_key = RecoveryKey::from_base58(
"EsTc LW2K PGiF wKEA 3As5 g5c4 BXwk qeeJ ZJV8 Q9fu gUMN UE4d",
)?;
assert_eq!(test_key.inner, TEST_KEY, "The decoded recovery key doesn't match the test key");
RecoveryKey::from_base58("EsTc LW2K PGiF wKEA 3As5 g5c4 BXwk qeeJ ZJV8 Q9fu gUMN UE4e")
.expect_err("Can't create a recovery key if the parity byte is invalid");
Ok(())
}
}

View File

@@ -0,0 +1,519 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Server-side backup support for room keys
//!
//! This module largely implements support for server-side backups using the
//! `m.megolm_backup.v1.curve25519-aes-sha2` backup algorithm.
//!
//! Due to various flaws in this backup algorithm it is **not** recommended to
//! use this module or any of its functionality. The module is only provided for
//! backwards compatibility.
//!
//! [spec]: https://spec.matrix.org/unstable/client-server-api/#server-side-key-backups
use std::{
collections::{BTreeMap, BTreeSet},
sync::Arc,
};
use matrix_sdk_common::{locks::RwLock, uuid::Uuid};
use ruma::{
api::client::r0::backup::RoomKeyBackup, DeviceKeyAlgorithm, DeviceKeyId, RoomId, UserId,
};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tracing::{debug, info, instrument, trace, warn};
use crate::{
olm::{Account, InboundGroupSession},
store::{BackupKeys, Changes, RoomKeyCounts, Store},
CryptoStoreError, KeysBackupRequest, OutgoingRequest,
};
mod keys;
pub use keys::{DecodeError, MegolmV1BackupKey, PickledRecoveryKey, RecoveryKey};
pub use olm_rs::errors::OlmPkDecryptionError;
/// A state machine that handles backing up room keys.
///
/// The state machine can be activated using the
/// [`BackupMachine::enable_backup_v1`] method. After the state machine has been
/// enabled a request that will upload encrypted room keys can be generated
/// using the [`BackupMachine::backup`] method.
#[derive(Debug, Clone)]
pub struct BackupMachine {
account: Account,
store: Store,
backup_key: Arc<RwLock<Option<MegolmV1BackupKey>>>,
pending_backup: Arc<RwLock<Option<PendingBackup>>>,
}
#[derive(Debug, Clone)]
struct PendingBackup {
request_id: Uuid,
request: KeysBackupRequest,
sessions: BTreeMap<RoomId, BTreeMap<String, BTreeSet<String>>>,
}
impl PendingBackup {
fn session_was_part_of_the_backup(&self, session: &InboundGroupSession) -> bool {
self.sessions
.get(session.room_id())
.and_then(|r| r.get(session.sender_key()).map(|s| s.contains(session.session_id())))
.unwrap_or(false)
}
}
impl From<PendingBackup> for OutgoingRequest {
fn from(b: PendingBackup) -> Self {
OutgoingRequest { request_id: b.request_id, request: Arc::new(b.request.into()) }
}
}
impl BackupMachine {
const BACKUP_BATCH_SIZE: usize = 100;
pub(crate) fn new(
account: Account,
store: Store,
backup_key: Option<MegolmV1BackupKey>,
) -> Self {
Self {
account,
store,
backup_key: RwLock::new(backup_key).into(),
pending_backup: RwLock::new(None).into(),
}
}
/// Are we able to back up room keys to the server?
pub async fn enabled(&self) -> bool {
self.backup_key.read().await.as_ref().map(|b| b.backup_version().is_some()).unwrap_or(false)
}
/// Verify some backup auth data that we downloaded from the server.
///
/// The auth data should be fetched from the server using the
/// [`/room_keys/version`] endpoint.
///
/// Ruma models this using the [`BackupAlgorithm`] struct, but care needs to
/// be taken if that struct is used. Some clients might use unspecced fields
/// and the given Ruma struct will lose the unspecced fields after a
/// serialization cycle.
///
/// [`BackupAlgorithm`]: ruma::api::client::r0::backup::BackupAlgorithm
/// [`/room_keys/version`]: https://spec.matrix.org/unstable/client-server-api/#get_matrixclientv3room_keysversion
pub async fn verify_backup(
&self,
mut serialized_auth_data: Value,
) -> Result<bool, CryptoStoreError> {
#[derive(Debug, Serialize, Deserialize)]
struct AuthData {
public_key: String,
#[serde(default)]
signatures: BTreeMap<UserId, BTreeMap<DeviceKeyId, String>>,
#[serde(flatten)]
extra: BTreeMap<String, Value>,
}
let auth_data: AuthData = serde_json::from_value(serialized_auth_data.clone())?;
trace!(?auth_data, "Verifying backup auth data");
Ok(if let Some(signatures) = auth_data.signatures.get(self.store.user_id()) {
for device_key_id in signatures.keys() {
if device_key_id.algorithm() == DeviceKeyAlgorithm::Ed25519 {
if device_key_id.device_id() == self.account.device_id() {
let result = self.account.is_signed(&mut serialized_auth_data);
trace!(?result, "Checking auth data signature of our own device");
if result.is_ok() {
return Ok(true);
}
} else {
let device = self
.store
.get_device(self.store.user_id(), device_key_id.device_id())
.await?;
trace!(
device_id = device_key_id.device_id().as_str(),
"Checking backup auth data for device"
);
if let Some(device) = device {
if device.verified()
&& device.is_signed_by_device(&mut serialized_auth_data).is_ok()
{
return Ok(true);
}
} else {
trace!(
device_id = device_key_id.device_id().as_str(),
"Device not found, can't check signature"
);
}
}
}
}
false
} else {
false
})
}
/// Activate the given backup key to be used to encrypt and backup room
/// keys.
///
/// This will use the [`m.megolm_backup.v1.curve25519-aes-sha2`] algorithm
/// to encrypt the room keys.
///
/// [`m.megolm_backup.v1.curve25519-aes-sha2`]:
/// https://spec.matrix.org/unstable/client-server-api/#backup-algorithm-mmegolm_backupv1curve25519-aes-sha2
pub async fn enable_backup_v1(&self, key: MegolmV1BackupKey) -> Result<(), CryptoStoreError> {
if key.backup_version().is_some() {
*self.backup_key.write().await = Some(key.clone());
info!(backup_key =? key, "Activated a backup");
} else {
warn!(backup_key =? key, "Tried to activate a backup without having the backup key uploaded");
}
Ok(())
}
/// Get the number of backed up room keys and the total number of room keys.
pub async fn room_key_counts(&self) -> Result<RoomKeyCounts, CryptoStoreError> {
self.store.inbound_group_session_counts().await
}
/// Disable and reset our backup state.
///
/// This will remove any pending backup request, remove the backup key and
/// reset the backup state of each room key we have.
#[instrument(skip(self))]
pub async fn disable_backup(&self) -> Result<(), CryptoStoreError> {
debug!("Disabling key backup and resetting backup state for room keys");
self.backup_key.write().await.take();
self.pending_backup.write().await.take();
self.store.reset_backup_state().await?;
debug!("Done disabling backup");
Ok(())
}
/// Store the recovery key in the cryptostore.
///
/// This is useful if the client wants to support gossiping of the backup
/// key.
pub async fn save_recovery_key(
&self,
recovery_key: Option<RecoveryKey>,
version: Option<String>,
) -> Result<(), CryptoStoreError> {
let changes = Changes { recovery_key, backup_version: version, ..Default::default() };
self.store.save_changes(changes).await
}
/// Get the backup keys we have saved in our crypto store.
pub async fn get_backup_keys(&self) -> Result<BackupKeys, CryptoStoreError> {
self.store.load_backup_keys().await
}
/// Encrypt a batch of room keys and return a request that needs to be sent
/// out to backup the room keys.
pub async fn backup(&self) -> Result<Option<OutgoingRequest>, CryptoStoreError> {
let mut request = self.pending_backup.write().await;
if let Some(request) = &*request {
trace!("Backing up, returning an existing request");
Ok(Some(request.clone().into()))
} else {
trace!("Backing up, creating a new request");
let new_request = self.backup_helper().await?;
*request = new_request.clone();
Ok(new_request.map(|r| r.into()))
}
}
pub(crate) async fn mark_request_as_sent(
&self,
request_id: Uuid,
) -> Result<(), CryptoStoreError> {
let mut request = self.pending_backup.write().await;
if let Some(r) = &*request {
if r.request_id == request_id {
let sessions: Vec<_> = self
.store
.get_inbound_group_sessions()
.await?
.into_iter()
.filter(|s| r.session_was_part_of_the_backup(s))
.collect();
for session in &sessions {
session.mark_as_backed_up()
}
trace!(request_id =? r.request_id, keys =? r.sessions, "Marking room keys as backed up");
let changes = Changes { inbound_group_sessions: sessions, ..Default::default() };
self.store.save_changes(changes).await?;
let counts = self.store.inbound_group_session_counts().await?;
trace!(
room_key_counts =? counts,
request_id =? r.request_id, keys =? r.sessions, "Marked room keys as backed up"
);
*request = None;
} else {
warn!(
expected = r.request_id.to_string().as_str(),
got = request_id.to_string().as_str(),
"Tried to mark a pending backup as sent but the request id didn't match"
)
}
} else {
warn!(
request_id = request_id.to_string().as_str(),
"Tried to mark a pending backup as sent but there isn't a backup pending"
);
}
Ok(())
}
async fn backup_helper(&self) -> Result<Option<PendingBackup>, CryptoStoreError> {
if let Some(backup_key) = &*self.backup_key.read().await {
if let Some(version) = backup_key.backup_version() {
let sessions =
self.store.inbound_group_sessions_for_backup(Self::BACKUP_BATCH_SIZE).await?;
if !sessions.is_empty() {
let key_count = sessions.len();
let (backup, session_record) = Self::backup_keys(sessions, backup_key).await;
info!(
key_count = key_count,
keys =? session_record,
backup_key =? backup_key,
"Successfully created a room keys backup request"
);
let request = PendingBackup {
request_id: Uuid::new_v4(),
request: KeysBackupRequest { version, rooms: backup },
sessions: session_record,
};
Ok(Some(request))
} else {
trace!(?backup_key, "No room keys need to be backed up");
Ok(None)
}
} else {
warn!("Trying to backup room keys but the backup key wasn't uploaded");
Ok(None)
}
} else {
warn!("Trying to backup room keys but no backup key was found");
Ok(None)
}
}
/// Backup all the non-backed up room keys we know about
async fn backup_keys(
sessions: Vec<InboundGroupSession>,
backup_key: &MegolmV1BackupKey,
) -> (BTreeMap<RoomId, RoomKeyBackup>, BTreeMap<RoomId, BTreeMap<String, BTreeSet<String>>>)
{
let mut backup: BTreeMap<RoomId, RoomKeyBackup> = BTreeMap::new();
let mut session_record: BTreeMap<RoomId, BTreeMap<String, BTreeSet<String>>> =
BTreeMap::new();
for session in sessions.into_iter() {
let room_id = session.room_id().to_owned();
let session_id = session.session_id().to_owned();
let sender_key = session.sender_key().to_owned();
let session = backup_key.encrypt(session).await;
session_record
.entry(room_id.clone())
.or_default()
.entry(sender_key)
.or_default()
.insert(session_id.clone());
backup
.entry(room_id)
.or_insert_with(|| RoomKeyBackup::new(BTreeMap::new()))
.sessions
.insert(session_id, session);
}
(backup, session_record)
}
}
#[cfg(test)]
mod test {
use matrix_sdk_test::async_test;
use ruma::{room_id, user_id, DeviceIdBox, RoomId, UserId};
use super::RecoveryKey;
use crate::{OlmError, OlmMachine};
fn alice_id() -> UserId {
user_id!("@alice:example.org")
}
fn alice_device_id() -> DeviceIdBox {
"JLAFKJWSCS".into()
}
fn room_id() -> RoomId {
room_id!("!test:localhost")
}
fn room_id2() -> RoomId {
room_id!("!test2:localhost")
}
async fn backup_flow(machine: OlmMachine) -> Result<(), OlmError> {
let backup_machine = machine.backup_machine();
let counts = backup_machine.store.inbound_group_session_counts().await?;
assert_eq!(counts.total, 0, "Initially no keys exist");
assert_eq!(counts.backed_up, 0, "Initially no backed up keys exist");
machine.create_outbound_group_session_with_defaults(&room_id()).await?;
machine.create_outbound_group_session_with_defaults(&room_id2()).await?;
let counts = backup_machine.store.inbound_group_session_counts().await?;
assert_eq!(counts.total, 2, "Two room keys need to exist in the store");
assert_eq!(counts.backed_up, 0, "No room keys have been backed up yet");
let recovery_key = RecoveryKey::new().expect("Can't create new recovery key");
let backup_key = recovery_key.megolm_v1_public_key();
backup_key.set_version("1".to_owned());
backup_machine.enable_backup_v1(backup_key).await?;
let request =
backup_machine.backup().await?.expect("Created a backup request successfully");
assert_eq!(
Some(request.request_id),
backup_machine.backup().await?.map(|r| r.request_id),
"Calling backup again without uploading creates the same backup request"
);
backup_machine.mark_request_as_sent(request.request_id).await?;
let counts = backup_machine.store.inbound_group_session_counts().await?;
assert_eq!(counts.total, 2);
assert_eq!(counts.backed_up, 2, "All room keys have been backed up");
assert!(
backup_machine.backup().await?.is_none(),
"No room keys need to be backed up, no request needs to be created"
);
backup_machine.disable_backup().await?;
let counts = backup_machine.store.inbound_group_session_counts().await?;
assert_eq!(counts.total, 2);
assert_eq!(
counts.backed_up, 0,
"Disabling the backup resets the backup flag on the room keys"
);
Ok(())
}
#[async_test]
async fn memory_store_backups() -> Result<(), OlmError> {
let machine = OlmMachine::new(&alice_id(), &alice_device_id());
backup_flow(machine).await
}
#[async_test]
#[cfg(feature = "sled_cryptostore")]
async fn default_store_backups() -> Result<(), OlmError> {
use tempfile::tempdir;
let tmpdir = tempdir().expect("Can't create a temporary dir");
let machine = OlmMachine::new_with_default_store(
&alice_id(),
&alice_device_id(),
tmpdir.as_ref(),
None,
)
.await?;
backup_flow(machine).await
}
#[async_test]
#[cfg(feature = "sled_cryptostore")]
async fn recovery_key_storing() -> Result<(), OlmError> {
use tempfile::tempdir;
let tmpdir = tempdir().expect("Can't create a temporary dir");
let machine = OlmMachine::new_with_default_store(
&alice_id(),
&alice_device_id(),
tmpdir.as_ref(),
Some("test"),
)
.await?;
let backup_machine = machine.backup_machine();
let recovery_key = RecoveryKey::new().expect("Can't create new recovery key");
let encoded_key = recovery_key.to_base64();
backup_machine.save_recovery_key(Some(recovery_key), Some("1".to_owned())).await?;
let loded_backup = backup_machine.get_backup_keys().await?;
assert_eq!(
encoded_key,
loded_backup
.recovery_key
.expect("The recovery key wasn't loaded from the store")
.to_base64(),
"The loaded key matches to the one we stored"
);
assert_eq!(
Some("1"),
loded_backup.backup_version.as_deref(),
"The loaded version matches to the one we stored"
);
Ok(())
}
}

View File

@@ -770,6 +770,8 @@ impl GossipMachine {
return Ok(None);
};
let secret = std::mem::take(&mut event.content.secret);
if let Some(request) = self.store.get_outgoing_secret_requests(request_id).await? {
match &request.info {
SecretInfo::KeyRequest(_) => {
@@ -791,30 +793,32 @@ impl GossipMachine {
self.store.get_device_from_curve_key(&event.sender, sender_key).await?
{
if device.verified() {
match self
.store
.import_secret(
secret_name,
std::mem::take(&mut event.content.secret),
)
.await
{
Ok(_) => self.mark_as_done(request).await?,
Err(e) => {
// If this is a store error propagate it up
// the call stack.
if let SecretImportError::Store(e) = e {
return Err(e);
} else {
// Otherwise warn that there was
// something wrong with the secret.
warn!(
secret_name = secret_name.as_ref(),
error =? e,
"Error while importing a secret"
)
if secret_name != &SecretName::RecoveryKey {
match self.store.import_secret(secret_name, secret).await {
Ok(_) => self.mark_as_done(request).await?,
Err(e) => {
// If this is a store error propagate it up
// the call stack.
if let SecretImportError::Store(e) = e {
return Err(e);
} else {
// Otherwise warn that there was
// something wrong with the secret.
warn!(
secret_name = secret_name.as_ref(),
error =? e,
"Error while importing a secret"
)
}
}
}
} else {
// Skip importing the recovery key here since
// we'll want to check if the public key matches
// to the latest version on the server. We
// instead leave the key in the event and let
// the user import it later.
event.content.secret = secret;
}
} else {
warn!(

View File

@@ -539,7 +539,7 @@ impl ReadOnlyDevice {
Ok(())
}
fn is_signed_by_device(&self, json: &mut Value) -> Result<(), SignatureError> {
pub(crate) fn is_signed_by_device(&self, json: &mut Value) -> Result<(), SignatureError> {
let signing_key =
self.get_key(DeviceKeyAlgorithm::Ed25519).ok_or(SignatureError::MissingSigningKey)?;
@@ -623,9 +623,6 @@ impl PartialEq for ReadOnlyDevice {
#[cfg(test)]
pub(crate) mod test {
#[cfg(target_arch = "wasm32")]
use wasm_bindgen_test::wasm_bindgen_test;
use matrix_sdk_test::async_test;
use std::convert::TryFrom;
use ruma::{encryption::DeviceKeys, user_id, DeviceKeyAlgorithm};

View File

@@ -30,6 +30,9 @@
compile_error!("indexeddb_cryptostore only works for wasm32 target");
#[cfg(feature = "backups_v1")]
#[cfg_attr(feature = "docs", doc(cfg(backups_v1)))]
pub mod backups;
mod error;
mod file_encryption;
mod gossiping;
@@ -72,7 +75,7 @@ pub use matrix_qrcode;
pub(crate) use olm::ReadOnlyAccount;
pub use olm::{CrossSigningStatus, EncryptionSettings};
pub use requests::{
IncomingResponse, KeysQueryRequest, OutgoingRequest, OutgoingRequests,
IncomingResponse, KeysBackupRequest, KeysQueryRequest, OutgoingRequest, OutgoingRequests,
OutgoingVerificationRequest, RoomMessageRequest, ToDeviceRequest, UploadSigningKeysRequest,
};
pub use store::{CrossSigningKeyExport, CryptoStoreError, SecretImportError};

View File

@@ -46,11 +46,14 @@ use ruma::{
secret::request::SecretName,
AnyMessageEventContent, AnyRoomEvent, AnyToDeviceEvent, EventContent,
},
DeviceId, DeviceIdBox, DeviceKeyAlgorithm, EventEncryptionAlgorithm, RoomId, UInt, UserId,
DeviceId, DeviceIdBox, DeviceKeyAlgorithm, DeviceKeyId, EventEncryptionAlgorithm, RoomId, UInt,
UserId,
};
use serde_json::Value;
use tracing::{debug, error, info, trace, warn};
#[cfg(feature = "backups_v1")]
use crate::backups::BackupMachine;
#[cfg(feature = "sled_cryptostore")]
use crate::store::sled::SledStore;
use crate::{
@@ -104,6 +107,9 @@ pub struct OlmMachine {
/// State machine handling public user identities and devices, keeping track
/// of when a key query needs to be done and handling one.
identity_manager: IdentityManager,
/// A state machine that handles creating room key backups.
#[cfg(feature = "backups_v1")]
backup_machine: BackupMachine,
}
#[cfg(not(tarpaulin_include))]
@@ -180,6 +186,9 @@ impl OlmMachine {
let identity_manager =
IdentityManager::new(user_id.clone(), device_id.clone(), store.clone());
#[cfg(feature = "backups_v1")]
let backup_machine = BackupMachine::new(account.clone(), store.clone(), None);
OlmMachine {
user_id,
device_id,
@@ -191,6 +200,8 @@ impl OlmMachine {
verification_machine,
key_request_machine,
identity_manager,
#[cfg(feature = "backups_v1")]
backup_machine,
}
}
@@ -371,6 +382,10 @@ impl OlmMachine {
IncomingResponse::RoomMessage(_) => {
self.verification_machine.mark_request_as_sent(request_id);
}
IncomingResponse::KeysBackup(_) => {
#[cfg(feature = "backups_v1")]
self.backup_machine.mark_request_as_sent(*request_id).await?;
}
};
Ok(())
@@ -651,18 +666,18 @@ impl OlmMachine {
Ok(())
}
// #[cfg(test)]
// pub(crate) async fn create_inbound_session(
// &self,
// room_id: &RoomId,
// ) -> OlmResult<InboundGroupSession> {
// let (_, session) = self
// .group_session_manager
// .create_outbound_group_session(room_id, EncryptionSettings::default())
// .await?;
#[cfg(test)]
pub(crate) async fn create_inbound_session(
&self,
room_id: &RoomId,
) -> OlmResult<InboundGroupSession> {
let (_, session) = self
.group_session_manager
.create_outbound_group_session(room_id, EncryptionSettings::default())
.await?;
// Ok(session)
// }
Ok(session)
}
/// Encrypt a room message for the given room.
///
@@ -1453,6 +1468,62 @@ impl OlmMachine {
) -> Result<CrossSigningStatus, SecretImportError> {
self.store.import_cross_signing_keys(export).await
}
async fn sign_account(
&self,
message: &str,
signatures: &mut BTreeMap<UserId, BTreeMap<DeviceKeyId, String>>,
) {
let device_key_id = DeviceKeyId::from_parts(DeviceKeyAlgorithm::Ed25519, self.device_id());
let signature = self.account.sign(message).await;
signatures.entry(self.user_id().to_owned()).or_default().insert(device_key_id, signature);
}
async fn sign_master(
&self,
message: &str,
signatures: &mut BTreeMap<UserId, BTreeMap<DeviceKeyId, String>>,
) -> Result<(), crate::SignatureError> {
let identity = &*self.user_identity.lock().await;
let master_key: DeviceIdBox = identity
.master_public_key()
.await
.and_then(|m| m.get_first_key().map(|k| k.to_owned()))
.ok_or(crate::SignatureError::MissingSigningKey)?
.into();
let device_key_id = DeviceKeyId::from_parts(DeviceKeyAlgorithm::Ed25519, &master_key);
let signature = identity.sign(message).await?;
signatures.entry(self.user_id().to_owned()).or_default().insert(device_key_id, signature);
Ok(())
}
/// Sign the given message using our device key and if available cross
/// signing master key.
pub async fn sign(&self, message: &str) -> BTreeMap<UserId, BTreeMap<DeviceKeyId, String>> {
let mut signatures: BTreeMap<_, BTreeMap<_, _>> = BTreeMap::new();
self.sign_account(message, &mut signatures).await;
if let Err(e) = self.sign_master(message, &mut signatures).await {
warn!(error =? e, "Couldn't sign the message using the cross signing master key")
}
signatures
}
/// Get a reference to the backup related state machine.
///
/// This state machine can be used to incrementally backup all room keys to
/// the server.
#[cfg(feature = "backups_v1")]
pub fn backup_machine(&self) -> &BackupMachine {
&self.backup_machine
}
}
#[cfg(test)]

View File

@@ -686,6 +686,19 @@ impl ReadOnlyAccount {
self.inner.lock().await.sign(string)
}
#[cfg(feature = "backups_v1")]
pub(crate) fn is_signed(&self, json: &mut Value) -> Result<(), SignatureError> {
let signing_key = self.identity_keys.ed25519();
let utility = crate::olm::Utility::new();
utility.verify_json(
&self.user_id,
&DeviceKeyId::from_parts(DeviceKeyAlgorithm::Ed25519, self.device_id()),
signing_key,
json,
)
}
/// Store the account as a base64 encoded string.
///
/// # Arguments
@@ -806,7 +819,8 @@ impl ReadOnlyAccount {
) -> Result<SignatureUploadRequest, SignatureError> {
let public_key =
master_key.get_first_key().ok_or(SignatureError::MissingSigningKey)?.to_string();
let mut cross_signing_key = master_key.into();
let mut cross_signing_key: CrossSigningKey = master_key.into();
cross_signing_key.signatures.clear();
self.sign_cross_signing_key(&mut cross_signing_key).await?;
let mut signed_keys = BTreeMap::new();

View File

@@ -12,7 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{collections::BTreeMap, convert::TryFrom, fmt, mem, sync::Arc};
use std::{
collections::BTreeMap,
convert::TryFrom,
fmt, mem,
sync::{
atomic::{AtomicBool, Ordering::SeqCst},
Arc,
},
};
use matrix_sdk_common::locks::Mutex;
pub use olm_rs::{
@@ -39,7 +47,7 @@ use serde::{Deserialize, Serialize};
use serde_json::Value;
use zeroize::Zeroizing;
use super::{ExportedGroupSessionKey, ExportedRoomKey, GroupSessionKey};
use super::{BackedUpRoomKey, ExportedGroupSessionKey, ExportedRoomKey, GroupSessionKey};
use crate::error::{EventError, MegolmResult};
// TODO add creation times to the inbound group sessions so we can export
@@ -61,7 +69,7 @@ pub struct InboundGroupSession {
pub(crate) room_id: Arc<RoomId>,
forwarding_chains: Arc<Vec<String>>,
imported: bool,
backed_up: bool,
backed_up: Arc<AtomicBool>,
}
impl InboundGroupSession {
@@ -105,7 +113,7 @@ impl InboundGroupSession {
room_id: room_id.clone().into(),
forwarding_chains: Vec::new().into(),
imported: false,
backed_up: false,
backed_up: AtomicBool::new(false).into(),
})
}
@@ -123,6 +131,25 @@ impl InboundGroupSession {
Self::try_from(exported_session.into())
}
#[allow(dead_code)]
fn from_backup(
room_id: &RoomId,
backup: BackedUpRoomKey,
) -> Result<Self, OlmGroupSessionError> {
let session = OlmInboundGroupSession::new(&backup.session_key.0)?;
let session_id = session.session_id();
Self::from_export(ExportedRoomKey {
algorithm: backup.algorithm,
room_id: room_id.to_owned(),
sender_key: backup.sender_key,
session_id,
session_key: backup.session_key,
sender_claimed_keys: backup.sender_claimed_keys,
forwarding_curve25519_key_chain: backup.forwarding_curve25519_key_chain,
})
}
/// Create a new inbound group session from a forwarded room key content.
///
/// # Arguments
@@ -157,7 +184,7 @@ impl InboundGroupSession {
room_id: content.room_id.clone().into(),
forwarding_chains: forwarding_chains.into(),
imported: true,
backed_up: false,
backed_up: AtomicBool::new(false).into(),
})
}
@@ -177,7 +204,7 @@ impl InboundGroupSession {
room_id: (&*self.room_id).clone(),
forwarding_chains: self.forwarding_key_chain().to_vec(),
imported: self.imported,
backed_up: self.backed_up,
backed_up: self.backed_up(),
history_visibility: self.history_visibility.as_ref().clone(),
}
}
@@ -195,6 +222,21 @@ impl InboundGroupSession {
&self.sender_key
}
/// Has the session been backed up to the server.
pub fn backed_up(&self) -> bool {
self.backed_up.load(SeqCst)
}
/// Reset the backup state of the inbound group session.
pub(crate) fn reset_backup_state(&self) {
self.backed_up.store(false, SeqCst)
}
#[cfg(feature = "backups_v1")]
pub(crate) fn mark_as_backed_up(&self) {
self.backed_up.store(true, SeqCst)
}
/// Get the map of signing keys this session was received from.
pub fn signing_keys(&self) -> &BTreeMap<DeviceKeyAlgorithm, String> {
&self.signing_keys
@@ -256,7 +298,7 @@ impl InboundGroupSession {
signing_keys: pickle.signing_key.into(),
room_id: pickle.room_id.into(),
forwarding_chains: pickle.forwarding_chains.into(),
backed_up: pickle.backed_up,
backed_up: AtomicBool::from(pickle.backed_up).into(),
imported: pickle.imported,
})
}
@@ -291,6 +333,11 @@ impl InboundGroupSession {
self.inner.lock().await.decrypt(message)
}
#[cfg(feature = "backups_v1")]
pub(crate) async fn to_backup(&self) -> BackedUpRoomKey {
self.export().await.into()
}
/// Decrypt an event from a room timeline.
///
/// # Arguments
@@ -413,16 +460,16 @@ impl TryFrom<ExportedRoomKey> for InboundGroupSession {
let first_known_index = session.first_known_index();
Ok(InboundGroupSession {
inner: Arc::new(Mutex::new(session)),
inner: Mutex::new(session).into(),
session_id: key.session_id.into(),
sender_key: key.sender_key.into(),
history_visibility: None.into(),
first_known_index,
signing_keys: Arc::new(key.sender_claimed_keys),
room_id: Arc::new(key.room_id),
forwarding_chains: Arc::new(key.forwarding_curve25519_key_chain),
signing_keys: key.sender_claimed_keys.into(),
room_id: key.room_id.into(),
forwarding_chains: key.forwarding_curve25519_key_chain.into(),
imported: true,
backed_up: false,
backed_up: AtomicBool::from(false).into(),
})
}
}

View File

@@ -43,7 +43,7 @@ pub struct GroupSessionKey(pub String);
#[zeroize(drop)]
pub struct ExportedGroupSessionKey(pub String);
/// An exported version of a `InboundGroupSession`
/// An exported version of an `InboundGroupSession`
///
/// This can be used to share the `InboundGroupSession` in an exported file.
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
@@ -71,6 +71,28 @@ pub struct ExportedRoomKey {
pub forwarding_curve25519_key_chain: Vec<String>,
}
/// A backed up version of an `InboundGroupSession`
///
/// This can be used to backup the `InboundGroupSession` to the server.
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
pub struct BackedUpRoomKey {
/// The encryption algorithm that the session uses.
pub algorithm: EventEncryptionAlgorithm,
/// The Curve25519 key of the device which initiated the session originally.
pub sender_key: String,
/// The key for the session.
pub session_key: ExportedGroupSessionKey,
/// The Ed25519 key of the device which initiated the session originally.
pub sender_claimed_keys: BTreeMap<DeviceKeyAlgorithm, String>,
/// Chain of Curve25519 keys through which this session was forwarded, via
/// m.forwarded_room_key events.
pub forwarding_curve25519_key_chain: Vec<String>,
}
impl TryInto<ToDeviceForwardedRoomKeyEventContent> for ExportedRoomKey {
type Error = ();
@@ -104,6 +126,18 @@ impl TryInto<ToDeviceForwardedRoomKeyEventContent> for ExportedRoomKey {
}
}
impl From<ExportedRoomKey> for BackedUpRoomKey {
fn from(k: ExportedRoomKey) -> Self {
Self {
algorithm: k.algorithm,
sender_key: k.sender_key,
session_key: k.session_key,
sender_claimed_keys: k.sender_claimed_keys,
forwarding_curve25519_key_chain: k.forwarding_curve25519_key_chain,
}
}
}
impl From<ToDeviceForwardedRoomKeyEventContent> for ExportedRoomKey {
/// Convert the content of a forwarded room key into a exported room key.
fn from(forwarded_key: ToDeviceForwardedRoomKeyEventContent) -> Self {
@@ -125,6 +159,9 @@ impl From<ToDeviceForwardedRoomKeyEventContent> for ExportedRoomKey {
#[cfg(test)]
mod test {
#[cfg(target_arch = "wasm32")]
use wasm_bindgen_test::wasm_bindgen_test;
use matrix_sdk_test::async_test;
use std::{
sync::Arc,
time::{Duration, Instant},

View File

@@ -458,6 +458,17 @@ impl PrivateCrossSigningIdentity {
Ok(SignatureUploadRequest::new(signed_keys))
}
pub(crate) async fn sign(&self, message: &str) -> Result<String, SignatureError> {
Ok(self
.master_key
.lock()
.await
.as_ref()
.ok_or(SignatureError::MissingSigningKey)?
.sign(message)
.await)
}
/// Create a new identity for the given Olm Account.
///
/// Returns the new identity, the upload signing keys request and a
@@ -770,6 +781,9 @@ mod test {
let master = user_signing.sign_user(&bob_public).await.unwrap();
let num_signatures: usize = master.signatures.iter().map(|(_, u)| u.len()).sum();
assert_eq!(num_signatures, 1, "We're only uploading our own signature");
bob_public.master_key = master.into();
user_signing.public_key.verify_master_key(bob_public.master_key()).unwrap();

View File

@@ -155,6 +155,10 @@ impl MasterSigning {
Ok(Self { inner, public_key: pickle.public_key.into() })
}
pub async fn sign(&self, message: &str) -> String {
self.inner.sign(message).await.0
}
pub async fn sign_subkey<'a>(&self, subkey: &mut CrossSigningKey) {
let subkey_without_signatures = json!({
"user_id": subkey.user_id.clone(),
@@ -164,7 +168,7 @@ impl MasterSigning {
let message = serde_json::to_string(&subkey_without_signatures)
.expect("Can't serialize cross signing subkey");
let signature = self.inner.sign(&message).await;
let signature = self.sign(&message).await;
subkey
.signatures
@@ -176,7 +180,7 @@ impl MasterSigning {
self.inner.public_key().as_str().into(),
)
.to_string(),
signature.0,
signature,
);
}
}
@@ -203,10 +207,10 @@ impl UserSigning {
&self,
user: &ReadOnlyUserIdentity,
) -> Result<CrossSigningKey, SignatureError> {
let signature = self.sign_user_helper(user).await?;
let signatures = self.sign_user_helper(user).await?;
let mut master_key: CrossSigningKey = user.master_key().to_owned().into();
master_key.signatures.extend(signature);
master_key.signatures = signatures;
Ok(master_key)
}

View File

@@ -17,6 +17,7 @@ use std::{collections::BTreeMap, iter, sync::Arc, time::Duration};
use matrix_sdk_common::uuid::Uuid;
use ruma::{
api::client::r0::{
backup::{add_backup_keys::Response as KeysBackupResponse, RoomKeyBackup},
keys::{
claim_keys::{Request as KeysClaimRequest, Response as KeysClaimResponse},
get_keys::Response as KeysQueryResponse,
@@ -197,6 +198,8 @@ pub enum OutgoingRequests {
/// A room message request, usually for sending in-room interactive
/// verification events.
RoomMessage(RoomMessageRequest),
/// A request that will back up a batch of room keys to the server.
KeysBackup(KeysBackupRequest),
}
#[cfg(test)]
@@ -215,6 +218,12 @@ impl From<KeysQueryRequest> for OutgoingRequests {
}
}
impl From<KeysBackupRequest> for OutgoingRequests {
fn from(r: KeysBackupRequest) -> Self {
OutgoingRequests::KeysBackup(r)
}
}
impl From<KeysClaimRequest> for OutgoingRequests {
fn from(r: KeysClaimRequest) -> Self {
Self::KeysClaim(r)
@@ -278,6 +287,8 @@ pub enum IncomingResponse<'a> {
SignatureUpload(&'a SignatureUploadResponse),
/// A room message response, usually for interactive verifications.
RoomMessage(&'a RoomMessageResponse),
/// Response for the server-side room key backup request.
KeysBackup(&'a KeysBackupResponse),
}
impl<'a> From<&'a KeysUploadResponse> for IncomingResponse<'a> {
@@ -286,6 +297,12 @@ impl<'a> From<&'a KeysUploadResponse> for IncomingResponse<'a> {
}
}
impl<'a> From<&'a KeysBackupResponse> for IncomingResponse<'a> {
fn from(response: &'a KeysBackupResponse) -> Self {
IncomingResponse::KeysBackup(response)
}
}
impl<'a> From<&'a KeysQueryResponse> for IncomingResponse<'a> {
fn from(response: &'a KeysQueryResponse) -> Self {
IncomingResponse::KeysQuery(response)
@@ -356,6 +373,16 @@ pub struct RoomMessageRequest {
pub content: AnyMessageEventContent,
}
/// A request that will back up a batch of room keys to the server.
#[derive(Clone, Debug)]
pub struct KeysBackupRequest {
/// The backup version that these room keys should be part of.
pub version: String,
/// The map from room id to a backed up room key that we're going to upload
/// to the server.
pub rooms: BTreeMap<RoomId, RoomKeyBackup>,
}
/// An enum over the different outgoing verification based requests.
#[derive(Clone, Debug)]
pub enum OutgoingVerificationRequest {

View File

@@ -111,6 +111,11 @@ impl GroupSessionStore {
.collect()
}
/// Get the number of `InboundGroupSession`s we have.
pub fn count(&self) -> usize {
self.entries.iter().map(|d| d.value().values().map(|t| t.len()).sum::<usize>()).sum()
}
/// Get a inbound group session from our store.
///
/// # Arguments

View File

@@ -23,7 +23,8 @@ use ruma::{DeviceId, DeviceIdBox, RoomId, UserId};
use super::{
caches::{DeviceStore, GroupSessionStore, SessionStore},
Changes, CryptoStore, InboundGroupSession, ReadOnlyAccount, Result, Session,
BackupKeys, Changes, CryptoStore, InboundGroupSession, ReadOnlyAccount, Result, RoomKeyCounts,
Session,
};
use crate::{
gossiping::{GossipRequest, SecretInfo},
@@ -163,6 +164,34 @@ impl CryptoStore for MemoryStore {
Ok(self.inbound_group_sessions.get_all())
}
async fn inbound_group_session_counts(&self) -> Result<RoomKeyCounts> {
let backed_up =
self.get_inbound_group_sessions().await?.into_iter().filter(|s| s.backed_up()).count();
Ok(RoomKeyCounts { total: self.inbound_group_sessions.count(), backed_up })
}
async fn inbound_group_sessions_for_backup(
&self,
limit: usize,
) -> Result<Vec<InboundGroupSession>> {
Ok(self
.get_inbound_group_sessions()
.await?
.into_iter()
.filter(|s| !s.backed_up())
.take(limit)
.collect())
}
async fn reset_backup_state(&self) -> Result<()> {
for session in self.get_inbound_group_sessions().await? {
session.reset_backup_state()
}
Ok(())
}
async fn get_outbound_group_sessions(
&self,
_: &RoomId,
@@ -271,6 +300,10 @@ impl CryptoStore for MemoryStore {
Ok(())
}
async fn load_backup_keys(&self) -> Result<BackupKeys> {
Ok(BackupKeys::default())
}
}
#[cfg(test)]

View File

@@ -107,11 +107,15 @@ pub(crate) struct Store {
verification_machine: VerificationMachine,
}
#[derive(Clone, Debug, Default)]
#[derive(Default, Debug)]
#[allow(missing_docs)]
pub struct Changes {
pub account: Option<ReadOnlyAccount>,
pub private_identity: Option<PrivateCrossSigningIdentity>,
#[cfg(feature = "backups_v1")]
pub backup_version: Option<String>,
#[cfg(feature = "backups_v1")]
pub recovery_key: Option<crate::backups::RecoveryKey>,
pub sessions: Vec<Session>,
pub message_hashes: Vec<OlmMessageHash>,
pub inbound_group_sessions: Vec<InboundGroupSession>,
@@ -170,6 +174,26 @@ impl DeviceChanges {
}
}
/// Struct holding info about how many room keys the store has.
#[derive(Debug, Clone, Default)]
pub struct RoomKeyCounts {
/// The total number of room keys the store has.
pub total: usize,
/// The number of backed up room keys the store has.
pub backed_up: usize,
}
/// Stored versions of the backup keys.
#[derive(Default, Debug)]
pub struct BackupKeys {
/// The recovery key, the one used to decrypt backed up room keys.
#[cfg(feature = "backups_v1")]
pub recovery_key: Option<crate::backups::RecoveryKey>,
/// The version that we are using for backups.
#[cfg(feature = "backups_v1")]
pub backup_version: Option<String>,
}
/// A struct containing private cross signing keys that can be backed up or
/// uploaded to the secret store.
#[derive(Zeroize)]
@@ -431,7 +455,18 @@ impl Store {
| SecretName::CrossSigningSelfSigningKey => {
self.identity.lock().await.export_secret(secret_name).await
}
SecretName::RecoveryKey => None,
SecretName::RecoveryKey => {
#[cfg(feature = "backups_v1")]
if let Some(key) = self.load_backup_keys().await.unwrap().recovery_key {
let exported = key.to_base64();
Some(exported)
} else {
None
}
#[cfg(not(feature = "backups_v1"))]
None
}
name => {
warn!(secret =? name, "Unknown secret was requested");
None
@@ -496,7 +531,12 @@ impl Store {
self.save_changes(changes).await?;
}
}
SecretName::RecoveryKey => (),
SecretName::RecoveryKey => {
// We don't import the recovery key here since we'll want to
// check if the public key matches to the latest version on the
// server. We instead leave the key in the event and let the
// user import it later.
}
name => {
warn!(secret =? name, "Tried to import an unknown secret");
}
@@ -630,6 +670,22 @@ pub trait CryptoStore: AsyncTraitDeps {
/// Get all the inbound group sessions we have stored.
async fn get_inbound_group_sessions(&self) -> Result<Vec<InboundGroupSession>>;
/// Get the number inbound group sessions we have and how many of them are
/// backed up.
async fn inbound_group_session_counts(&self) -> Result<RoomKeyCounts>;
/// Get all the inbound group sessions we have not backed up yet.
async fn inbound_group_sessions_for_backup(
&self,
limit: usize,
) -> Result<Vec<InboundGroupSession>>;
/// Reset the backup state of all the stored inbound group sessions.
async fn reset_backup_state(&self) -> Result<()>;
/// Get the backup keys we have stored.
async fn load_backup_keys(&self) -> Result<BackupKeys>;
/// Get the outbound group sessions we have stored that is used for the
/// given room.
async fn get_outbound_group_sessions(

View File

@@ -29,14 +29,14 @@ use ruma::{
pub use sled::Error;
use sled::{
transaction::{ConflictableTransactionError, TransactionError},
Config, Db, Transactional, Tree,
Config, Db, IVec, Transactional, Tree,
};
use tracing::trace;
use uuid::Uuid;
use super::{
caches::SessionStore, Changes, CryptoStore, CryptoStoreError, InboundGroupSession, PickleKey,
ReadOnlyAccount, Result, Session,
caches::SessionStore, BackupKeys, Changes, CryptoStore, CryptoStoreError, InboundGroupSession,
PickleKey, ReadOnlyAccount, Result, RoomKeyCounts, Session,
};
use crate::{
gossiping::{GossipRequest, SecretInfo},
@@ -205,9 +205,46 @@ impl SledStore {
self.account_info.read().unwrap().clone()
}
fn upgrade_database(db: &Db) -> Result<()> {
let version = db
.get("version")?
async fn reset_backup_state(&self) -> Result<()> {
let mut pickles: Vec<(IVec, PickledInboundGroupSession)> = self
.inbound_group_sessions
.iter()
.map(|p| {
let item = p?;
Ok((
item.0,
serde_json::from_slice(&item.1).map_err(CryptoStoreError::Serialization)?,
))
})
.collect::<Result<_>>()?;
for (_, pickle) in &mut pickles {
pickle.backed_up = false;
}
let ret: Result<(), TransactionError<serde_json::Error>> =
self.inbound_group_sessions.transaction(|inbound_sessions| {
for (key, pickle) in &pickles {
inbound_sessions.insert(
key,
serde_json::to_vec(&pickle).map_err(ConflictableTransactionError::Abort)?,
)?;
}
Ok(())
});
ret?;
self.inner.flush_async().await?;
Ok(())
}
fn upgrade(&self) -> Result<()> {
let version = self
.inner
.get("store_version")?
.map(|v| {
let (version_bytes, _) = v.split_at(std::mem::size_of::<u8>());
u8::from_be_bytes(version_bytes.try_into().unwrap_or_default())
@@ -225,17 +262,17 @@ impl SledStore {
if version == 0 {
// We changed the schema but migrating this isn't important since we
// rotate the group sessions relatively often anyways so we just
// drop it.
db.drop_tree("outbound_group_sessions")?;
// clear the tree.
self.outbound_group_sessions.clear()?;
}
db.insert("version", DATABASE_VERSION.to_be_bytes().as_ref())?;
self.inner.insert("store_version", DATABASE_VERSION.to_be_bytes().as_ref())?;
self.inner.flush()?;
Ok(())
}
fn open_helper(db: Db, path: Option<PathBuf>, passphrase: Option<&str>) -> Result<Self> {
Self::upgrade_database(&db)?;
let account = db.open_tree("account")?;
let private_identity = db.open_tree("private_identity")?;
@@ -263,7 +300,7 @@ impl SledStore {
.expect("Can't create default pickle key")
};
Ok(Self {
let database = Self {
account_info: RwLock::new(None).into(),
path,
inner: db,
@@ -283,7 +320,11 @@ impl SledStore {
tracked_users,
olm_hashes,
identities,
})
};
database.upgrade()?;
Ok(database)
}
fn get_or_create_pickle_key(passphrase: &str, database: &Db) -> Result<PickleKey> {
@@ -361,6 +402,9 @@ impl SledStore {
None
};
#[cfg(feature = "backups_v1")]
let recovery_key_pickle = changes.recovery_key.map(|r| r.pickle(self.get_pickle_key()));
let device_changes = changes.devices;
let mut session_changes = HashMap::new();
@@ -399,6 +443,8 @@ impl SledStore {
let identity_changes = changes.identities;
let olm_hashes = changes.message_hashes;
let key_requests = changes.key_requests;
#[cfg(feature = "backups_v1")]
let backup_version = changes.backup_version;
let ret: Result<(), TransactionError<serde_json::Error>> = (
&self.account,
@@ -441,6 +487,22 @@ impl SledStore {
)?;
}
#[cfg(feature = "backups_v1")]
if let Some(r) = &recovery_key_pickle {
account.insert(
"recovery_key_v1".encode(),
serde_json::to_vec(r).map_err(ConflictableTransactionError::Abort)?,
)?;
}
#[cfg(feature = "backups_v1")]
if let Some(b) = &backup_version {
account.insert(
"backup_version_v1".encode(),
serde_json::to_vec(b).map_err(ConflictableTransactionError::Abort)?,
)?;
}
for device in device_changes.new.iter().chain(&device_changes.changed) {
let key = (device.user_id().as_str(), device.device_id().as_str()).encode();
let device = serde_json::to_vec(&device)
@@ -655,6 +717,57 @@ impl CryptoStore for SledStore {
.collect())
}
async fn inbound_group_session_counts(&self) -> Result<RoomKeyCounts> {
let pickles: Vec<PickledInboundGroupSession> = self
.inbound_group_sessions
.iter()
.map(|p| {
let item = p?;
serde_json::from_slice(&item.1).map_err(CryptoStoreError::Serialization)
})
.collect::<Result<_>>()?;
let total = pickles.len();
let backed_up = pickles.into_iter().filter(|p| p.backed_up).count();
Ok(RoomKeyCounts { total, backed_up })
}
async fn inbound_group_sessions_for_backup(
&self,
limit: usize,
) -> Result<Vec<InboundGroupSession>> {
let pickles: Vec<InboundGroupSession> = self
.inbound_group_sessions
.iter()
.map(|p| {
let item = p?;
serde_json::from_slice(&item.1).map_err(CryptoStoreError::from)
})
.filter_map(|p: Result<PickledInboundGroupSession, CryptoStoreError>| match p {
Ok(p) => {
if !p.backed_up {
Some(
InboundGroupSession::from_pickle(p, self.get_pickle_mode())
.map_err(CryptoStoreError::from),
)
} else {
None
}
}
Err(p) => Some(Err(p)),
})
.take(limit)
.collect::<Result<_>>()?;
Ok(pickles)
}
async fn reset_backup_state(&self) -> Result<()> {
self.reset_backup_state().await
}
async fn get_outbound_group_sessions(
&self,
room_id: &RoomId,
@@ -670,14 +783,14 @@ impl CryptoStore for SledStore {
!self.users_for_key_query_cache.is_empty()
}
fn tracked_users(&self) -> HashSet<UserId> {
self.tracked_users_cache.to_owned().iter().map(|u| u.clone()).collect()
}
fn users_for_key_query(&self) -> HashSet<UserId> {
self.users_for_key_query_cache.iter().map(|u| u.clone()).collect()
}
fn tracked_users(&self) -> HashSet<UserId> {
self.tracked_users_cache.to_owned().iter().map(|u| u.clone()).collect()
}
async fn update_tracked_user(&self, user: &UserId, dirty: bool) -> Result<bool> {
let already_added = self.tracked_users_cache.insert(user.clone());
@@ -796,6 +909,32 @@ impl CryptoStore for SledStore {
Ok(())
}
async fn load_backup_keys(&self) -> Result<BackupKeys> {
let version = self
.account
.get("backup_version_v1".encode())?
.map(|v| serde_json::from_slice(&v))
.transpose()?;
#[cfg(feature = "backups_v1")]
let recovery_key = {
self.account
.get("recovery_key_v1".encode())?
.map(|p| serde_json::from_slice(&p))
.transpose()?
.map(|p| {
crate::backups::RecoveryKey::from_pickle(p, self.get_pickle_key())
.map_err(|_| CryptoStoreError::UnpicklingError)
})
.transpose()?
};
#[cfg(not(feature = "backups_v1"))]
let recovery_key = None;
Ok(BackupKeys { backup_version: version, recovery_key })
}
}
#[cfg(test)]

View File

@@ -775,11 +775,12 @@ impl TryFrom<OutgoingRequest> for OutgoingContent {
fn try_from(value: OutgoingRequest) -> Result<Self, Self::Error> {
match value.request() {
crate::OutgoingRequests::KeysUpload(_) => Err("Invalid request type".to_owned()),
crate::OutgoingRequests::KeysQuery(_) => Err("Invalid request type".to_owned()),
crate::OutgoingRequests::KeysUpload(_)
| crate::OutgoingRequests::KeysQuery(_)
| crate::OutgoingRequests::KeysBackup(_)
| crate::OutgoingRequests::SignatureUpload(_)
| crate::OutgoingRequests::KeysClaim(_) => Err("Invalid request type".to_owned()),
crate::OutgoingRequests::ToDeviceRequest(r) => Self::try_from(r.clone()),
crate::OutgoingRequests::SignatureUpload(_) => Err("Invalid request type".to_owned()),
crate::OutgoingRequests::KeysClaim(_) => Err("Invalid request type".to_owned()),
crate::OutgoingRequests::RoomMessage(r) => Ok(Self::from(r.clone())),
}
}

View File

@@ -535,7 +535,7 @@ impl IdentitiesBeingVerified {
};
if should_request_secrets {
let secret_requests = self.request_missing_secrets().await;
let secret_requests = self.request_missing_secrets().await?;
changes.key_requests = secret_requests;
}
@@ -547,9 +547,16 @@ impl IdentitiesBeingVerified {
.unwrap_or(VerificationResult::Ok))
}
async fn request_missing_secrets(&self) -> Vec<GossipRequest> {
let secrets = self.private_identity.get_missing_secrets().await;
GossipMachine::request_missing_secrets(self.user_id(), secrets)
async fn request_missing_secrets(&self) -> Result<Vec<GossipRequest>, CryptoStoreError> {
#[allow(unused_mut)]
let mut secrets = self.private_identity.get_missing_secrets().await;
#[cfg(feature = "backups_v1")]
if self.store.inner.load_backup_keys().await?.recovery_key.is_none() {
secrets.push(ruma::events::secret::request::SecretName::RecoveryKey);
}
Ok(GossipMachine::request_missing_secrets(self.user_id(), secrets))
}
async fn mark_identity_as_verified(
@@ -667,9 +674,6 @@ impl IdentitiesBeingVerified {
#[cfg(test)]
pub(crate) mod test {
#[cfg(target_arch = "wasm32")]
use wasm_bindgen_test::wasm_bindgen_test;
use matrix_sdk_test::async_test;
use std::convert::TryInto;
use ruma::{

View File

@@ -110,7 +110,7 @@ features = ["wasm-bindgen"]
[dev-dependencies]
anyhow = "1.0"
dirs = "3.0.2"
futures = { version = "0.3.15", default-features = false }
futures = { version = "0.3.15", default-features = false, features = ["executor"] }
lazy_static = "1.4.0"
matches = "0.1.8"
matrix-sdk-test = { version = "0.4.0", path = "../matrix-sdk-test" }

View File

@@ -26,15 +26,14 @@ use std::{
use anymap2::any::CloneAnySendSync;
use dashmap::DashMap;
use futures_core::stream::Stream;
use futures_timer::Delay as sleep;
use matrix_sdk_base::{
deserialized_responses::{JoinedRoom, LeftRoom, SyncResponse},
deserialized_responses::SyncResponse,
media::{MediaEventContent, MediaFormat, MediaRequest, MediaThumbnailSize, MediaType},
BaseClient, Session, Store,
};
use matrix_sdk_common::{
instant::{Duration, Instant},
locks::{Mutex, RwLock},
locks::{Mutex, RwLock, RwLockReadGuard},
};
use mime::{self, Mime};
use ruma::{
@@ -109,27 +108,31 @@ pub enum LoopCtrl {
/// All of the state is held in an `Arc` so the `Client` can be cloned freely.
#[derive(Clone)]
pub struct Client {
pub(crate) inner: Arc<ClientInner>,
}
pub(crate) struct ClientInner {
/// The URL of the homeserver to connect to.
homeserver: Arc<RwLock<Url>>,
/// The underlying HTTP client.
http_client: HttpClient,
/// User session data.
pub(crate) base_client: BaseClient,
base_client: BaseClient,
/// Locks making sure we only have one group session sharing request in
/// flight per room.
#[cfg(feature = "encryption")]
pub(crate) group_session_locks: Arc<DashMap<RoomId, Arc<Mutex<()>>>>,
pub(crate) group_session_locks: DashMap<RoomId, Arc<Mutex<()>>>,
#[cfg(feature = "encryption")]
/// Lock making sure we're only doing one key claim request at a time.
pub(crate) key_claim_lock: Arc<Mutex<()>>,
pub(crate) members_request_locks: Arc<DashMap<RoomId, Arc<Mutex<()>>>>,
pub(crate) typing_notice_times: Arc<DashMap<RoomId, Instant>>,
pub(crate) key_claim_lock: Mutex<()>,
pub(crate) members_request_locks: DashMap<RoomId, Arc<Mutex<()>>>,
pub(crate) typing_notice_times: DashMap<RoomId, Instant>,
/// Event handlers. See `register_event_handler`.
pub(crate) event_handlers: Arc<RwLock<EventHandlerMap>>,
event_handlers: RwLock<EventHandlerMap>,
/// Custom event handler context. See `register_event_handler_context`.
pub(crate) event_handler_data: Arc<StdRwLock<AnyMap>>,
event_handler_data: StdRwLock<AnyMap>,
/// Notification handlers. See `register_notification_handler`.
notification_handlers: Arc<RwLock<Vec<NotificationHandlerFn>>>,
notification_handlers: RwLock<Vec<NotificationHandlerFn>>,
/// Whether the client should operate in application service style mode.
/// This is low-level functionality. For an high-level API check the
/// `matrix_sdk_appservice` crate.
@@ -142,7 +145,7 @@ pub struct Client {
/// synchronization, e.g. if we send out a request to create a room, we can
/// wait for the sync to get the data to fetch a room object from the state
/// store.
pub(crate) sync_beat: Arc<event_listener::Event>,
pub(crate) sync_beat: event_listener::Event,
}
#[cfg(not(tarpaulin_include))]
@@ -186,7 +189,7 @@ impl Client {
let http_client =
HttpClient::new(client, homeserver.clone(), session, config.request_config);
Ok(Self {
let inner = Arc::new(ClientInner {
homeserver,
http_client,
base_client,
@@ -201,8 +204,10 @@ impl Client {
notification_handlers: Default::default(),
appservice_mode: config.appservice_mode,
use_discovery_response: config.use_discovery_response,
sync_beat: event_listener::Event::new().into(),
})
sync_beat: event_listener::Event::new(),
});
Ok(Self { inner })
}
/// Create a new [`Client`] using homeserver auto discovery.
@@ -226,12 +231,12 @@ impl Client {
/// // First let's try to construct an user id, presumably from user input.
/// let alice = UserId::try_from("@alice:example.org")?;
///
/// // Now let's try to discover the homeserver and create client object.
/// // Now let's try to discover the homeserver and create a client object.
/// let client = Client::new_from_user_id(&alice).await?;
///
/// // Finally let's try to login.
/// client.login(alice, "password", None, None).await?;
/// # matrix_sdk::Result::Ok(()) });
/// # Result::<_, matrix_sdk::Error>::Ok(()) });
/// ```
///
/// [spec]: https://spec.matrix.org/unstable/client-server-api/#well-known-uri
@@ -268,6 +273,24 @@ impl Client {
Ok(client)
}
pub(crate) fn base_client(&self) -> &BaseClient {
&self.inner.base_client
}
#[cfg(feature = "encryption")]
pub(crate) async fn olm_machine(&self) -> Option<matrix_sdk_base::crypto::OlmMachine> {
self.base_client().olm_machine().await
}
#[cfg(feature = "encryption")]
pub(crate) async fn mark_request_as_sent(
&self,
request_id: &matrix_sdk_base::uuid::Uuid,
response: impl Into<matrix_sdk_base::crypto::IncomingResponse<'_>>,
) -> Result<(), matrix_sdk_base::Error> {
self.base_client().mark_request_as_sent(request_id, response).await
}
fn homeserver_from_user_id(user_id: &UserId) -> Result<Url> {
let homeserver = format!("https://{}", user_id.server_name());
#[allow(unused_mut)]
@@ -290,7 +313,7 @@ impl Client {
///
/// * `homeserver_url` - The new URL to use.
pub async fn set_homeserver(&self, homeserver_url: Url) {
let mut homeserver = self.homeserver.write().await;
let mut homeserver = self.inner.homeserver.write().await;
*homeserver = homeserver_url;
}
@@ -324,23 +347,23 @@ impl Client {
/// Is the client logged in.
pub async fn logged_in(&self) -> bool {
self.base_client.logged_in().await
self.inner.base_client.logged_in().await
}
/// The Homeserver of the client.
pub async fn homeserver(&self) -> Url {
self.homeserver.read().await.clone()
self.inner.homeserver.read().await.clone()
}
/// Get the user id of the current owner of the client.
pub async fn user_id(&self) -> Option<UserId> {
let session = self.base_client.session().read().await;
let session = self.inner.base_client.session().read().await;
session.as_ref().cloned().map(|s| s.user_id)
}
/// Get the device id that identifies the current session.
pub async fn device_id(&self) -> Option<DeviceIdBox> {
let session = self.base_client.session().read().await;
let session = self.inner.base_client.session().read().await;
session.as_ref().map(|s| s.device_id.clone())
}
@@ -351,7 +374,7 @@ impl Client {
/// Can be used with [`Client::restore_login`] to restore a previously
/// logged in session.
pub async fn session(&self) -> Option<Session> {
self.base_client.session().read().await.clone()
self.inner.base_client.session().read().await.clone()
}
/// Fetches the display name of the owner of the client.
@@ -469,7 +492,7 @@ impl Client {
/// Get a reference to the store.
pub fn store(&self) -> &Store {
self.base_client.store()
self.inner.base_client.store()
}
/// Sets the mxc avatar url of the client's owner. The avatar gets unset if
@@ -611,32 +634,41 @@ impl Client {
<H::Future as Future>::Output: EventHandlerResult,
{
let event_type = H::ID.1;
self.event_handlers.write().await.entry(H::ID).or_default().push(Box::new(move |data| {
let maybe_fut = serde_json::from_str(data.raw.get())
.map(|ev| handler.clone().handle_event(ev, data));
self.inner.event_handlers.write().await.entry(H::ID).or_default().push(Box::new(
move |data| {
let maybe_fut = serde_json::from_str(data.raw.get())
.map(|ev| handler.clone().handle_event(ev, data));
Box::pin(async move {
match maybe_fut {
Ok(Some(fut)) => {
fut.await.print_error(event_type);
}
Ok(None) => {
error!("Event handler for {} has an invalid context argument", event_type);
}
Err(e) => {
warn!(
"Failed to deserialize `{}` event, skipping event handler.\n\
Box::pin(async move {
match maybe_fut {
Ok(Some(fut)) => {
fut.await.print_error(event_type);
}
Ok(None) => {
error!(
"Event handler for {} has an invalid context argument",
event_type
);
}
Err(e) => {
warn!(
"Failed to deserialize `{}` event, skipping event handler.\n\
Deserialization error: {}",
event_type, e,
);
event_type, e,
);
}
}
}
})
}));
})
},
));
self
}
pub(crate) async fn event_handlers(&self) -> RwLockReadGuard<'_, EventHandlerMap> {
self.inner.event_handlers.read().await
}
/// Add an arbitrary value for use as event handler context.
///
/// The value can be obtained in an event handler by adding an argument of
@@ -678,10 +710,18 @@ impl Client {
where
T: Clone + Send + Sync + 'static,
{
self.event_handler_data.write().unwrap().insert(ctx);
self.inner.event_handler_data.write().unwrap().insert(ctx);
self
}
pub(crate) fn event_handler_context<T>(&self) -> Option<T>
where
T: Clone + Send + Sync + 'static,
{
let map = self.inner.event_handler_data.read().unwrap();
map.get::<T>().cloned()
}
/// Register a handler for a notification.
///
/// Similar to [`Client::register_event_handler`], but only allows functions
@@ -692,13 +732,19 @@ impl Client {
H: Fn(Notification, room::Room, Client) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
self.notification_handlers.write().await.push(Box::new(
self.inner.notification_handlers.write().await.push(Box::new(
move |notification, room, client| Box::pin((handler)(notification, room, client)),
));
self
}
pub(crate) async fn notification_handlers(
&self,
) -> RwLockReadGuard<'_, Vec<NotificationHandlerFn>> {
self.inner.notification_handlers.read().await
}
/// Get all the rooms the client knows about.
///
/// This will return the list of joined, invited, and left rooms.
@@ -849,7 +895,7 @@ impl Client {
/// "Logged in as {}, got device_id {} and access_token {}",
/// user, response.device_id, response.access_token
/// );
/// # matrix_sdk::Result::Ok(()) });
/// # Result::<_, matrix_sdk::Error>::Ok(()) });
/// ```
///
/// [`restore_login`]: #method.restore_login
@@ -1166,7 +1212,7 @@ impl Client {
///
/// * `response` - A successful login response.
async fn receive_login_response(&self, response: &login::Response) -> Result<()> {
if self.use_discovery_response {
if self.inner.use_discovery_response {
if let Some(well_known) = &response.well_known {
if let Ok(homeserver) = Url::parse(&well_known.homeserver.base_url) {
self.set_homeserver(homeserver).await;
@@ -1174,7 +1220,7 @@ impl Client {
}
}
self.base_client.receive_login_response(response).await?;
self.inner.base_client.receive_login_response(response).await?;
Ok(())
}
@@ -1192,9 +1238,52 @@ impl Client {
/// * `session` - A session that the user already has from a
/// previous login call.
///
/// # Examples
///
/// ```no_run
/// use matrix_sdk::{Client, Session, ruma::{DeviceIdBox, user_id}};
/// # use url::Url;
/// # use futures::executor::block_on;
/// # block_on(async {
///
/// let homeserver = Url::parse("http://example.com")?;
/// let client = Client::new(homeserver).await?;
///
/// let session = Session {
/// access_token: "My-Token".to_owned(),
/// user_id: user_id!("@example:localhost"),
/// device_id: DeviceIdBox::from("MYDEVICEID"),
/// };
///
/// client.restore_login(session).await?;
/// # anyhow::Result::<()>::Ok(()) });
/// ```
///
/// The `Session` object can also be created from the response the
/// [`Client::login()`] method returns:
///
/// ```no_run
/// use matrix_sdk::{Client, Session, ruma::{DeviceIdBox, user_id}};
/// # use url::Url;
/// # use futures::executor::block_on;
/// # block_on(async {
///
/// let homeserver = Url::parse("http://example.com")?;
/// let client = Client::new(homeserver).await?;
///
/// let session: Session = client
/// .login("example", "my-password", None, None)
/// .await?
/// .into();
///
/// // Persist the `Session` so it can later be used to restore the login.
/// client.restore_login(session).await?;
/// # anyhow::Result::<()>::Ok(()) });
/// ```
///
/// [`login`]: #method.login
pub async fn restore_login(&self, session: Session) -> Result<()> {
Ok(self.base_client.restore_login(session).await?)
Ok(self.inner.base_client.restore_login(session).await?)
}
/// Register a user to the server.
@@ -1241,8 +1330,8 @@ impl Client {
let homeserver = self.homeserver().await;
info!("Registering to {}", homeserver);
let config = if self.appservice_mode {
Some(self.http_client.request_config.force_auth())
let config = if self.inner.appservice_mode {
Some(self.inner.http_client.request_config.force_auth())
} else {
None
};
@@ -1307,14 +1396,14 @@ impl Client {
filter_name: &str,
definition: FilterDefinition<'_>,
) -> Result<String> {
if let Some(filter) = self.base_client.get_filter(filter_name).await? {
if let Some(filter) = self.inner.base_client.get_filter(filter_name).await? {
Ok(filter)
} else {
let user_id = self.user_id().await.ok_or(Error::AuthenticationRequired)?;
let request = FilterUploadRequest::new(&user_id, definition);
let response = self.send(request, None).await?;
self.base_client.receive_filter_upload(filter_name, &response).await?;
self.inner.base_client.receive_filter_upload(filter_name, &response).await?;
Ok(response.filter_id)
}
@@ -1468,7 +1557,7 @@ impl Client {
/// for room in response.chunk {
/// println!("Found room {:?}", room);
/// }
/// # matrix_sdk::Result::Ok(()) });
/// # Result::<_, matrix_sdk::Error>::Ok(()) });
/// ```
pub async fn public_rooms_filtered(
&self,
@@ -1526,8 +1615,8 @@ impl Client {
content_type: Some(content_type.essence_str()),
});
let request_config = self.http_client.request_config.timeout(timeout);
Ok(self.http_client.upload(request, Some(request_config)).await?)
let request_config = self.inner.http_client.request_config.timeout(timeout);
Ok(self.inner.http_client.upload(request, Some(request_config)).await?)
}
/// Send an arbitrary request to the server, without updating client state.
@@ -1568,7 +1657,7 @@ impl Client {
///
/// // Check the corresponding Response struct to find out what types are
/// // returned
/// # matrix_sdk::Result::Ok(()) });
/// # Result::<_, matrix_sdk::Error>::Ok(()) });
/// ```
pub async fn send<Request>(
&self,
@@ -1579,7 +1668,7 @@ impl Client {
Request: OutgoingRequest + Debug,
HttpError: From<FromHttpResponseError<Request::EndpointError>>,
{
Ok(self.http_client.send(request, config).await?)
Ok(self.inner.http_client.send(request, config).await?)
}
/// Get information of all our own devices.
@@ -1603,7 +1692,7 @@ impl Client {
/// device.display_name.as_deref().unwrap_or("")
/// );
/// }
/// # matrix_sdk::Result::Ok(()) });
/// # Result::<_, matrix_sdk::Error>::Ok(()) });
/// ```
pub async fn devices(&self) -> HttpResult<get_devices::Response> {
let request = get_devices::Request::new();
@@ -1656,7 +1745,7 @@ impl Client {
/// .await?;
/// }
/// }
/// # matrix_sdk::Result::Ok(()) });
/// # Result::<_, matrix_sdk::Error>::Ok(()) });
pub async fn delete_devices(
&self,
devices: &[DeviceIdBox],
@@ -1668,135 +1757,6 @@ impl Client {
self.send(request, None).await
}
pub(crate) async fn process_sync(
&self,
response: sync_events::Response,
) -> Result<SyncResponse> {
let response = self.base_client.receive_sync_response(response).await?;
let SyncResponse {
next_batch: _,
rooms,
presence,
account_data,
to_device: _,
device_lists: _,
device_one_time_keys_count: _,
ambiguity_changes: _,
notifications,
} = &response;
self.handle_sync_events(EventKind::GlobalAccountData, &None, &account_data.events).await?;
self.handle_sync_events(EventKind::Presence, &None, &presence.events).await?;
for (room_id, room_info) in &rooms.join {
let room = self.get_room(room_id);
if room.is_none() {
error!("Can't call event handler, room {} not found", room_id);
continue;
}
let JoinedRoom { unread_notifications: _, timeline, state, account_data, ephemeral } =
room_info;
self.handle_sync_events(EventKind::EphemeralRoomData, &room, &ephemeral.events).await?;
self.handle_sync_events(EventKind::RoomAccountData, &room, &account_data.events)
.await?;
self.handle_sync_state_events(&room, &state.events).await?;
self.handle_sync_timeline_events(&room, &timeline.events).await?;
}
for (room_id, room_info) in &rooms.leave {
let room = self.get_room(room_id);
if room.is_none() {
error!("Can't call event handler, room {} not found", room_id);
continue;
}
let LeftRoom { timeline, state, account_data } = room_info;
self.handle_sync_events(EventKind::RoomAccountData, &room, &account_data.events)
.await?;
self.handle_sync_state_events(&room, &state.events).await?;
self.handle_sync_timeline_events(&room, &timeline.events).await?;
}
for (room_id, room_info) in &rooms.invite {
let room = self.get_room(room_id);
if room.is_none() {
error!("Can't call event handler, room {} not found", room_id);
continue;
}
// FIXME: Destructure room_info
self.handle_sync_events(
EventKind::StrippedState,
&room,
&room_info.invite_state.events,
)
.await?;
}
// Construct notification event handler futures
let mut futures = Vec::new();
for handler in &*self.notification_handlers.read().await {
for (room_id, room_notifications) in notifications {
let room = match self.get_room(room_id) {
Some(room) => room,
None => {
warn!("Can't call notification handler, room {} not found", room_id);
continue;
}
};
futures.extend(room_notifications.iter().map(|notification| {
(handler)(notification.clone(), room.clone(), self.clone())
}));
}
}
// Run the notification handler futures with the
// `self.notification_handlers` lock no longer being held, in order.
for fut in futures {
fut.await;
}
Ok(response)
}
async fn sync_loop_helper(
&self,
sync_settings: &mut crate::config::SyncSettings<'_>,
) -> Result<SyncResponse> {
let response = self.sync_once(sync_settings.clone()).await;
match response {
Ok(r) => {
sync_settings.token = Some(r.next_batch.clone());
Ok(r)
}
Err(e) => {
error!("Received an invalid response: {}", e);
sleep::new(Duration::from_secs(1)).await;
Err(e)
}
}
}
async fn delay_sync(last_sync_time: &mut Option<Instant>) {
let now = Instant::now();
// If the last sync happened less than a second ago, sleep for a
// while to not hammer out requests if the server doesn't respect
// the sync timeout.
if let Some(t) = last_sync_time {
if now - *t <= Duration::from_secs(1) {
sleep::new(Duration::from_secs(1)).await;
}
}
*last_sync_time = Some(now);
}
/// Synchronize the client's state with the latest state on the server.
///
/// ## Syncing Events
@@ -1874,7 +1834,7 @@ impl Client {
/// // Now keep on syncing forever. `sync()` will use the stored sync token
/// // from our `sync_once()` call automatically.
/// client.sync(SyncSettings::default()).await;
/// # matrix_sdk::Result::Ok(()) });
/// # Result::<_, matrix_sdk::Error>::Ok(()) });
/// ```
///
/// [`sync`]: #method.sync
@@ -1901,9 +1861,9 @@ impl Client {
timeout: sync_settings.timeout,
});
let request_config = self.http_client.request_config.timeout(
let request_config = self.inner.http_client.request_config.timeout(
sync_settings.timeout.unwrap_or_else(|| Duration::from_secs(0))
+ self.http_client.request_config.timeout,
+ self.inner.http_client.request_config.timeout,
);
let response = self.send(request, Some(request_config)).await?;
@@ -1914,7 +1874,7 @@ impl Client {
error!(error =? e, "Error while sending outgoing E2EE requests");
};
self.sync_beat.notify(usize::MAX);
self.inner.sync_beat.notify(usize::MAX);
Ok(response)
}
@@ -1931,7 +1891,7 @@ impl Client {
/// method to react to individual events. If you instead wish to handle
/// events in a bulk manner the [`Client::sync_with_callback`] and
/// [`Client::sync_stream`] methods can be used instead. Those two methods
/// repeadetly return the whole sync response.
/// repeatedly return the whole sync response.
///
/// # Arguments
///
@@ -1965,7 +1925,7 @@ impl Client {
/// // Now keep on syncing forever. `sync()` will use the latest sync token
/// // automatically.
/// client.sync(SyncSettings::default()).await;
/// # matrix_sdk::Result::Ok(()) });
/// # Result::<_, matrix_sdk::Error>::Ok(()) });
/// ```
///
/// [argument docs]: #method.sync_once
@@ -2094,7 +2054,7 @@ impl Client {
/// }
/// }
///
/// # matrix_sdk::Result::Ok(()) });
/// # Result::<_, matrix_sdk::Error>::Ok(()) });
/// ```
#[instrument]
pub async fn sync_stream<'a>(
@@ -2119,7 +2079,7 @@ impl Client {
/// Get the current, if any, sync token of the client.
/// This will be None if the client didn't sync at least once.
pub async fn sync_token(&self) -> Option<String> {
self.base_client.sync_token().await
self.inner.base_client.sync_token().await
}
/// Get a media file's content.
@@ -2138,7 +2098,7 @@ impl Client {
use_cache: bool,
) -> Result<Vec<u8>> {
let content = if use_cache {
self.base_client.store().get_media_content(request).await?
self.inner.base_client.store().get_media_content(request).await?
} else {
None
};
@@ -2182,7 +2142,7 @@ impl Client {
};
if use_cache {
self.base_client.store().add_media_content(request, content.clone()).await?;
self.inner.base_client.store().add_media_content(request, content.clone()).await?;
}
Ok(content)
@@ -2195,7 +2155,7 @@ impl Client {
///
/// * `request` - The `MediaRequest` of the content.
pub async fn remove_media_content(&self, request: &MediaRequest) -> Result<()> {
Ok(self.base_client.store().remove_media_content(request).await?)
Ok(self.inner.base_client.store().remove_media_content(request).await?)
}
/// Delete all the media content corresponding to the given
@@ -2205,7 +2165,7 @@ impl Client {
///
/// * `uri` - The `MxcUri` of the files.
pub async fn remove_media_content_for_uri(&self, uri: &MxcUri) -> Result<()> {
Ok(self.base_client.store().remove_media_content_for_uri(uri).await?)
Ok(self.inner.base_client.store().remove_media_content_for_uri(uri).await?)
}
/// Get the file of the given media event content.
@@ -2724,7 +2684,7 @@ pub(crate) mod test {
.add_state_event(EventsJson::PowerLevels)
.build_sync_response();
client.base_client.receive_sync_response(response).await.unwrap();
client.inner.base_client.receive_sync_response(response).await.unwrap();
let room_id = room_id!("!SVkFJHzfwvuaIEawgC:localhost");
assert_eq!(client.homeserver().await, Url::parse(&mockito::server_url()).unwrap());

View File

@@ -90,7 +90,7 @@ impl ClientConfig {
/// let client_config = ClientConfig::new()
/// .proxy("http://localhost:8080")?;
///
/// # matrix_sdk::Result::Ok(())
/// # Result::<_, matrix_sdk::Error>::Ok(())
/// ```
#[cfg(not(target_arch = "wasm32"))]
pub fn proxy(mut self, proxy: &str) -> Result<Self> {
@@ -105,7 +105,7 @@ impl ClientConfig {
}
/// Set a custom HTTP user agent for the client.
pub fn user_agent(mut self, user_agent: &str) -> std::result::Result<Self, InvalidHeaderValue> {
pub fn user_agent(mut self, user_agent: &str) -> Result<Self, InvalidHeaderValue> {
self.user_agent = Some(HeaderValue::from_str(user_agent)?);
Ok(self)
}

View File

@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{ops::Deref, result::Result as StdResult};
use std::ops::Deref;
use matrix_sdk_base::crypto::{
store::CryptoStoreError, Device as BaseDevice, LocalTrust, ReadOnlyDevice,
@@ -250,7 +250,7 @@ impl Device {
/// }
/// # anyhow::Result::<()>::Ok(()) });
/// ```
pub async fn verify(&self) -> std::result::Result<(), ManualVerifyError> {
pub async fn verify(&self) -> Result<(), ManualVerifyError> {
let request = self.inner.verify().await?;
self.client.send(request, None).await?;
@@ -391,10 +391,7 @@ impl Device {
/// # Arguments
///
/// * `trust_state` - The new trust state that should be set for the device.
pub async fn set_local_trust(
&self,
trust_state: LocalTrust,
) -> StdResult<(), CryptoStoreError> {
pub async fn set_local_trust(&self, trust_state: LocalTrust) -> Result<(), CryptoStoreError> {
self.inner.set_local_trust(trust_state).await
}
}

View File

@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{result::Result, sync::Arc};
use std::sync::Arc;
use matrix_sdk_base::{
crypto::{

View File

@@ -266,6 +266,7 @@ use matrix_sdk_base::{
use matrix_sdk_common::{instant::Duration, uuid::Uuid};
use ruma::{
api::client::r0::{
backup::add_backup_keys::Response as KeysBackupResponse,
keys::{get_keys, upload_keys, upload_signing_keys::Request as UploadSigningKeysRequest},
message::send_message_event,
to_device::send_event_to_device::{
@@ -294,7 +295,7 @@ impl Client {
/// called the fingerprint of the device.
#[cfg(feature = "encryption")]
pub async fn ed25519_key(&self) -> Option<String> {
self.base_client.olm_machine().await.map(|o| o.identity_keys().ed25519().to_owned())
self.olm_machine().await.map(|o| o.identity_keys().ed25519().to_owned())
}
/// Get the status of the private cross signing keys.
@@ -303,7 +304,7 @@ impl Client {
/// stored locally.
#[cfg(feature = "encryption")]
pub async fn cross_signing_status(&self) -> Option<CrossSigningStatus> {
if let Some(machine) = self.base_client.olm_machine().await {
if let Some(machine) = self.olm_machine().await {
Some(machine.cross_signing_status().await)
} else {
None
@@ -316,13 +317,13 @@ impl Client {
/// capable devices up to date.
#[cfg(feature = "encryption")]
pub async fn tracked_users(&self) -> HashSet<UserId> {
self.base_client.olm_machine().await.map(|o| o.tracked_users()).unwrap_or_default()
self.olm_machine().await.map(|o| o.tracked_users()).unwrap_or_default()
}
/// Get a verification object with the given flow id.
#[cfg(feature = "encryption")]
pub async fn get_verification(&self, user_id: &UserId, flow_id: &str) -> Option<Verification> {
let olm = self.base_client.olm_machine().await?;
let olm = self.olm_machine().await?;
olm.get_verification(user_id, flow_id).map(|v| match v {
matrix_sdk_base::crypto::Verification::SasV1(s) => {
SasVerification { inner: s, client: self.clone() }.into()
@@ -342,7 +343,7 @@ impl Client {
user_id: &UserId,
flow_id: impl AsRef<str>,
) -> Option<VerificationRequest> {
let olm = self.base_client.olm_machine().await?;
let olm = self.olm_machine().await?;
olm.get_verification_request(user_id, flow_id)
.map(|r| VerificationRequest { inner: r, client: self.clone() })
@@ -387,7 +388,7 @@ impl Client {
user_id: &UserId,
device_id: &DeviceId,
) -> StdResult<Option<Device>, CryptoStoreError> {
let device = self.base_client.get_device(user_id, device_id).await?;
let device = self.base_client().get_device(user_id, device_id).await?;
Ok(device.map(|d| Device { inner: d, client: self.clone() }))
}
@@ -424,7 +425,7 @@ impl Client {
&self,
user_id: &UserId,
) -> StdResult<UserDevices, CryptoStoreError> {
let devices = self.base_client.get_user_devices(user_id).await?;
let devices = self.base_client().get_user_devices(user_id).await?;
Ok(UserDevices { inner: devices, client: self.clone() })
}
@@ -467,7 +468,7 @@ impl Client {
) -> StdResult<Option<crate::encryption::identities::UserIdentity>, CryptoStoreError> {
use crate::encryption::identities::UserIdentity;
if let Some(olm) = self.base_client.olm_machine().await {
if let Some(olm) = self.olm_machine().await {
let identity = olm.get_identity(user_id).await?;
Ok(identity.map(|i| match i {
@@ -525,7 +526,7 @@ impl Client {
/// # anyhow::Result::<()>::Ok(()) });
#[cfg(feature = "encryption")]
pub async fn bootstrap_cross_signing(&self, auth_data: Option<AuthData<'_>>) -> Result<()> {
let olm = self.base_client.olm_machine().await.ok_or(Error::AuthenticationRequired)?;
let olm = self.olm_machine().await.ok_or(Error::AuthenticationRequired)?;
let (request, signature_request) = olm.bootstrap_cross_signing(false).await?;
@@ -600,7 +601,7 @@ impl Client {
passphrase: &str,
predicate: impl FnMut(&matrix_sdk_base::crypto::olm::InboundGroupSession) -> bool,
) -> Result<()> {
let olm = self.base_client.olm_machine().await.ok_or(Error::AuthenticationRequired)?;
let olm = self.olm_machine().await.ok_or(Error::AuthenticationRequired)?;
let keys = olm.export_keys(predicate).await?;
let passphrase = zeroize::Zeroizing::new(passphrase.to_owned());
@@ -660,7 +661,7 @@ impl Client {
path: PathBuf,
passphrase: &str,
) -> StdResult<RoomKeyImportResult, RoomKeyImportError> {
let olm = self.base_client.olm_machine().await.ok_or(RoomKeyImportError::StoreClosed)?;
let olm = self.olm_machine().await.ok_or(RoomKeyImportError::StoreClosed)?;
let passphrase = zeroize::Zeroizing::new(passphrase.to_owned());
let decrypt = move || {
@@ -683,7 +684,7 @@ impl Client {
) -> serde_json::Result<RoomEvent> {
use ruma::serde::JsonObject;
if let Some(machine) = self.base_client.olm_machine().await {
if let Some(machine) = self.olm_machine().await {
if let AnyRoomEvent::Message(event) = event {
if let AnyMessageEvent::RoomEncrypted(_) = event {
let room_id = event.room_id();
@@ -726,7 +727,7 @@ impl Client {
let request = assign!(get_keys::Request::new(), { device_keys });
let response = self.send(request, None).await?;
self.base_client.mark_request_as_sent(request_id, &response).await?;
self.mark_request_as_sent(request_id, &response).await?;
Ok(response)
}
@@ -843,7 +844,7 @@ impl Client {
if let Some(room) = self.get_joined_room(&response.room_id) {
Ok(Some(room))
} else {
self.sync_beat.listen().wait_timeout(SYNC_WAIT_TIME);
self.inner.sync_beat.listen().wait_timeout(SYNC_WAIT_TIME);
Ok(self.get_joined_room(&response.room_id))
}
}
@@ -859,11 +860,11 @@ impl Client {
&self,
users: impl Iterator<Item = &UserId>,
) -> Result<()> {
let _lock = self.key_claim_lock.lock().await;
let _lock = self.inner.key_claim_lock.lock().await;
if let Some((request_id, request)) = self.base_client.get_missing_sessions(users).await? {
if let Some((request_id, request)) = self.base_client().get_missing_sessions(users).await? {
let response = self.send(request, None).await?;
self.base_client.mark_request_as_sent(&request_id, &response).await?;
self.mark_request_as_sent(&request_id, &response).await?;
}
Ok(())
@@ -892,7 +893,7 @@ impl Client {
);
let response = self.send(request.clone(), None).await?;
self.base_client.mark_request_as_sent(request_id, &response).await?;
self.mark_request_as_sent(request_id, &response).await?;
Ok(response)
}
@@ -970,25 +971,41 @@ impl Client {
}
OutgoingRequests::ToDeviceRequest(request) => {
let response = self.send_to_device(request).await?;
self.base_client.mark_request_as_sent(r.request_id(), &response).await?;
self.mark_request_as_sent(r.request_id(), &response).await?;
}
OutgoingRequests::SignatureUpload(request) => {
let response = self.send(request.clone(), None).await?;
self.base_client.mark_request_as_sent(r.request_id(), &response).await?;
self.mark_request_as_sent(r.request_id(), &response).await?;
}
OutgoingRequests::RoomMessage(request) => {
let response = self.room_send_helper(request).await?;
self.base_client.mark_request_as_sent(r.request_id(), &response).await?;
self.mark_request_as_sent(r.request_id(), &response).await?;
}
OutgoingRequests::KeysClaim(request) => {
let response = self.send(request.clone(), None).await?;
self.base_client.mark_request_as_sent(r.request_id(), &response).await?;
self.mark_request_as_sent(r.request_id(), &response).await?;
}
OutgoingRequests::KeysBackup(request) => {
let response = self.send_backup_request(request).await?;
self.mark_request_as_sent(r.request_id(), &response).await?;
}
}
Ok(())
}
async fn send_backup_request(
&self,
request: &matrix_sdk_base::crypto::KeysBackupRequest,
) -> Result<KeysBackupResponse> {
let request = ruma::api::client::r0::backup::add_backup_keys::Request::new(
&request.version,
request.rooms.to_owned(),
);
Ok(self.send(request, None).await?)
}
pub(crate) async fn send_outgoing_requests(&self) -> Result<()> {
const MAX_CONCURRENT_REQUESTS: usize = 20;
@@ -998,7 +1015,7 @@ impl Client {
warn!("Error while claiming one-time keys {:?}", e);
}
let outgoing_requests = stream::iter(self.base_client.outgoing_requests().await?)
let outgoing_requests = stream::iter(self.base_client().outgoing_requests().await?)
.map(|r| self.send_outgoing_request(r));
let requests = outgoing_requests.buffer_unordered(MAX_CONCURRENT_REQUESTS);

View File

@@ -71,7 +71,7 @@ impl QrVerification {
///
/// The [`to_bytes()`](#method.to_bytes) method can be used to instead
/// output the raw bytes that should be encoded as a QR code.
pub fn to_qr_code(&self) -> std::result::Result<QrCode, EncodingError> {
pub fn to_qr_code(&self) -> Result<QrCode, EncodingError> {
self.inner.to_qr_code()
}
@@ -80,7 +80,7 @@ impl QrVerification {
///
/// The [`to_qr_code()`](#method.to_qr_code) method can be used to instead
/// output a `QrCode` object that can be rendered.
pub fn to_bytes(&self) -> std::result::Result<Vec<u8>, EncodingError> {
pub fn to_bytes(&self) -> Result<Vec<u8>, EncodingError> {
self.inner.to_bytes()
}

View File

@@ -40,7 +40,7 @@ use thiserror::Error;
use url::ParseError as UrlParseError;
/// Result type of the matrix-sdk.
pub type Result<T> = std::result::Result<T, Error>;
pub type Result<T, E = Error> = std::result::Result<T, E>;
/// Result type of a pure HTTP request.
pub type HttpResult<T> = std::result::Result<T, HttpError>;

View File

@@ -186,8 +186,7 @@ pub struct Ctx<T>(pub T);
impl<T: Clone + Send + Sync + 'static> EventHandlerContext for Ctx<T> {
fn from_data(data: &EventHandlerData<'_>) -> Option<Self> {
let anymap = data.client.event_handler_data.read().unwrap();
Some(Ctx(anymap.get::<T>()?.clone()))
data.client.event_handler_context::<T>().map(Ctx)
}
}
@@ -330,8 +329,7 @@ impl Client {
// Construct event handler futures
let futures: Vec<_> = self
.event_handlers
.read()
.event_handlers()
.await
.get(&event_handler_id)
.into_iter()

View File

@@ -53,8 +53,8 @@ pub mod event_handler;
mod http_client;
/// High-level room API
pub mod room;
/// High-level room API
mod room_member;
mod sync;
#[cfg(feature = "encryption")]
pub mod encryption;

View File

@@ -168,14 +168,17 @@ impl Common {
pub(crate) async fn request_members(&self) -> Result<Option<MembersResponse>> {
if let Some(mutex) =
self.client.members_request_locks.get(self.inner.room_id()).map(|m| m.clone())
self.client.inner.members_request_locks.get(self.inner.room_id()).map(|m| m.clone())
{
mutex.lock().await;
Ok(None)
} else {
let mutex = Arc::new(Mutex::new(()));
self.client.members_request_locks.insert(self.inner.room_id().clone(), mutex.clone());
self.client
.inner
.members_request_locks
.insert(self.inner.room_id().clone(), mutex.clone());
let _guard = mutex.lock().await;
@@ -183,9 +186,9 @@ impl Common {
let response = self.client.send(request, None).await?;
let response =
self.client.base_client.receive_members(self.inner.room_id(), &response).await?;
self.client.base_client().receive_members(self.inner.room_id(), &response).await?;
self.client.members_request_locks.remove(self.inner.room_id());
self.client.inner.members_request_locks.remove(self.inner.room_id());
Ok(Some(response))
}

View File

@@ -170,37 +170,39 @@ impl Joined {
/// if let Some(room) = client.get_joined_room(&room_id) {
/// room.typing_notice(true).await?
/// }
/// # matrix_sdk::Result::Ok(()) });
/// # Result::<_, matrix_sdk::Error>::Ok(()) });
/// ```
pub async fn typing_notice(&self, typing: bool) -> Result<()> {
// Only send a request to the homeserver if the old timeout has elapsed
// or the typing notice changed state within the
// TYPING_NOTICE_TIMEOUT
let send =
if let Some(typing_time) = self.client.typing_notice_times.get(self.inner.room_id()) {
if typing_time.elapsed() > TYPING_NOTICE_RESEND_TIMEOUT {
// We always reactivate the typing notice if typing is true or
// we may need to deactivate it if it's
// currently active if typing is false
typing || typing_time.elapsed() <= TYPING_NOTICE_TIMEOUT
} else {
// Only send a request when we need to deactivate typing
!typing
}
let send = if let Some(typing_time) =
self.client.inner.typing_notice_times.get(self.inner.room_id())
{
if typing_time.elapsed() > TYPING_NOTICE_RESEND_TIMEOUT {
// We always reactivate the typing notice if typing is true or
// we may need to deactivate it if it's
// currently active if typing is false
typing || typing_time.elapsed() <= TYPING_NOTICE_TIMEOUT
} else {
// Typing notice is currently deactivated, therefore, send a request
// only when it's about to be activated
typing
};
// Only send a request when we need to deactivate typing
!typing
}
} else {
// Typing notice is currently deactivated, therefore, send a request
// only when it's about to be activated
typing
};
if send {
let typing = if typing {
self.client
.inner
.typing_notice_times
.insert(self.inner.room_id().clone(), Instant::now());
Typing::Yes(TYPING_NOTICE_TIMEOUT)
} else {
self.client.typing_notice_times.remove(self.inner.room_id());
self.client.inner.typing_notice_times.remove(self.inner.room_id());
Typing::No
};
@@ -278,7 +280,7 @@ impl Joined {
/// if let Some(room) = client.get_joined_room(&room_id) {
/// room.enable_encryption().await?
/// }
/// # matrix_sdk::Result::Ok(()) });
/// # Result::<_, matrix_sdk::Error>::Ok(()) });
/// ```
pub async fn enable_encryption(&self) -> Result<()> {
use ruma::{
@@ -294,7 +296,7 @@ impl Joined {
// TODO do we want to return an error here if we time out? This
// could be quite useful if someone wants to enable encryption and
// send a message right after it's enabled.
self.client.sync_beat.listen().wait_timeout(SYNC_WAIT_TIME);
self.client.inner.sync_beat.listen().wait_timeout(SYNC_WAIT_TIME);
}
Ok(())
@@ -311,7 +313,7 @@ impl Joined {
// TODO expose this publicly so people can pre-share a group session if
// e.g. a user starts to type a message for a room.
if let Some(mutex) =
self.client.group_session_locks.get(self.inner.room_id()).map(|m| m.clone())
self.client.inner.group_session_locks.get(self.inner.room_id()).map(|m| m.clone())
{
// If a group session share request is already going on,
// await the release of the lock.
@@ -320,7 +322,10 @@ impl Joined {
// Otherwise create a new lock and share the group
// session.
let mutex = Arc::new(Mutex::new(()));
self.client.group_session_locks.insert(self.inner.room_id().clone(), mutex.clone());
self.client
.inner
.group_session_locks
.insert(self.inner.room_id().clone(), mutex.clone());
let _guard = mutex.lock().await;
@@ -334,13 +339,13 @@ impl Joined {
let response = self.share_group_session().await;
self.client.group_session_locks.remove(self.inner.room_id());
self.client.inner.group_session_locks.remove(self.inner.room_id());
// If one of the responses failed invalidate the group
// session as using it would end up in undecryptable
// messages.
if let Err(r) = response {
self.client.base_client.invalidate_group_session(self.inner.room_id()).await?;
self.client.base_client().invalidate_group_session(self.inner.room_id()).await?;
return Err(r);
}
}
@@ -357,12 +362,12 @@ impl Joined {
#[cfg(feature = "encryption")]
async fn share_group_session(&self) -> Result<()> {
let mut requests =
self.client.base_client.share_group_session(self.inner.room_id()).await?;
self.client.base_client().share_group_session(self.inner.room_id()).await?;
for request in requests.drain(..) {
let response = self.client.send_to_device(&request).await?;
self.client.base_client.mark_request_as_sent(&request.txn_id, &response).await?;
self.client.mark_request_as_sent(&request.txn_id, &response).await?;
}
Ok(())
@@ -449,7 +454,7 @@ impl Joined {
/// if let Some(room) = client.get_joined_room(&room_id) {
/// room.send(content, Some(txn_id)).await?;
/// }
/// # matrix_sdk::Result::Ok(()) });
/// # Result::<_, matrix_sdk::Error>::Ok(()) });
/// ```
///
/// [`SyncMessageEvent`]: ruma::events::SyncMessageEvent
@@ -517,7 +522,7 @@ impl Joined {
/// if let Some(room) = client.get_joined_room(&room_id) {
/// room.send_raw(content, "m.room.message", None).await?;
/// }
/// # matrix_sdk::Result::Ok(()) });
/// # Result::<_, matrix_sdk::Error>::Ok(()) });
/// ```
///
/// [`SyncMessageEvent`]: ruma::events::SyncMessageEvent
@@ -554,8 +559,7 @@ impl Joined {
self.preshare_group_session().await?;
let olm =
self.client.base_client.olm_machine().await.expect("Olm machine wasn't started");
let olm = self.client.olm_machine().await.expect("Olm machine wasn't started");
let encrypted_content =
olm.encrypt_raw(self.inner.room_id(), content, event_type).await?;
@@ -631,7 +635,7 @@ impl Joined {
/// None,
/// ).await?;
/// }
/// # matrix_sdk::Result::Ok(()) });
/// # Result::<_, matrix_sdk::Error>::Ok(()) });
/// ```
pub async fn send_attachment<R: Read>(
&self,
@@ -698,7 +702,7 @@ impl Joined {
/// if let Some(room) = client.get_joined_room(&room_id) {
/// room.send_state_event(content, "").await?;
/// }
/// # matrix_sdk::Result::Ok(()) });
/// # Result::<_, matrix_sdk::Error>::Ok(()) });
/// ```
pub async fn send_state_event(
&self,
@@ -791,7 +795,7 @@ impl Joined {
/// let reason = Some("Indecent material");
/// room.redact(&event_id, reason, None).await?;
/// }
/// # matrix_sdk::Result::Ok(()) });
/// # Result::<_, matrix_sdk::Error>::Ok(()) });
/// ```
pub async fn redact(
&self,
@@ -834,7 +838,7 @@ impl Joined {
///
/// room.set_tag("u.work", tag_info ).await?;
/// }
/// # matrix_sdk::Result::Ok(()) });
/// # Result::<_, matrix_sdk::Error>::Ok(()) });
/// ```
pub async fn set_tag(&self, tag: &str, tag_info: TagInfo) -> HttpResult<create_tag::Response> {
let user_id = self.client.user_id().await.ok_or(HttpError::AuthenticationRequired)?;

View File

@@ -0,0 +1,144 @@
use std::time::Duration;
use futures_timer::Delay as sleep;
use matrix_sdk_base::{
deserialized_responses::{JoinedRoom, LeftRoom, SyncResponse},
instant::Instant,
};
use ruma::api::client::r0::sync::sync_events;
use tracing::{error, warn};
use crate::{event_handler::EventKind, Client, Result};
/// Internal functionality related to getting events from the server
/// (`sync_events` endpoint)
impl Client {
pub(crate) async fn process_sync(
&self,
response: sync_events::Response,
) -> Result<SyncResponse> {
let response = self.base_client().receive_sync_response(response).await?;
let SyncResponse {
next_batch: _,
rooms,
presence,
account_data,
to_device: _,
device_lists: _,
device_one_time_keys_count: _,
ambiguity_changes: _,
notifications,
} = &response;
self.handle_sync_events(EventKind::GlobalAccountData, &None, &account_data.events).await?;
self.handle_sync_events(EventKind::Presence, &None, &presence.events).await?;
for (room_id, room_info) in &rooms.join {
let room = self.get_room(room_id);
if room.is_none() {
error!("Can't call event handler, room {} not found", room_id);
continue;
}
let JoinedRoom { unread_notifications: _, timeline, state, account_data, ephemeral } =
room_info;
self.handle_sync_events(EventKind::EphemeralRoomData, &room, &ephemeral.events).await?;
self.handle_sync_events(EventKind::RoomAccountData, &room, &account_data.events)
.await?;
self.handle_sync_state_events(&room, &state.events).await?;
self.handle_sync_timeline_events(&room, &timeline.events).await?;
}
for (room_id, room_info) in &rooms.leave {
let room = self.get_room(room_id);
if room.is_none() {
error!("Can't call event handler, room {} not found", room_id);
continue;
}
let LeftRoom { timeline, state, account_data } = room_info;
self.handle_sync_events(EventKind::RoomAccountData, &room, &account_data.events)
.await?;
self.handle_sync_state_events(&room, &state.events).await?;
self.handle_sync_timeline_events(&room, &timeline.events).await?;
}
for (room_id, room_info) in &rooms.invite {
let room = self.get_room(room_id);
if room.is_none() {
error!("Can't call event handler, room {} not found", room_id);
continue;
}
// FIXME: Destructure room_info
self.handle_sync_events(
EventKind::StrippedState,
&room,
&room_info.invite_state.events,
)
.await?;
}
// Construct notification event handler futures
let mut futures = Vec::new();
for handler in &*self.notification_handlers().await {
for (room_id, room_notifications) in notifications {
let room = match self.get_room(room_id) {
Some(room) => room,
None => {
warn!("Can't call notification handler, room {} not found", room_id);
continue;
}
};
futures.extend(room_notifications.iter().map(|notification| {
(handler)(notification.clone(), room.clone(), self.clone())
}));
}
}
// Run the notification handler futures with the
// `self.notification_handlers` lock no longer being held, in order.
for fut in futures {
fut.await;
}
Ok(response)
}
pub(crate) async fn sync_loop_helper(
&self,
sync_settings: &mut crate::config::SyncSettings<'_>,
) -> Result<SyncResponse> {
let response = self.sync_once(sync_settings.clone()).await;
match response {
Ok(r) => {
sync_settings.token = Some(r.next_batch.clone());
Ok(r)
}
Err(e) => {
error!("Received an invalid response: {}", e);
sleep::new(Duration::from_secs(1)).await;
Err(e)
}
}
}
pub(crate) async fn delay_sync(last_sync_time: &mut Option<Instant>) {
let now = Instant::now();
// If the last sync happened less than a second ago, sleep for a
// while to not hammer out requests if the server doesn't respect
// the sync timeout.
if let Some(t) = last_sync_time {
if now - *t <= Duration::from_secs(1) {
sleep::new(Duration::from_secs(1)).await;
}
}
*last_sync_time = Some(now);
}
}