mirror of
https://github.com/matrix-org/matrix-rust-sdk.git
synced 2026-05-06 15:04:11 -04:00
Merge remote-tracking branch 'upstream/main' into ben-wasm-store
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
# Or if we're able to ignore certain lines.
|
||||
Fo = "Fo"
|
||||
BA = "BA"
|
||||
UE = "UE"
|
||||
|
||||
[files]
|
||||
# Our json files contain a bunch of base64 encoded ed25519 keys which aren't
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::{net::ToSocketAddrs, result::Result as StdResult};
|
||||
use std::net::ToSocketAddrs;
|
||||
|
||||
use matrix_sdk::{
|
||||
bytes::Bytes,
|
||||
@@ -168,7 +168,7 @@ mod handlers {
|
||||
_user_id: String,
|
||||
appservice: AppService,
|
||||
request: http::Request<Bytes>,
|
||||
) -> StdResult<impl warp::Reply, Rejection> {
|
||||
) -> Result<impl warp::Reply, Rejection> {
|
||||
if let Some(user_exists) = appservice.event_handler.users.lock().await.as_mut() {
|
||||
let request =
|
||||
query_user::IncomingRequest::try_from_http_request(request).map_err(Error::from)?;
|
||||
@@ -185,7 +185,7 @@ mod handlers {
|
||||
_room_id: String,
|
||||
appservice: AppService,
|
||||
request: http::Request<Bytes>,
|
||||
) -> StdResult<impl warp::Reply, Rejection> {
|
||||
) -> Result<impl warp::Reply, Rejection> {
|
||||
if let Some(room_exists) = appservice.event_handler.rooms.lock().await.as_mut() {
|
||||
let request =
|
||||
query_room::IncomingRequest::try_from_http_request(request).map_err(Error::from)?;
|
||||
@@ -202,7 +202,7 @@ mod handlers {
|
||||
_txn_id: String,
|
||||
appservice: AppService,
|
||||
request: http::Request<Bytes>,
|
||||
) -> StdResult<impl warp::Reply, Rejection> {
|
||||
) -> Result<impl warp::Reply, Rejection> {
|
||||
let incoming_transaction: ruma::api::appservice::event::push_events::v1::IncomingRequest =
|
||||
ruma::api::IncomingRequest::try_from_http_request(request).map_err(Error::from)?;
|
||||
|
||||
@@ -224,7 +224,7 @@ struct ErrorMessage {
|
||||
message: String,
|
||||
}
|
||||
|
||||
pub async fn handle_rejection(err: Rejection) -> std::result::Result<impl Reply, Rejection> {
|
||||
pub async fn handle_rejection(err: Rejection) -> Result<impl Reply, Rejection> {
|
||||
if err.find::<Unauthorized>().is_some() || err.find::<warp::reject::InvalidQuery>().is_some() {
|
||||
let code = http::StatusCode::UNAUTHORIZED;
|
||||
let message = "UNAUTHORIZED";
|
||||
|
||||
@@ -58,7 +58,7 @@ default-features = false
|
||||
features = ["sync", "fs"]
|
||||
|
||||
[dev-dependencies]
|
||||
futures = { version = "0.3.15", default-features = false }
|
||||
futures = { version = "0.3.15", default-features = false, features = ["executor"] }
|
||||
http = "0.2.4"
|
||||
matrix-sdk-test = { version = "0.4.0", path = "../matrix-sdk-test" }
|
||||
|
||||
|
||||
@@ -15,11 +15,26 @@
|
||||
|
||||
//! User sessions.
|
||||
|
||||
use ruma::{DeviceId, UserId};
|
||||
use ruma::{DeviceIdBox, UserId};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// A user session, containing an access token and information about the
|
||||
/// associated user account.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```
|
||||
/// use matrix_sdk_base::Session;
|
||||
/// use ruma::{DeviceIdBox, user_id};
|
||||
///
|
||||
/// let session = Session {
|
||||
/// access_token: "My-Token".to_owned(),
|
||||
/// user_id: user_id!("@example:localhost"),
|
||||
/// device_id: DeviceIdBox::from("MYDEVICEID"),
|
||||
/// };
|
||||
///
|
||||
/// assert_eq!(session.device_id.as_str(), "MYDEVICEID");
|
||||
/// ```
|
||||
#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)]
|
||||
pub struct Session {
|
||||
/// The access token used for this session.
|
||||
@@ -27,5 +42,15 @@ pub struct Session {
|
||||
/// The user the access token was issued for.
|
||||
pub user_id: UserId,
|
||||
/// The ID of the client device
|
||||
pub device_id: Box<DeviceId>,
|
||||
pub device_id: DeviceIdBox,
|
||||
}
|
||||
|
||||
impl From<ruma::api::client::r0::session::login::Response> for Session {
|
||||
fn from(response: ruma::api::client::r0::session::login::Response) -> Self {
|
||||
Self {
|
||||
access_token: response.access_token,
|
||||
user_id: response.user_id,
|
||||
device_id: response.device_id,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,7 +33,7 @@ version = "0.1.9"
|
||||
features = ["now"]
|
||||
|
||||
[target.'cfg(target_arch = "wasm32")'.dependencies]
|
||||
futures-util = { version = "0.3.15", default-features = false }
|
||||
futures-locks = { version = "0.6.0", default-features = false }
|
||||
async-lock = "2.4.0"
|
||||
futures-util = { version = "0.3.15", default-features = false, features = ["channel"] }
|
||||
wasm-bindgen-futures = "0.4.24"
|
||||
uuid = { version = "0.8.2", default-features = false, features = ["v4", "wasm-bindgen"] }
|
||||
|
||||
@@ -1,8 +1,4 @@
|
||||
// could switch to futures-lock completely at some point, blocker:
|
||||
// https://github.com/asomers/futures-locks/issues/34
|
||||
// https://www.reddit.com/r/rust/comments/f4zldz/i_audited_3_different_implementation_of_async/
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
pub use futures_locks::{Mutex, MutexGuard, RwLock, RwLockReadGuard, RwLockWriteGuard};
|
||||
pub use async_lock::{Mutex, MutexGuard, RwLock, RwLockReadGuard, RwLockWriteGuard};
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub use tokio::sync::{Mutex, MutexGuard, RwLock, RwLockReadGuard, RwLockWriteGuard};
|
||||
|
||||
@@ -16,8 +16,9 @@ features = ["docs"]
|
||||
rustdoc-args = ["--cfg", "feature=\"docs\""]
|
||||
|
||||
[features]
|
||||
default = []
|
||||
default = ["backups_v1"]
|
||||
qrcode = ["matrix-qrcode"]
|
||||
backups_v1 = []
|
||||
sled_cryptostore = ["sled"]
|
||||
docs = ["sled_cryptostore"]
|
||||
indexeddb_cryptostore = ["indexed_db_futures", "wasm-bindgen"]
|
||||
@@ -27,6 +28,7 @@ aes = { version = "0.7.4", features = ["ctr"] }
|
||||
aes-gcm = "0.9.2"
|
||||
atomic = "0.5.0"
|
||||
base64 = "0.13.0"
|
||||
bs58 = "0.4.0"
|
||||
byteorder = "1.4.3"
|
||||
dashmap = "4.0.2"
|
||||
futures-util = { version = "0.3.15", default-features = false }
|
||||
@@ -34,9 +36,13 @@ getrandom = "0.2.3"
|
||||
hmac = "0.11.0"
|
||||
matrix-qrcode = { version = "0.2.0", path = "../matrix-qrcode", optional = true }
|
||||
matrix-sdk-common = { version = "0.4.0", path = "../matrix-sdk-common" }
|
||||
olm-rs = { version = "2.0.1", features = ["serde"] }
|
||||
olm-rs = { version = "2.1", features = ["serde"] }
|
||||
pbkdf2 = { version = "0.9.0", default-features = false }
|
||||
ruma = { git = "https://github.com/ruma/ruma", rev = "ac6ecc3e5", features = ["client-api-c", "unstable-pre-spec"] }
|
||||
rand = "0.8.4"
|
||||
ruma = { git = "https://github.com/ruma/ruma", rev = "ac6ecc3e5", features = [
|
||||
"client-api-c",
|
||||
"unstable-pre-spec",
|
||||
] }
|
||||
serde = { version = "1.0.126", features = ["derive", "rc"] }
|
||||
serde_json = "1.0.64"
|
||||
sha2 = "0.9.5"
|
||||
@@ -50,17 +56,24 @@ indexed_db_futures = { version = "0.2.0", optional = true }
|
||||
wasm-bindgen = { version = "0.2.74", features = ["serde-serialize"], optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
futures = { version = "0.3.15", default-features = false }
|
||||
futures = { version = "0.3.15", default-features = false, features = ["executor"] }
|
||||
http = "0.2.4"
|
||||
indoc = "1.0.3"
|
||||
matches = "0.1.8"
|
||||
matrix-sdk-test = { version = "0.4.0", path = "../matrix-sdk-test" }
|
||||
|
||||
[target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies]
|
||||
criterion = { version = "0.3.4", features = [
|
||||
"async",
|
||||
"async_tokio",
|
||||
"html_reports",
|
||||
] }
|
||||
proptest = "1.0.0"
|
||||
criterion = { version = "0.3.4", features = ["async", "async_tokio", "html_reports"] }
|
||||
tokio = { version = "1.7.1", default-features = false, features = ["rt-multi-thread", "macros"] }
|
||||
tempfile = "3.2.0"
|
||||
tokio = { version = "1.7.1", default-features = false, features = [
|
||||
"rt-multi-thread",
|
||||
"macros",
|
||||
] }
|
||||
|
||||
[target.'cfg(target_os = "linux")'.dev-dependencies]
|
||||
criterion = { version = "0.3.4", features = ["async", "html_reports"] }
|
||||
|
||||
148
crates/matrix-sdk-crypto/src/backups/keys/backup.rs
Normal file
148
crates/matrix-sdk-crypto/src/backups/keys/backup.rs
Normal file
@@ -0,0 +1,148 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::{
|
||||
collections::BTreeMap,
|
||||
sync::{Arc, Mutex},
|
||||
};
|
||||
|
||||
use olm_rs::pk::OlmPkEncryption;
|
||||
use ruma::{
|
||||
api::client::r0::backup::{KeyBackupData, KeyBackupDataInit, SessionDataInit},
|
||||
DeviceKeyId, UserId,
|
||||
};
|
||||
use zeroize::Zeroizing;
|
||||
|
||||
use super::recovery::DecodeError;
|
||||
use crate::olm::InboundGroupSession;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct InnerBackupKey {
|
||||
key: [u8; MegolmV1BackupKey::KEY_SIZE],
|
||||
signatures: BTreeMap<UserId, BTreeMap<DeviceKeyId, String>>,
|
||||
version: Mutex<Option<String>>,
|
||||
}
|
||||
|
||||
/// The public part of a backup key.
|
||||
#[derive(Clone)]
|
||||
pub struct MegolmV1BackupKey {
|
||||
inner: Arc<InnerBackupKey>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for MegolmV1BackupKey {
|
||||
fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
formatter
|
||||
.debug_struct("MegolmV1BackupKey")
|
||||
.field("key", &self.to_base64())
|
||||
.field("version", &self.backup_version())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl MegolmV1BackupKey {
|
||||
const KEY_SIZE: usize = 32;
|
||||
|
||||
pub(super) fn new(public_key: &str, version: Option<String>) -> Self {
|
||||
let key = Self::from_base64(public_key).expect("Invalid backup key");
|
||||
|
||||
if let Some(version) = version {
|
||||
key.set_version(version)
|
||||
}
|
||||
|
||||
key
|
||||
}
|
||||
|
||||
/// Get the full name of the backup algorithm this backup key supports.
|
||||
pub fn backup_algorithm(&self) -> &str {
|
||||
"m.megolm_backup.v1.curve25519-aes-sha2"
|
||||
}
|
||||
|
||||
/// Get all the signatures of this `MegolmV1BackupKey`.
|
||||
pub fn signatures(&self) -> BTreeMap<UserId, BTreeMap<DeviceKeyId, String>> {
|
||||
self.inner.signatures.to_owned()
|
||||
}
|
||||
|
||||
/// Try to create a new `MegolmV1BackupKey` from a base 64 encoded string.
|
||||
pub fn from_base64(public_key: &str) -> Result<Self, DecodeError> {
|
||||
let mut key = [0u8; Self::KEY_SIZE];
|
||||
let decoded_key = crate::utilities::decode(public_key)?;
|
||||
|
||||
if decoded_key.len() != Self::KEY_SIZE {
|
||||
Err(DecodeError::Length(Self::KEY_SIZE, decoded_key.len()))
|
||||
} else {
|
||||
key.copy_from_slice(&decoded_key);
|
||||
|
||||
let inner =
|
||||
InnerBackupKey { key, signatures: Default::default(), version: Mutex::new(None) };
|
||||
|
||||
Ok(MegolmV1BackupKey { inner: inner.into() })
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert the [`MegolmV1BackupKey`] to a base 64 encoded string.
|
||||
pub fn to_base64(&self) -> String {
|
||||
crate::utilities::encode(&self.inner.key)
|
||||
}
|
||||
|
||||
/// Get the backup version that this key is used with, if any.
|
||||
pub fn backup_version(&self) -> Option<String> {
|
||||
self.inner.version.lock().unwrap().clone()
|
||||
}
|
||||
|
||||
/// Set the backup version that this `MegolmV1BackupKey` will be used with.
|
||||
///
|
||||
/// The key won't be able to encrypt room keys unless a version has been
|
||||
/// set.
|
||||
pub fn set_version(&self, version: String) {
|
||||
*self.inner.version.lock().unwrap() = Some(version);
|
||||
}
|
||||
|
||||
pub(crate) async fn encrypt(&self, session: InboundGroupSession) -> KeyBackupData {
|
||||
let pk = OlmPkEncryption::new(&self.to_base64());
|
||||
|
||||
// It's ok to truncate here, there's a semantic difference only between
|
||||
// 0 and 1+ anyways.
|
||||
let forwarded_count = (session.forwarding_key_chain().len() as u32).into();
|
||||
let first_message_index = session.first_known_index().into();
|
||||
|
||||
// Convert our key to the backup representation.
|
||||
let key = session.to_backup().await;
|
||||
|
||||
// The key gets zeroized in `BackedUpRoomKey` but we're creating a copy
|
||||
// here that won't, so let's wrap it up in a `Zeroizing` struct.
|
||||
let key =
|
||||
Zeroizing::new(serde_json::to_string(&key).expect("Can't serialize exported room key"));
|
||||
|
||||
let message = pk.encrypt(&key);
|
||||
|
||||
let session_data = SessionDataInit {
|
||||
ephemeral: message.ephemeral_key,
|
||||
ciphertext: message.ciphertext,
|
||||
mac: message.mac,
|
||||
}
|
||||
.into();
|
||||
|
||||
KeyBackupDataInit {
|
||||
first_message_index,
|
||||
forwarded_count,
|
||||
// TODO is this actually used anywhere? seems to be completely
|
||||
// useless and requires us to get the Device out of the store?
|
||||
// Also should this be checked at the time of the backup or at the
|
||||
// time of the room key receival?
|
||||
is_verified: false,
|
||||
session_data,
|
||||
}
|
||||
.into()
|
||||
}
|
||||
}
|
||||
54
crates/matrix-sdk-crypto/src/backups/keys/mod.rs
Normal file
54
crates/matrix-sdk-crypto/src/backups/keys/mod.rs
Normal file
@@ -0,0 +1,54 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Module for the keys that are used to backup room keys.
|
||||
//!
|
||||
//! The backup key is split into two parts:
|
||||
//!
|
||||
//! ```text
|
||||
//! ┌───────────────────────────┐
|
||||
//! │ RecoveryKey | BackupKey │
|
||||
//! └───────────────────────────┘
|
||||
//! ```
|
||||
//!
|
||||
//! 1. RecoveryKey, a private Curve25519 key that is used to decrypt backed up
|
||||
//! room keys.
|
||||
//! 2. BackupKey, the public part of the Curve25519 RecoveryKey, this one is
|
||||
//! used to encrypt room keys that get backed up.
|
||||
//!
|
||||
//! The `RecoveryKey` can be derived from a passphrase or randomly generated.
|
||||
//! To allow other devices access to this key you'll either have to re-enter the
|
||||
//! passphrase or the key as a base58 encoded string or received from another
|
||||
//! device using the secret sharing mechanism.
|
||||
//!
|
||||
//! In practice deriving a recovery key from a passphrase isn't done, and is
|
||||
//! **not** supported by the spec. Instead the secret storage key that encrypts
|
||||
//! secrets and puts them into your global account data gets derived from a
|
||||
//! passphrase and presented as a base58 encoded string, confusingly also called
|
||||
//! recovery key.
|
||||
//!
|
||||
//! The `RecoveryKey` can also be uploaded to your account data as an
|
||||
//! `m.megolm.v1` event, which gets encrypted by the SSSS key.
|
||||
//!
|
||||
//! The `MegolmV1BackupKey` is used to encrypt individual room keys so they can
|
||||
//! be uploaded to the homeserver.
|
||||
//!
|
||||
//! The `MegolmV1BackupKey` is a public key and is uploaded to the server using
|
||||
//! the `/room_keys/version` API endpoint.
|
||||
|
||||
mod backup;
|
||||
mod recovery;
|
||||
|
||||
pub use backup::MegolmV1BackupKey;
|
||||
pub use recovery::{DecodeError, PickledRecoveryKey, RecoveryKey};
|
||||
354
crates/matrix-sdk-crypto/src/backups/keys/recovery.rs
Normal file
354
crates/matrix-sdk-crypto/src/backups/keys/recovery.rs
Normal file
@@ -0,0 +1,354 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::{
|
||||
convert::TryFrom,
|
||||
io::{Cursor, Read},
|
||||
};
|
||||
|
||||
use aes::cipher::generic_array::GenericArray;
|
||||
use aes_gcm::{
|
||||
aead::{Aead, NewAead},
|
||||
Aes256Gcm,
|
||||
};
|
||||
use bs58;
|
||||
use olm_rs::{
|
||||
errors::OlmPkDecryptionError,
|
||||
pk::{OlmPkDecryption, PkMessage},
|
||||
};
|
||||
use rand::{thread_rng, Error as RandomError, Fill};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
use zeroize::{Zeroize, Zeroizing};
|
||||
|
||||
use super::MegolmV1BackupKey;
|
||||
use crate::utilities::{decode_url_safe, encode, encode_url_safe};
|
||||
|
||||
const NONCE_SIZE: usize = 12;
|
||||
|
||||
/// Error type for the decoding of a RecoveryKey.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum DecodeError {
|
||||
/// The decoded recovery key has an invalid prefix.
|
||||
#[error("The decoded recovery key has an invalid prefix: expected {0:?}, got {1:?}")]
|
||||
Prefix([u8; 2], [u8; 2]),
|
||||
/// The parity byte of the recovery key didn't match.
|
||||
#[error("The parity byte of the recovery key doesn't match: expected {0:?}, got {1:?}")]
|
||||
Parity(u8, u8),
|
||||
/// The recovery key has an invalid length.
|
||||
#[error("The decoded recovery key has a invalid length: expected {0}, got {1}")]
|
||||
Length(usize, usize),
|
||||
/// The recovry key isn't valid base58.
|
||||
#[error(transparent)]
|
||||
Base58(#[from] bs58::decode::Error),
|
||||
/// The recovery key isn't valid base64.
|
||||
#[error(transparent)]
|
||||
Base64(#[from] base64::DecodeError),
|
||||
/// The recovery key is too short, we couldn't read enough data.
|
||||
#[error(transparent)]
|
||||
Io(#[from] std::io::Error),
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum UnpicklingError {
|
||||
#[error(transparent)]
|
||||
Json(#[from] serde_json::Error),
|
||||
#[error("Couldn't decrypt the pickle: {0}")]
|
||||
Decryption(String),
|
||||
#[error(transparent)]
|
||||
Decode(#[from] DecodeError),
|
||||
}
|
||||
|
||||
/// The private part of a backup key.
|
||||
#[derive(Zeroize)]
|
||||
pub struct RecoveryKey {
|
||||
inner: [u8; RecoveryKey::KEY_SIZE],
|
||||
}
|
||||
|
||||
impl Drop for RecoveryKey {
|
||||
fn drop(&mut self) {
|
||||
self.inner.zeroize()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for RecoveryKey {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("RecoveryKey").finish()
|
||||
}
|
||||
}
|
||||
|
||||
/// The pickled version of a recovery key.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PickledRecoveryKey(String);
|
||||
|
||||
impl AsRef<str> for PickledRecoveryKey {
|
||||
fn as_ref(&self) -> &str {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct InnerPickle {
|
||||
version: u8,
|
||||
nonce: String,
|
||||
ciphertext: String,
|
||||
}
|
||||
|
||||
impl TryFrom<String> for RecoveryKey {
|
||||
type Error = DecodeError;
|
||||
|
||||
fn try_from(value: String) -> Result<Self, Self::Error> {
|
||||
Self::from_base58(&value)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for RecoveryKey {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let string = Zeroizing::new(self.to_base58());
|
||||
|
||||
let string = Zeroizing::new(
|
||||
string
|
||||
.chars()
|
||||
.collect::<Vec<char>>()
|
||||
.chunks(Self::DISPLAY_CHUNK_SIZE)
|
||||
.map(|c| c.iter().collect::<String>())
|
||||
.collect::<Vec<_>>()
|
||||
.join(" "),
|
||||
);
|
||||
|
||||
write!(f, "{}", string.as_str())
|
||||
}
|
||||
}
|
||||
|
||||
impl RecoveryKey {
|
||||
const KEY_SIZE: usize = 32;
|
||||
const PREFIX: [u8; 2] = [0x8b, 0x01];
|
||||
const PREFIX_PARITY: u8 = Self::PREFIX[0] ^ Self::PREFIX[1];
|
||||
const DISPLAY_CHUNK_SIZE: usize = 4;
|
||||
|
||||
fn parity_byte(bytes: &[u8]) -> u8 {
|
||||
bytes.iter().fold(Self::PREFIX_PARITY, |acc, x| acc ^ x)
|
||||
}
|
||||
|
||||
/// Create a new random recovery key.
|
||||
pub fn new() -> Result<Self, RandomError> {
|
||||
let mut rng = thread_rng();
|
||||
|
||||
let mut key = [0u8; Self::KEY_SIZE];
|
||||
key.try_fill(&mut rng)?;
|
||||
|
||||
Ok(Self { inner: key })
|
||||
}
|
||||
|
||||
/// Create a new recovery key from the given byte array.
|
||||
///
|
||||
/// **Warning**: You need to make sure that the byte array contains correct
|
||||
/// random data, either by using a random number generator or by using an
|
||||
/// exported version of a previously created [`RecoveryKey`].
|
||||
pub fn from_bytes(key: [u8; Self::KEY_SIZE]) -> Self {
|
||||
Self { inner: key }
|
||||
}
|
||||
|
||||
/// Try to create a [`RecoveryKey`] from a base64 export of a `RecoveryKey`.
|
||||
pub fn from_base64(key: &str) -> Result<Self, DecodeError> {
|
||||
let decoded = Zeroizing::new(crate::utilities::decode(key)?);
|
||||
|
||||
if decoded.len() != Self::KEY_SIZE {
|
||||
Err(DecodeError::Length(decoded.len(), Self::KEY_SIZE))
|
||||
} else {
|
||||
let mut key = [0u8; Self::KEY_SIZE];
|
||||
key.copy_from_slice(&decoded);
|
||||
|
||||
Ok(Self::from_bytes(key))
|
||||
}
|
||||
}
|
||||
|
||||
/// Export the `RecoveryKey` as a base64 encoded string.
|
||||
pub fn to_base64(&self) -> String {
|
||||
encode(self.inner)
|
||||
}
|
||||
|
||||
/// Try to create a [`RecoveryKey`] from a base58 export of a `RecoveryKey`.
|
||||
pub fn from_base58(value: &str) -> Result<Self, DecodeError> {
|
||||
// Remove any whitespace we might have
|
||||
let value: String = value.chars().filter(|c| !c.is_whitespace()).collect();
|
||||
|
||||
let decoded = bs58::decode(value).with_alphabet(bs58::Alphabet::BITCOIN).into_vec()?;
|
||||
let mut decoded = Cursor::new(decoded);
|
||||
|
||||
let mut prefix = [0u8; 2];
|
||||
let mut key = [0u8; Self::KEY_SIZE];
|
||||
let mut expected_parity = [0u8; 1];
|
||||
|
||||
decoded.read_exact(&mut prefix)?;
|
||||
decoded.read_exact(&mut key)?;
|
||||
decoded.read_exact(&mut expected_parity)?;
|
||||
|
||||
let expected_parity = expected_parity[0];
|
||||
let parity = Self::parity_byte(key.as_ref());
|
||||
|
||||
let _ = Zeroizing::new(decoded.into_inner());
|
||||
|
||||
if prefix != Self::PREFIX {
|
||||
Err(DecodeError::Prefix(Self::PREFIX, prefix))
|
||||
} else if expected_parity != parity {
|
||||
Err(DecodeError::Parity(expected_parity, parity))
|
||||
} else {
|
||||
Ok(Self::from_bytes(key))
|
||||
}
|
||||
}
|
||||
|
||||
/// Export the `RecoveryKey` as a base58 encoded string.
|
||||
pub fn to_base58(&self) -> String {
|
||||
let bytes = Zeroizing::new(
|
||||
[
|
||||
Self::PREFIX.as_ref(),
|
||||
self.inner.as_ref(),
|
||||
[Self::parity_byte(self.inner.as_ref())].as_ref(),
|
||||
]
|
||||
.concat(),
|
||||
);
|
||||
|
||||
bs58::encode(bytes.as_slice()).with_alphabet(bs58::Alphabet::BITCOIN).into_string()
|
||||
}
|
||||
|
||||
fn get_pk_decrytpion(&self) -> OlmPkDecryption {
|
||||
OlmPkDecryption::from_bytes(self.inner.as_ref())
|
||||
.expect("Can't generate a libom PkDecryption object from our private key")
|
||||
}
|
||||
|
||||
/// Extract the megolm.v1 public key from this `RecoveryKey`.
|
||||
pub fn megolm_v1_public_key(&self) -> MegolmV1BackupKey {
|
||||
let pk = self.get_pk_decrytpion();
|
||||
let public_key = MegolmV1BackupKey::new(pk.public_key(), None);
|
||||
public_key
|
||||
}
|
||||
|
||||
/// Export this [`RecoveryKey`] as an encrypted pickle that can be safely
|
||||
/// stored.
|
||||
pub fn pickle(&self, pickle_key: &[u8]) -> PickledRecoveryKey {
|
||||
let key = GenericArray::from_slice(pickle_key);
|
||||
let cipher = Aes256Gcm::new(key);
|
||||
|
||||
let mut nonce = vec![0u8; NONCE_SIZE];
|
||||
let mut rng = thread_rng();
|
||||
|
||||
nonce.try_fill(&mut rng).expect("Can't generate random nocne to pickle the recovery key");
|
||||
let nonce = GenericArray::from_slice(nonce.as_slice());
|
||||
|
||||
let ciphertext =
|
||||
cipher.encrypt(nonce, self.inner.as_ref()).expect("Can't encrypt recovery key");
|
||||
|
||||
let ciphertext = encode_url_safe(ciphertext);
|
||||
|
||||
let pickle =
|
||||
InnerPickle { version: 1, nonce: encode_url_safe(nonce.as_slice()), ciphertext };
|
||||
|
||||
PickledRecoveryKey(serde_json::to_string(&pickle).expect("Can't encode pickled signing"))
|
||||
}
|
||||
|
||||
/// Try to import a `RecoveryKey` from a previously exported pickle.
|
||||
pub fn from_pickle(
|
||||
pickle: PickledRecoveryKey,
|
||||
pickle_key: &[u8],
|
||||
) -> Result<Self, UnpicklingError> {
|
||||
let pickled: InnerPickle = serde_json::from_str(pickle.as_ref())?;
|
||||
|
||||
let key = GenericArray::from_slice(pickle_key);
|
||||
let cipher = Aes256Gcm::new(key);
|
||||
|
||||
let nonce = decode_url_safe(pickled.nonce).map_err(DecodeError::from)?;
|
||||
let nonce = GenericArray::from_slice(&nonce);
|
||||
let ciphertext = &decode_url_safe(pickled.ciphertext).map_err(DecodeError::from)?;
|
||||
|
||||
let decrypted = cipher
|
||||
.decrypt(nonce, ciphertext.as_slice())
|
||||
.map_err(|e| UnpicklingError::Decryption(e.to_string()))?;
|
||||
|
||||
if decrypted.len() != Self::KEY_SIZE {
|
||||
Err(DecodeError::Length(decrypted.len(), Self::KEY_SIZE).into())
|
||||
} else {
|
||||
let mut key = [0u8; Self::KEY_SIZE];
|
||||
key.copy_from_slice(&decrypted);
|
||||
|
||||
Ok(Self { inner: key })
|
||||
}
|
||||
}
|
||||
|
||||
/// Try to decrypt the given ciphertext using this `RecoveryKey`.
|
||||
///
|
||||
/// This will use the [`m.megolm_backup.v1.curve25519-aes-sha2`] algorithm
|
||||
/// to decrypt the given ciphertext.
|
||||
///
|
||||
/// [`m.megolm_backup.v1.curve25519-aes-sha2`]:
|
||||
/// https://spec.matrix.org/unstable/client-server-api/#backup-algorithm-mmegolm_backupv1curve25519-aes-sha2
|
||||
pub fn decrypt_v1(
|
||||
&self,
|
||||
mac: String,
|
||||
ephemeral_key: String,
|
||||
ciphertext: String,
|
||||
) -> Result<String, OlmPkDecryptionError> {
|
||||
let message = PkMessage::new(mac, ephemeral_key, ciphertext);
|
||||
let pk = self.get_pk_decrytpion();
|
||||
|
||||
pk.decrypt(message)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::{DecodeError, RecoveryKey};
|
||||
|
||||
const TEST_KEY: [u8; 32] = [
|
||||
0x77, 0x07, 0x6D, 0x0A, 0x73, 0x18, 0xA5, 0x7D, 0x3C, 0x16, 0xC1, 0x72, 0x51, 0xB2, 0x66,
|
||||
0x45, 0xDF, 0x4C, 0x2F, 0x87, 0xEB, 0xC0, 0x99, 0x2A, 0xB1, 0x77, 0xFB, 0xA5, 0x1D, 0xB9,
|
||||
0x2C, 0x2A,
|
||||
];
|
||||
|
||||
#[test]
|
||||
fn base64_decoding() -> Result<(), DecodeError> {
|
||||
let key = RecoveryKey::new().expect("Can't create a new recovery key");
|
||||
|
||||
let base64 = key.to_base64();
|
||||
let decoded_key = RecoveryKey::from_base64(&base64)?;
|
||||
assert_eq!(key.inner, decoded_key.inner, "The decode key does't match the original");
|
||||
|
||||
RecoveryKey::from_base64("i").expect_err("The recovery key is too short");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn base58_decoding() -> Result<(), DecodeError> {
|
||||
let key = RecoveryKey::new().expect("Can't create a new recovery key");
|
||||
|
||||
let base64 = key.to_base58();
|
||||
let decoded_key = RecoveryKey::from_base58(&base64)?;
|
||||
assert_eq!(key.inner, decoded_key.inner, "The decode key does't match the original");
|
||||
|
||||
let test_key =
|
||||
RecoveryKey::from_base58("EsTcLW2KPGiFwKEA3As5g5c4BXwkqeeJZJV8Q9fugUMNUE4d")?;
|
||||
assert_eq!(test_key.inner, TEST_KEY, "The decoded recovery key doesn't match the test key");
|
||||
|
||||
let test_key = RecoveryKey::from_base58(
|
||||
"EsTc LW2K PGiF wKEA 3As5 g5c4 BXwk qeeJ ZJV8 Q9fu gUMN UE4d",
|
||||
)?;
|
||||
assert_eq!(test_key.inner, TEST_KEY, "The decoded recovery key doesn't match the test key");
|
||||
|
||||
RecoveryKey::from_base58("EsTc LW2K PGiF wKEA 3As5 g5c4 BXwk qeeJ ZJV8 Q9fu gUMN UE4e")
|
||||
.expect_err("Can't create a recovery key if the parity byte is invalid");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
519
crates/matrix-sdk-crypto/src/backups/mod.rs
Normal file
519
crates/matrix-sdk-crypto/src/backups/mod.rs
Normal file
@@ -0,0 +1,519 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Server-side backup support for room keys
|
||||
//!
|
||||
//! This module largely implements support for server-side backups using the
|
||||
//! `m.megolm_backup.v1.curve25519-aes-sha2` backup algorithm.
|
||||
//!
|
||||
//! Due to various flaws in this backup algorithm it is **not** recommended to
|
||||
//! use this module or any of its functionality. The module is only provided for
|
||||
//! backwards compatibility.
|
||||
//!
|
||||
//! [spec]: https://spec.matrix.org/unstable/client-server-api/#server-side-key-backups
|
||||
|
||||
use std::{
|
||||
collections::{BTreeMap, BTreeSet},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use matrix_sdk_common::{locks::RwLock, uuid::Uuid};
|
||||
use ruma::{
|
||||
api::client::r0::backup::RoomKeyBackup, DeviceKeyAlgorithm, DeviceKeyId, RoomId, UserId,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use tracing::{debug, info, instrument, trace, warn};
|
||||
|
||||
use crate::{
|
||||
olm::{Account, InboundGroupSession},
|
||||
store::{BackupKeys, Changes, RoomKeyCounts, Store},
|
||||
CryptoStoreError, KeysBackupRequest, OutgoingRequest,
|
||||
};
|
||||
|
||||
mod keys;
|
||||
|
||||
pub use keys::{DecodeError, MegolmV1BackupKey, PickledRecoveryKey, RecoveryKey};
|
||||
pub use olm_rs::errors::OlmPkDecryptionError;
|
||||
|
||||
/// A state machine that handles backing up room keys.
|
||||
///
|
||||
/// The state machine can be activated using the
|
||||
/// [`BackupMachine::enable_backup_v1`] method. After the state machine has been
|
||||
/// enabled a request that will upload encrypted room keys can be generated
|
||||
/// using the [`BackupMachine::backup`] method.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct BackupMachine {
|
||||
account: Account,
|
||||
store: Store,
|
||||
backup_key: Arc<RwLock<Option<MegolmV1BackupKey>>>,
|
||||
pending_backup: Arc<RwLock<Option<PendingBackup>>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct PendingBackup {
|
||||
request_id: Uuid,
|
||||
request: KeysBackupRequest,
|
||||
sessions: BTreeMap<RoomId, BTreeMap<String, BTreeSet<String>>>,
|
||||
}
|
||||
|
||||
impl PendingBackup {
|
||||
fn session_was_part_of_the_backup(&self, session: &InboundGroupSession) -> bool {
|
||||
self.sessions
|
||||
.get(session.room_id())
|
||||
.and_then(|r| r.get(session.sender_key()).map(|s| s.contains(session.session_id())))
|
||||
.unwrap_or(false)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<PendingBackup> for OutgoingRequest {
|
||||
fn from(b: PendingBackup) -> Self {
|
||||
OutgoingRequest { request_id: b.request_id, request: Arc::new(b.request.into()) }
|
||||
}
|
||||
}
|
||||
|
||||
impl BackupMachine {
|
||||
const BACKUP_BATCH_SIZE: usize = 100;
|
||||
|
||||
pub(crate) fn new(
|
||||
account: Account,
|
||||
store: Store,
|
||||
backup_key: Option<MegolmV1BackupKey>,
|
||||
) -> Self {
|
||||
Self {
|
||||
account,
|
||||
store,
|
||||
backup_key: RwLock::new(backup_key).into(),
|
||||
pending_backup: RwLock::new(None).into(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Are we able to back up room keys to the server?
|
||||
pub async fn enabled(&self) -> bool {
|
||||
self.backup_key.read().await.as_ref().map(|b| b.backup_version().is_some()).unwrap_or(false)
|
||||
}
|
||||
|
||||
/// Verify some backup auth data that we downloaded from the server.
|
||||
///
|
||||
/// The auth data should be fetched from the server using the
|
||||
/// [`/room_keys/version`] endpoint.
|
||||
///
|
||||
/// Ruma models this using the [`BackupAlgorithm`] struct, but care needs to
|
||||
/// be taken if that struct is used. Some clients might use unspecced fields
|
||||
/// and the given Ruma struct will lose the unspecced fields after a
|
||||
/// serialization cycle.
|
||||
///
|
||||
/// [`BackupAlgorithm`]: ruma::api::client::r0::backup::BackupAlgorithm
|
||||
/// [`/room_keys/version`]: https://spec.matrix.org/unstable/client-server-api/#get_matrixclientv3room_keysversion
|
||||
pub async fn verify_backup(
|
||||
&self,
|
||||
mut serialized_auth_data: Value,
|
||||
) -> Result<bool, CryptoStoreError> {
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct AuthData {
|
||||
public_key: String,
|
||||
#[serde(default)]
|
||||
signatures: BTreeMap<UserId, BTreeMap<DeviceKeyId, String>>,
|
||||
#[serde(flatten)]
|
||||
extra: BTreeMap<String, Value>,
|
||||
}
|
||||
|
||||
let auth_data: AuthData = serde_json::from_value(serialized_auth_data.clone())?;
|
||||
|
||||
trace!(?auth_data, "Verifying backup auth data");
|
||||
|
||||
Ok(if let Some(signatures) = auth_data.signatures.get(self.store.user_id()) {
|
||||
for device_key_id in signatures.keys() {
|
||||
if device_key_id.algorithm() == DeviceKeyAlgorithm::Ed25519 {
|
||||
if device_key_id.device_id() == self.account.device_id() {
|
||||
let result = self.account.is_signed(&mut serialized_auth_data);
|
||||
|
||||
trace!(?result, "Checking auth data signature of our own device");
|
||||
|
||||
if result.is_ok() {
|
||||
return Ok(true);
|
||||
}
|
||||
} else {
|
||||
let device = self
|
||||
.store
|
||||
.get_device(self.store.user_id(), device_key_id.device_id())
|
||||
.await?;
|
||||
|
||||
trace!(
|
||||
device_id = device_key_id.device_id().as_str(),
|
||||
"Checking backup auth data for device"
|
||||
);
|
||||
|
||||
if let Some(device) = device {
|
||||
if device.verified()
|
||||
&& device.is_signed_by_device(&mut serialized_auth_data).is_ok()
|
||||
{
|
||||
return Ok(true);
|
||||
}
|
||||
} else {
|
||||
trace!(
|
||||
device_id = device_key_id.device_id().as_str(),
|
||||
"Device not found, can't check signature"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
false
|
||||
} else {
|
||||
false
|
||||
})
|
||||
}
|
||||
|
||||
/// Activate the given backup key to be used to encrypt and backup room
|
||||
/// keys.
|
||||
///
|
||||
/// This will use the [`m.megolm_backup.v1.curve25519-aes-sha2`] algorithm
|
||||
/// to encrypt the room keys.
|
||||
///
|
||||
/// [`m.megolm_backup.v1.curve25519-aes-sha2`]:
|
||||
/// https://spec.matrix.org/unstable/client-server-api/#backup-algorithm-mmegolm_backupv1curve25519-aes-sha2
|
||||
pub async fn enable_backup_v1(&self, key: MegolmV1BackupKey) -> Result<(), CryptoStoreError> {
|
||||
if key.backup_version().is_some() {
|
||||
*self.backup_key.write().await = Some(key.clone());
|
||||
info!(backup_key =? key, "Activated a backup");
|
||||
} else {
|
||||
warn!(backup_key =? key, "Tried to activate a backup without having the backup key uploaded");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get the number of backed up room keys and the total number of room keys.
|
||||
pub async fn room_key_counts(&self) -> Result<RoomKeyCounts, CryptoStoreError> {
|
||||
self.store.inbound_group_session_counts().await
|
||||
}
|
||||
|
||||
/// Disable and reset our backup state.
|
||||
///
|
||||
/// This will remove any pending backup request, remove the backup key and
|
||||
/// reset the backup state of each room key we have.
|
||||
#[instrument(skip(self))]
|
||||
pub async fn disable_backup(&self) -> Result<(), CryptoStoreError> {
|
||||
debug!("Disabling key backup and resetting backup state for room keys");
|
||||
|
||||
self.backup_key.write().await.take();
|
||||
self.pending_backup.write().await.take();
|
||||
|
||||
self.store.reset_backup_state().await?;
|
||||
|
||||
debug!("Done disabling backup");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Store the recovery key in the cryptostore.
|
||||
///
|
||||
/// This is useful if the client wants to support gossiping of the backup
|
||||
/// key.
|
||||
pub async fn save_recovery_key(
|
||||
&self,
|
||||
recovery_key: Option<RecoveryKey>,
|
||||
version: Option<String>,
|
||||
) -> Result<(), CryptoStoreError> {
|
||||
let changes = Changes { recovery_key, backup_version: version, ..Default::default() };
|
||||
self.store.save_changes(changes).await
|
||||
}
|
||||
|
||||
/// Get the backup keys we have saved in our crypto store.
|
||||
pub async fn get_backup_keys(&self) -> Result<BackupKeys, CryptoStoreError> {
|
||||
self.store.load_backup_keys().await
|
||||
}
|
||||
|
||||
/// Encrypt a batch of room keys and return a request that needs to be sent
|
||||
/// out to backup the room keys.
|
||||
pub async fn backup(&self) -> Result<Option<OutgoingRequest>, CryptoStoreError> {
|
||||
let mut request = self.pending_backup.write().await;
|
||||
|
||||
if let Some(request) = &*request {
|
||||
trace!("Backing up, returning an existing request");
|
||||
|
||||
Ok(Some(request.clone().into()))
|
||||
} else {
|
||||
trace!("Backing up, creating a new request");
|
||||
|
||||
let new_request = self.backup_helper().await?;
|
||||
*request = new_request.clone();
|
||||
|
||||
Ok(new_request.map(|r| r.into()))
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn mark_request_as_sent(
|
||||
&self,
|
||||
request_id: Uuid,
|
||||
) -> Result<(), CryptoStoreError> {
|
||||
let mut request = self.pending_backup.write().await;
|
||||
|
||||
if let Some(r) = &*request {
|
||||
if r.request_id == request_id {
|
||||
let sessions: Vec<_> = self
|
||||
.store
|
||||
.get_inbound_group_sessions()
|
||||
.await?
|
||||
.into_iter()
|
||||
.filter(|s| r.session_was_part_of_the_backup(s))
|
||||
.collect();
|
||||
|
||||
for session in &sessions {
|
||||
session.mark_as_backed_up()
|
||||
}
|
||||
|
||||
trace!(request_id =? r.request_id, keys =? r.sessions, "Marking room keys as backed up");
|
||||
|
||||
let changes = Changes { inbound_group_sessions: sessions, ..Default::default() };
|
||||
self.store.save_changes(changes).await?;
|
||||
|
||||
let counts = self.store.inbound_group_session_counts().await?;
|
||||
|
||||
trace!(
|
||||
room_key_counts =? counts,
|
||||
request_id =? r.request_id, keys =? r.sessions, "Marked room keys as backed up"
|
||||
);
|
||||
|
||||
*request = None;
|
||||
} else {
|
||||
warn!(
|
||||
expected = r.request_id.to_string().as_str(),
|
||||
got = request_id.to_string().as_str(),
|
||||
"Tried to mark a pending backup as sent but the request id didn't match"
|
||||
)
|
||||
}
|
||||
} else {
|
||||
warn!(
|
||||
request_id = request_id.to_string().as_str(),
|
||||
"Tried to mark a pending backup as sent but there isn't a backup pending"
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn backup_helper(&self) -> Result<Option<PendingBackup>, CryptoStoreError> {
|
||||
if let Some(backup_key) = &*self.backup_key.read().await {
|
||||
if let Some(version) = backup_key.backup_version() {
|
||||
let sessions =
|
||||
self.store.inbound_group_sessions_for_backup(Self::BACKUP_BATCH_SIZE).await?;
|
||||
|
||||
if !sessions.is_empty() {
|
||||
let key_count = sessions.len();
|
||||
let (backup, session_record) = Self::backup_keys(sessions, backup_key).await;
|
||||
|
||||
info!(
|
||||
key_count = key_count,
|
||||
keys =? session_record,
|
||||
backup_key =? backup_key,
|
||||
"Successfully created a room keys backup request"
|
||||
);
|
||||
|
||||
let request = PendingBackup {
|
||||
request_id: Uuid::new_v4(),
|
||||
request: KeysBackupRequest { version, rooms: backup },
|
||||
sessions: session_record,
|
||||
};
|
||||
|
||||
Ok(Some(request))
|
||||
} else {
|
||||
trace!(?backup_key, "No room keys need to be backed up");
|
||||
Ok(None)
|
||||
}
|
||||
} else {
|
||||
warn!("Trying to backup room keys but the backup key wasn't uploaded");
|
||||
Ok(None)
|
||||
}
|
||||
} else {
|
||||
warn!("Trying to backup room keys but no backup key was found");
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
/// Backup all the non-backed up room keys we know about
|
||||
async fn backup_keys(
|
||||
sessions: Vec<InboundGroupSession>,
|
||||
backup_key: &MegolmV1BackupKey,
|
||||
) -> (BTreeMap<RoomId, RoomKeyBackup>, BTreeMap<RoomId, BTreeMap<String, BTreeSet<String>>>)
|
||||
{
|
||||
let mut backup: BTreeMap<RoomId, RoomKeyBackup> = BTreeMap::new();
|
||||
let mut session_record: BTreeMap<RoomId, BTreeMap<String, BTreeSet<String>>> =
|
||||
BTreeMap::new();
|
||||
|
||||
for session in sessions.into_iter() {
|
||||
let room_id = session.room_id().to_owned();
|
||||
let session_id = session.session_id().to_owned();
|
||||
let sender_key = session.sender_key().to_owned();
|
||||
let session = backup_key.encrypt(session).await;
|
||||
|
||||
session_record
|
||||
.entry(room_id.clone())
|
||||
.or_default()
|
||||
.entry(sender_key)
|
||||
.or_default()
|
||||
.insert(session_id.clone());
|
||||
|
||||
backup
|
||||
.entry(room_id)
|
||||
.or_insert_with(|| RoomKeyBackup::new(BTreeMap::new()))
|
||||
.sessions
|
||||
.insert(session_id, session);
|
||||
}
|
||||
|
||||
(backup, session_record)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use matrix_sdk_test::async_test;
|
||||
use ruma::{room_id, user_id, DeviceIdBox, RoomId, UserId};
|
||||
|
||||
use super::RecoveryKey;
|
||||
use crate::{OlmError, OlmMachine};
|
||||
|
||||
fn alice_id() -> UserId {
|
||||
user_id!("@alice:example.org")
|
||||
}
|
||||
|
||||
fn alice_device_id() -> DeviceIdBox {
|
||||
"JLAFKJWSCS".into()
|
||||
}
|
||||
|
||||
fn room_id() -> RoomId {
|
||||
room_id!("!test:localhost")
|
||||
}
|
||||
|
||||
fn room_id2() -> RoomId {
|
||||
room_id!("!test2:localhost")
|
||||
}
|
||||
|
||||
async fn backup_flow(machine: OlmMachine) -> Result<(), OlmError> {
|
||||
let backup_machine = machine.backup_machine();
|
||||
let counts = backup_machine.store.inbound_group_session_counts().await?;
|
||||
|
||||
assert_eq!(counts.total, 0, "Initially no keys exist");
|
||||
assert_eq!(counts.backed_up, 0, "Initially no backed up keys exist");
|
||||
|
||||
machine.create_outbound_group_session_with_defaults(&room_id()).await?;
|
||||
machine.create_outbound_group_session_with_defaults(&room_id2()).await?;
|
||||
|
||||
let counts = backup_machine.store.inbound_group_session_counts().await?;
|
||||
assert_eq!(counts.total, 2, "Two room keys need to exist in the store");
|
||||
assert_eq!(counts.backed_up, 0, "No room keys have been backed up yet");
|
||||
|
||||
let recovery_key = RecoveryKey::new().expect("Can't create new recovery key");
|
||||
let backup_key = recovery_key.megolm_v1_public_key();
|
||||
backup_key.set_version("1".to_owned());
|
||||
|
||||
backup_machine.enable_backup_v1(backup_key).await?;
|
||||
|
||||
let request =
|
||||
backup_machine.backup().await?.expect("Created a backup request successfully");
|
||||
assert_eq!(
|
||||
Some(request.request_id),
|
||||
backup_machine.backup().await?.map(|r| r.request_id),
|
||||
"Calling backup again without uploading creates the same backup request"
|
||||
);
|
||||
|
||||
backup_machine.mark_request_as_sent(request.request_id).await?;
|
||||
|
||||
let counts = backup_machine.store.inbound_group_session_counts().await?;
|
||||
assert_eq!(counts.total, 2);
|
||||
assert_eq!(counts.backed_up, 2, "All room keys have been backed up");
|
||||
|
||||
assert!(
|
||||
backup_machine.backup().await?.is_none(),
|
||||
"No room keys need to be backed up, no request needs to be created"
|
||||
);
|
||||
|
||||
backup_machine.disable_backup().await?;
|
||||
|
||||
let counts = backup_machine.store.inbound_group_session_counts().await?;
|
||||
assert_eq!(counts.total, 2);
|
||||
assert_eq!(
|
||||
counts.backed_up, 0,
|
||||
"Disabling the backup resets the backup flag on the room keys"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[async_test]
|
||||
async fn memory_store_backups() -> Result<(), OlmError> {
|
||||
let machine = OlmMachine::new(&alice_id(), &alice_device_id());
|
||||
|
||||
backup_flow(machine).await
|
||||
}
|
||||
|
||||
#[async_test]
|
||||
#[cfg(feature = "sled_cryptostore")]
|
||||
async fn default_store_backups() -> Result<(), OlmError> {
|
||||
use tempfile::tempdir;
|
||||
|
||||
let tmpdir = tempdir().expect("Can't create a temporary dir");
|
||||
let machine = OlmMachine::new_with_default_store(
|
||||
&alice_id(),
|
||||
&alice_device_id(),
|
||||
tmpdir.as_ref(),
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
|
||||
backup_flow(machine).await
|
||||
}
|
||||
|
||||
#[async_test]
|
||||
#[cfg(feature = "sled_cryptostore")]
|
||||
async fn recovery_key_storing() -> Result<(), OlmError> {
|
||||
use tempfile::tempdir;
|
||||
|
||||
let tmpdir = tempdir().expect("Can't create a temporary dir");
|
||||
let machine = OlmMachine::new_with_default_store(
|
||||
&alice_id(),
|
||||
&alice_device_id(),
|
||||
tmpdir.as_ref(),
|
||||
Some("test"),
|
||||
)
|
||||
.await?;
|
||||
let backup_machine = machine.backup_machine();
|
||||
|
||||
let recovery_key = RecoveryKey::new().expect("Can't create new recovery key");
|
||||
let encoded_key = recovery_key.to_base64();
|
||||
|
||||
backup_machine.save_recovery_key(Some(recovery_key), Some("1".to_owned())).await?;
|
||||
|
||||
let loded_backup = backup_machine.get_backup_keys().await?;
|
||||
|
||||
assert_eq!(
|
||||
encoded_key,
|
||||
loded_backup
|
||||
.recovery_key
|
||||
.expect("The recovery key wasn't loaded from the store")
|
||||
.to_base64(),
|
||||
"The loaded key matches to the one we stored"
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
Some("1"),
|
||||
loded_backup.backup_version.as_deref(),
|
||||
"The loaded version matches to the one we stored"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -770,6 +770,8 @@ impl GossipMachine {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let secret = std::mem::take(&mut event.content.secret);
|
||||
|
||||
if let Some(request) = self.store.get_outgoing_secret_requests(request_id).await? {
|
||||
match &request.info {
|
||||
SecretInfo::KeyRequest(_) => {
|
||||
@@ -791,30 +793,32 @@ impl GossipMachine {
|
||||
self.store.get_device_from_curve_key(&event.sender, sender_key).await?
|
||||
{
|
||||
if device.verified() {
|
||||
match self
|
||||
.store
|
||||
.import_secret(
|
||||
secret_name,
|
||||
std::mem::take(&mut event.content.secret),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(_) => self.mark_as_done(request).await?,
|
||||
Err(e) => {
|
||||
// If this is a store error propagate it up
|
||||
// the call stack.
|
||||
if let SecretImportError::Store(e) = e {
|
||||
return Err(e);
|
||||
} else {
|
||||
// Otherwise warn that there was
|
||||
// something wrong with the secret.
|
||||
warn!(
|
||||
secret_name = secret_name.as_ref(),
|
||||
error =? e,
|
||||
"Error while importing a secret"
|
||||
)
|
||||
if secret_name != &SecretName::RecoveryKey {
|
||||
match self.store.import_secret(secret_name, secret).await {
|
||||
Ok(_) => self.mark_as_done(request).await?,
|
||||
Err(e) => {
|
||||
// If this is a store error propagate it up
|
||||
// the call stack.
|
||||
if let SecretImportError::Store(e) = e {
|
||||
return Err(e);
|
||||
} else {
|
||||
// Otherwise warn that there was
|
||||
// something wrong with the secret.
|
||||
warn!(
|
||||
secret_name = secret_name.as_ref(),
|
||||
error =? e,
|
||||
"Error while importing a secret"
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Skip importing the recovery key here since
|
||||
// we'll want to check if the public key matches
|
||||
// to the latest version on the server. We
|
||||
// instead leave the key in the event and let
|
||||
// the user import it later.
|
||||
event.content.secret = secret;
|
||||
}
|
||||
} else {
|
||||
warn!(
|
||||
|
||||
@@ -539,7 +539,7 @@ impl ReadOnlyDevice {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn is_signed_by_device(&self, json: &mut Value) -> Result<(), SignatureError> {
|
||||
pub(crate) fn is_signed_by_device(&self, json: &mut Value) -> Result<(), SignatureError> {
|
||||
let signing_key =
|
||||
self.get_key(DeviceKeyAlgorithm::Ed25519).ok_or(SignatureError::MissingSigningKey)?;
|
||||
|
||||
@@ -623,9 +623,6 @@ impl PartialEq for ReadOnlyDevice {
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) mod test {
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
use wasm_bindgen_test::wasm_bindgen_test;
|
||||
use matrix_sdk_test::async_test;
|
||||
use std::convert::TryFrom;
|
||||
|
||||
use ruma::{encryption::DeviceKeys, user_id, DeviceKeyAlgorithm};
|
||||
|
||||
@@ -30,6 +30,9 @@
|
||||
compile_error!("indexeddb_cryptostore only works for wasm32 target");
|
||||
|
||||
|
||||
#[cfg(feature = "backups_v1")]
|
||||
#[cfg_attr(feature = "docs", doc(cfg(backups_v1)))]
|
||||
pub mod backups;
|
||||
mod error;
|
||||
mod file_encryption;
|
||||
mod gossiping;
|
||||
@@ -72,7 +75,7 @@ pub use matrix_qrcode;
|
||||
pub(crate) use olm::ReadOnlyAccount;
|
||||
pub use olm::{CrossSigningStatus, EncryptionSettings};
|
||||
pub use requests::{
|
||||
IncomingResponse, KeysQueryRequest, OutgoingRequest, OutgoingRequests,
|
||||
IncomingResponse, KeysBackupRequest, KeysQueryRequest, OutgoingRequest, OutgoingRequests,
|
||||
OutgoingVerificationRequest, RoomMessageRequest, ToDeviceRequest, UploadSigningKeysRequest,
|
||||
};
|
||||
pub use store::{CrossSigningKeyExport, CryptoStoreError, SecretImportError};
|
||||
|
||||
@@ -46,11 +46,14 @@ use ruma::{
|
||||
secret::request::SecretName,
|
||||
AnyMessageEventContent, AnyRoomEvent, AnyToDeviceEvent, EventContent,
|
||||
},
|
||||
DeviceId, DeviceIdBox, DeviceKeyAlgorithm, EventEncryptionAlgorithm, RoomId, UInt, UserId,
|
||||
DeviceId, DeviceIdBox, DeviceKeyAlgorithm, DeviceKeyId, EventEncryptionAlgorithm, RoomId, UInt,
|
||||
UserId,
|
||||
};
|
||||
use serde_json::Value;
|
||||
use tracing::{debug, error, info, trace, warn};
|
||||
|
||||
#[cfg(feature = "backups_v1")]
|
||||
use crate::backups::BackupMachine;
|
||||
#[cfg(feature = "sled_cryptostore")]
|
||||
use crate::store::sled::SledStore;
|
||||
use crate::{
|
||||
@@ -104,6 +107,9 @@ pub struct OlmMachine {
|
||||
/// State machine handling public user identities and devices, keeping track
|
||||
/// of when a key query needs to be done and handling one.
|
||||
identity_manager: IdentityManager,
|
||||
/// A state machine that handles creating room key backups.
|
||||
#[cfg(feature = "backups_v1")]
|
||||
backup_machine: BackupMachine,
|
||||
}
|
||||
|
||||
#[cfg(not(tarpaulin_include))]
|
||||
@@ -180,6 +186,9 @@ impl OlmMachine {
|
||||
let identity_manager =
|
||||
IdentityManager::new(user_id.clone(), device_id.clone(), store.clone());
|
||||
|
||||
#[cfg(feature = "backups_v1")]
|
||||
let backup_machine = BackupMachine::new(account.clone(), store.clone(), None);
|
||||
|
||||
OlmMachine {
|
||||
user_id,
|
||||
device_id,
|
||||
@@ -191,6 +200,8 @@ impl OlmMachine {
|
||||
verification_machine,
|
||||
key_request_machine,
|
||||
identity_manager,
|
||||
#[cfg(feature = "backups_v1")]
|
||||
backup_machine,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -371,6 +382,10 @@ impl OlmMachine {
|
||||
IncomingResponse::RoomMessage(_) => {
|
||||
self.verification_machine.mark_request_as_sent(request_id);
|
||||
}
|
||||
IncomingResponse::KeysBackup(_) => {
|
||||
#[cfg(feature = "backups_v1")]
|
||||
self.backup_machine.mark_request_as_sent(*request_id).await?;
|
||||
}
|
||||
};
|
||||
|
||||
Ok(())
|
||||
@@ -651,18 +666,18 @@ impl OlmMachine {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// #[cfg(test)]
|
||||
// pub(crate) async fn create_inbound_session(
|
||||
// &self,
|
||||
// room_id: &RoomId,
|
||||
// ) -> OlmResult<InboundGroupSession> {
|
||||
// let (_, session) = self
|
||||
// .group_session_manager
|
||||
// .create_outbound_group_session(room_id, EncryptionSettings::default())
|
||||
// .await?;
|
||||
#[cfg(test)]
|
||||
pub(crate) async fn create_inbound_session(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
) -> OlmResult<InboundGroupSession> {
|
||||
let (_, session) = self
|
||||
.group_session_manager
|
||||
.create_outbound_group_session(room_id, EncryptionSettings::default())
|
||||
.await?;
|
||||
|
||||
// Ok(session)
|
||||
// }
|
||||
Ok(session)
|
||||
}
|
||||
|
||||
/// Encrypt a room message for the given room.
|
||||
///
|
||||
@@ -1453,6 +1468,62 @@ impl OlmMachine {
|
||||
) -> Result<CrossSigningStatus, SecretImportError> {
|
||||
self.store.import_cross_signing_keys(export).await
|
||||
}
|
||||
|
||||
async fn sign_account(
|
||||
&self,
|
||||
message: &str,
|
||||
signatures: &mut BTreeMap<UserId, BTreeMap<DeviceKeyId, String>>,
|
||||
) {
|
||||
let device_key_id = DeviceKeyId::from_parts(DeviceKeyAlgorithm::Ed25519, self.device_id());
|
||||
let signature = self.account.sign(message).await;
|
||||
|
||||
signatures.entry(self.user_id().to_owned()).or_default().insert(device_key_id, signature);
|
||||
}
|
||||
|
||||
async fn sign_master(
|
||||
&self,
|
||||
message: &str,
|
||||
signatures: &mut BTreeMap<UserId, BTreeMap<DeviceKeyId, String>>,
|
||||
) -> Result<(), crate::SignatureError> {
|
||||
let identity = &*self.user_identity.lock().await;
|
||||
|
||||
let master_key: DeviceIdBox = identity
|
||||
.master_public_key()
|
||||
.await
|
||||
.and_then(|m| m.get_first_key().map(|k| k.to_owned()))
|
||||
.ok_or(crate::SignatureError::MissingSigningKey)?
|
||||
.into();
|
||||
|
||||
let device_key_id = DeviceKeyId::from_parts(DeviceKeyAlgorithm::Ed25519, &master_key);
|
||||
let signature = identity.sign(message).await?;
|
||||
|
||||
signatures.entry(self.user_id().to_owned()).or_default().insert(device_key_id, signature);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Sign the given message using our device key and if available cross
|
||||
/// signing master key.
|
||||
pub async fn sign(&self, message: &str) -> BTreeMap<UserId, BTreeMap<DeviceKeyId, String>> {
|
||||
let mut signatures: BTreeMap<_, BTreeMap<_, _>> = BTreeMap::new();
|
||||
|
||||
self.sign_account(message, &mut signatures).await;
|
||||
|
||||
if let Err(e) = self.sign_master(message, &mut signatures).await {
|
||||
warn!(error =? e, "Couldn't sign the message using the cross signing master key")
|
||||
}
|
||||
|
||||
signatures
|
||||
}
|
||||
|
||||
/// Get a reference to the backup related state machine.
|
||||
///
|
||||
/// This state machine can be used to incrementally backup all room keys to
|
||||
/// the server.
|
||||
#[cfg(feature = "backups_v1")]
|
||||
pub fn backup_machine(&self) -> &BackupMachine {
|
||||
&self.backup_machine
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -686,6 +686,19 @@ impl ReadOnlyAccount {
|
||||
self.inner.lock().await.sign(string)
|
||||
}
|
||||
|
||||
#[cfg(feature = "backups_v1")]
|
||||
pub(crate) fn is_signed(&self, json: &mut Value) -> Result<(), SignatureError> {
|
||||
let signing_key = self.identity_keys.ed25519();
|
||||
let utility = crate::olm::Utility::new();
|
||||
|
||||
utility.verify_json(
|
||||
&self.user_id,
|
||||
&DeviceKeyId::from_parts(DeviceKeyAlgorithm::Ed25519, self.device_id()),
|
||||
signing_key,
|
||||
json,
|
||||
)
|
||||
}
|
||||
|
||||
/// Store the account as a base64 encoded string.
|
||||
///
|
||||
/// # Arguments
|
||||
@@ -806,7 +819,8 @@ impl ReadOnlyAccount {
|
||||
) -> Result<SignatureUploadRequest, SignatureError> {
|
||||
let public_key =
|
||||
master_key.get_first_key().ok_or(SignatureError::MissingSigningKey)?.to_string();
|
||||
let mut cross_signing_key = master_key.into();
|
||||
let mut cross_signing_key: CrossSigningKey = master_key.into();
|
||||
cross_signing_key.signatures.clear();
|
||||
self.sign_cross_signing_key(&mut cross_signing_key).await?;
|
||||
|
||||
let mut signed_keys = BTreeMap::new();
|
||||
|
||||
@@ -12,7 +12,15 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::{collections::BTreeMap, convert::TryFrom, fmt, mem, sync::Arc};
|
||||
use std::{
|
||||
collections::BTreeMap,
|
||||
convert::TryFrom,
|
||||
fmt, mem,
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering::SeqCst},
|
||||
Arc,
|
||||
},
|
||||
};
|
||||
|
||||
use matrix_sdk_common::locks::Mutex;
|
||||
pub use olm_rs::{
|
||||
@@ -39,7 +47,7 @@ use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use zeroize::Zeroizing;
|
||||
|
||||
use super::{ExportedGroupSessionKey, ExportedRoomKey, GroupSessionKey};
|
||||
use super::{BackedUpRoomKey, ExportedGroupSessionKey, ExportedRoomKey, GroupSessionKey};
|
||||
use crate::error::{EventError, MegolmResult};
|
||||
|
||||
// TODO add creation times to the inbound group sessions so we can export
|
||||
@@ -61,7 +69,7 @@ pub struct InboundGroupSession {
|
||||
pub(crate) room_id: Arc<RoomId>,
|
||||
forwarding_chains: Arc<Vec<String>>,
|
||||
imported: bool,
|
||||
backed_up: bool,
|
||||
backed_up: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
impl InboundGroupSession {
|
||||
@@ -105,7 +113,7 @@ impl InboundGroupSession {
|
||||
room_id: room_id.clone().into(),
|
||||
forwarding_chains: Vec::new().into(),
|
||||
imported: false,
|
||||
backed_up: false,
|
||||
backed_up: AtomicBool::new(false).into(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -123,6 +131,25 @@ impl InboundGroupSession {
|
||||
Self::try_from(exported_session.into())
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn from_backup(
|
||||
room_id: &RoomId,
|
||||
backup: BackedUpRoomKey,
|
||||
) -> Result<Self, OlmGroupSessionError> {
|
||||
let session = OlmInboundGroupSession::new(&backup.session_key.0)?;
|
||||
let session_id = session.session_id();
|
||||
|
||||
Self::from_export(ExportedRoomKey {
|
||||
algorithm: backup.algorithm,
|
||||
room_id: room_id.to_owned(),
|
||||
sender_key: backup.sender_key,
|
||||
session_id,
|
||||
session_key: backup.session_key,
|
||||
sender_claimed_keys: backup.sender_claimed_keys,
|
||||
forwarding_curve25519_key_chain: backup.forwarding_curve25519_key_chain,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create a new inbound group session from a forwarded room key content.
|
||||
///
|
||||
/// # Arguments
|
||||
@@ -157,7 +184,7 @@ impl InboundGroupSession {
|
||||
room_id: content.room_id.clone().into(),
|
||||
forwarding_chains: forwarding_chains.into(),
|
||||
imported: true,
|
||||
backed_up: false,
|
||||
backed_up: AtomicBool::new(false).into(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -177,7 +204,7 @@ impl InboundGroupSession {
|
||||
room_id: (&*self.room_id).clone(),
|
||||
forwarding_chains: self.forwarding_key_chain().to_vec(),
|
||||
imported: self.imported,
|
||||
backed_up: self.backed_up,
|
||||
backed_up: self.backed_up(),
|
||||
history_visibility: self.history_visibility.as_ref().clone(),
|
||||
}
|
||||
}
|
||||
@@ -195,6 +222,21 @@ impl InboundGroupSession {
|
||||
&self.sender_key
|
||||
}
|
||||
|
||||
/// Has the session been backed up to the server.
|
||||
pub fn backed_up(&self) -> bool {
|
||||
self.backed_up.load(SeqCst)
|
||||
}
|
||||
|
||||
/// Reset the backup state of the inbound group session.
|
||||
pub(crate) fn reset_backup_state(&self) {
|
||||
self.backed_up.store(false, SeqCst)
|
||||
}
|
||||
|
||||
#[cfg(feature = "backups_v1")]
|
||||
pub(crate) fn mark_as_backed_up(&self) {
|
||||
self.backed_up.store(true, SeqCst)
|
||||
}
|
||||
|
||||
/// Get the map of signing keys this session was received from.
|
||||
pub fn signing_keys(&self) -> &BTreeMap<DeviceKeyAlgorithm, String> {
|
||||
&self.signing_keys
|
||||
@@ -256,7 +298,7 @@ impl InboundGroupSession {
|
||||
signing_keys: pickle.signing_key.into(),
|
||||
room_id: pickle.room_id.into(),
|
||||
forwarding_chains: pickle.forwarding_chains.into(),
|
||||
backed_up: pickle.backed_up,
|
||||
backed_up: AtomicBool::from(pickle.backed_up).into(),
|
||||
imported: pickle.imported,
|
||||
})
|
||||
}
|
||||
@@ -291,6 +333,11 @@ impl InboundGroupSession {
|
||||
self.inner.lock().await.decrypt(message)
|
||||
}
|
||||
|
||||
#[cfg(feature = "backups_v1")]
|
||||
pub(crate) async fn to_backup(&self) -> BackedUpRoomKey {
|
||||
self.export().await.into()
|
||||
}
|
||||
|
||||
/// Decrypt an event from a room timeline.
|
||||
///
|
||||
/// # Arguments
|
||||
@@ -413,16 +460,16 @@ impl TryFrom<ExportedRoomKey> for InboundGroupSession {
|
||||
let first_known_index = session.first_known_index();
|
||||
|
||||
Ok(InboundGroupSession {
|
||||
inner: Arc::new(Mutex::new(session)),
|
||||
inner: Mutex::new(session).into(),
|
||||
session_id: key.session_id.into(),
|
||||
sender_key: key.sender_key.into(),
|
||||
history_visibility: None.into(),
|
||||
first_known_index,
|
||||
signing_keys: Arc::new(key.sender_claimed_keys),
|
||||
room_id: Arc::new(key.room_id),
|
||||
forwarding_chains: Arc::new(key.forwarding_curve25519_key_chain),
|
||||
signing_keys: key.sender_claimed_keys.into(),
|
||||
room_id: key.room_id.into(),
|
||||
forwarding_chains: key.forwarding_curve25519_key_chain.into(),
|
||||
imported: true,
|
||||
backed_up: false,
|
||||
backed_up: AtomicBool::from(false).into(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -43,7 +43,7 @@ pub struct GroupSessionKey(pub String);
|
||||
#[zeroize(drop)]
|
||||
pub struct ExportedGroupSessionKey(pub String);
|
||||
|
||||
/// An exported version of a `InboundGroupSession`
|
||||
/// An exported version of an `InboundGroupSession`
|
||||
///
|
||||
/// This can be used to share the `InboundGroupSession` in an exported file.
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
|
||||
@@ -71,6 +71,28 @@ pub struct ExportedRoomKey {
|
||||
pub forwarding_curve25519_key_chain: Vec<String>,
|
||||
}
|
||||
|
||||
/// A backed up version of an `InboundGroupSession`
|
||||
///
|
||||
/// This can be used to backup the `InboundGroupSession` to the server.
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
|
||||
pub struct BackedUpRoomKey {
|
||||
/// The encryption algorithm that the session uses.
|
||||
pub algorithm: EventEncryptionAlgorithm,
|
||||
|
||||
/// The Curve25519 key of the device which initiated the session originally.
|
||||
pub sender_key: String,
|
||||
|
||||
/// The key for the session.
|
||||
pub session_key: ExportedGroupSessionKey,
|
||||
|
||||
/// The Ed25519 key of the device which initiated the session originally.
|
||||
pub sender_claimed_keys: BTreeMap<DeviceKeyAlgorithm, String>,
|
||||
|
||||
/// Chain of Curve25519 keys through which this session was forwarded, via
|
||||
/// m.forwarded_room_key events.
|
||||
pub forwarding_curve25519_key_chain: Vec<String>,
|
||||
}
|
||||
|
||||
impl TryInto<ToDeviceForwardedRoomKeyEventContent> for ExportedRoomKey {
|
||||
type Error = ();
|
||||
|
||||
@@ -104,6 +126,18 @@ impl TryInto<ToDeviceForwardedRoomKeyEventContent> for ExportedRoomKey {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ExportedRoomKey> for BackedUpRoomKey {
|
||||
fn from(k: ExportedRoomKey) -> Self {
|
||||
Self {
|
||||
algorithm: k.algorithm,
|
||||
sender_key: k.sender_key,
|
||||
session_key: k.session_key,
|
||||
sender_claimed_keys: k.sender_claimed_keys,
|
||||
forwarding_curve25519_key_chain: k.forwarding_curve25519_key_chain,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ToDeviceForwardedRoomKeyEventContent> for ExportedRoomKey {
|
||||
/// Convert the content of a forwarded room key into a exported room key.
|
||||
fn from(forwarded_key: ToDeviceForwardedRoomKeyEventContent) -> Self {
|
||||
@@ -125,6 +159,9 @@ impl From<ToDeviceForwardedRoomKeyEventContent> for ExportedRoomKey {
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
use wasm_bindgen_test::wasm_bindgen_test;
|
||||
use matrix_sdk_test::async_test;
|
||||
use std::{
|
||||
sync::Arc,
|
||||
time::{Duration, Instant},
|
||||
|
||||
@@ -458,6 +458,17 @@ impl PrivateCrossSigningIdentity {
|
||||
Ok(SignatureUploadRequest::new(signed_keys))
|
||||
}
|
||||
|
||||
pub(crate) async fn sign(&self, message: &str) -> Result<String, SignatureError> {
|
||||
Ok(self
|
||||
.master_key
|
||||
.lock()
|
||||
.await
|
||||
.as_ref()
|
||||
.ok_or(SignatureError::MissingSigningKey)?
|
||||
.sign(message)
|
||||
.await)
|
||||
}
|
||||
|
||||
/// Create a new identity for the given Olm Account.
|
||||
///
|
||||
/// Returns the new identity, the upload signing keys request and a
|
||||
@@ -770,6 +781,9 @@ mod test {
|
||||
|
||||
let master = user_signing.sign_user(&bob_public).await.unwrap();
|
||||
|
||||
let num_signatures: usize = master.signatures.iter().map(|(_, u)| u.len()).sum();
|
||||
assert_eq!(num_signatures, 1, "We're only uploading our own signature");
|
||||
|
||||
bob_public.master_key = master.into();
|
||||
|
||||
user_signing.public_key.verify_master_key(bob_public.master_key()).unwrap();
|
||||
|
||||
@@ -155,6 +155,10 @@ impl MasterSigning {
|
||||
Ok(Self { inner, public_key: pickle.public_key.into() })
|
||||
}
|
||||
|
||||
pub async fn sign(&self, message: &str) -> String {
|
||||
self.inner.sign(message).await.0
|
||||
}
|
||||
|
||||
pub async fn sign_subkey<'a>(&self, subkey: &mut CrossSigningKey) {
|
||||
let subkey_without_signatures = json!({
|
||||
"user_id": subkey.user_id.clone(),
|
||||
@@ -164,7 +168,7 @@ impl MasterSigning {
|
||||
|
||||
let message = serde_json::to_string(&subkey_without_signatures)
|
||||
.expect("Can't serialize cross signing subkey");
|
||||
let signature = self.inner.sign(&message).await;
|
||||
let signature = self.sign(&message).await;
|
||||
|
||||
subkey
|
||||
.signatures
|
||||
@@ -176,7 +180,7 @@ impl MasterSigning {
|
||||
self.inner.public_key().as_str().into(),
|
||||
)
|
||||
.to_string(),
|
||||
signature.0,
|
||||
signature,
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -203,10 +207,10 @@ impl UserSigning {
|
||||
&self,
|
||||
user: &ReadOnlyUserIdentity,
|
||||
) -> Result<CrossSigningKey, SignatureError> {
|
||||
let signature = self.sign_user_helper(user).await?;
|
||||
let signatures = self.sign_user_helper(user).await?;
|
||||
let mut master_key: CrossSigningKey = user.master_key().to_owned().into();
|
||||
|
||||
master_key.signatures.extend(signature);
|
||||
master_key.signatures = signatures;
|
||||
|
||||
Ok(master_key)
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ use std::{collections::BTreeMap, iter, sync::Arc, time::Duration};
|
||||
use matrix_sdk_common::uuid::Uuid;
|
||||
use ruma::{
|
||||
api::client::r0::{
|
||||
backup::{add_backup_keys::Response as KeysBackupResponse, RoomKeyBackup},
|
||||
keys::{
|
||||
claim_keys::{Request as KeysClaimRequest, Response as KeysClaimResponse},
|
||||
get_keys::Response as KeysQueryResponse,
|
||||
@@ -197,6 +198,8 @@ pub enum OutgoingRequests {
|
||||
/// A room message request, usually for sending in-room interactive
|
||||
/// verification events.
|
||||
RoomMessage(RoomMessageRequest),
|
||||
/// A request that will back up a batch of room keys to the server.
|
||||
KeysBackup(KeysBackupRequest),
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -215,6 +218,12 @@ impl From<KeysQueryRequest> for OutgoingRequests {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<KeysBackupRequest> for OutgoingRequests {
|
||||
fn from(r: KeysBackupRequest) -> Self {
|
||||
OutgoingRequests::KeysBackup(r)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<KeysClaimRequest> for OutgoingRequests {
|
||||
fn from(r: KeysClaimRequest) -> Self {
|
||||
Self::KeysClaim(r)
|
||||
@@ -278,6 +287,8 @@ pub enum IncomingResponse<'a> {
|
||||
SignatureUpload(&'a SignatureUploadResponse),
|
||||
/// A room message response, usually for interactive verifications.
|
||||
RoomMessage(&'a RoomMessageResponse),
|
||||
/// Response for the server-side room key backup request.
|
||||
KeysBackup(&'a KeysBackupResponse),
|
||||
}
|
||||
|
||||
impl<'a> From<&'a KeysUploadResponse> for IncomingResponse<'a> {
|
||||
@@ -286,6 +297,12 @@ impl<'a> From<&'a KeysUploadResponse> for IncomingResponse<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> From<&'a KeysBackupResponse> for IncomingResponse<'a> {
|
||||
fn from(response: &'a KeysBackupResponse) -> Self {
|
||||
IncomingResponse::KeysBackup(response)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> From<&'a KeysQueryResponse> for IncomingResponse<'a> {
|
||||
fn from(response: &'a KeysQueryResponse) -> Self {
|
||||
IncomingResponse::KeysQuery(response)
|
||||
@@ -356,6 +373,16 @@ pub struct RoomMessageRequest {
|
||||
pub content: AnyMessageEventContent,
|
||||
}
|
||||
|
||||
/// A request that will back up a batch of room keys to the server.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct KeysBackupRequest {
|
||||
/// The backup version that these room keys should be part of.
|
||||
pub version: String,
|
||||
/// The map from room id to a backed up room key that we're going to upload
|
||||
/// to the server.
|
||||
pub rooms: BTreeMap<RoomId, RoomKeyBackup>,
|
||||
}
|
||||
|
||||
/// An enum over the different outgoing verification based requests.
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum OutgoingVerificationRequest {
|
||||
|
||||
@@ -111,6 +111,11 @@ impl GroupSessionStore {
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get the number of `InboundGroupSession`s we have.
|
||||
pub fn count(&self) -> usize {
|
||||
self.entries.iter().map(|d| d.value().values().map(|t| t.len()).sum::<usize>()).sum()
|
||||
}
|
||||
|
||||
/// Get a inbound group session from our store.
|
||||
///
|
||||
/// # Arguments
|
||||
|
||||
@@ -23,7 +23,8 @@ use ruma::{DeviceId, DeviceIdBox, RoomId, UserId};
|
||||
|
||||
use super::{
|
||||
caches::{DeviceStore, GroupSessionStore, SessionStore},
|
||||
Changes, CryptoStore, InboundGroupSession, ReadOnlyAccount, Result, Session,
|
||||
BackupKeys, Changes, CryptoStore, InboundGroupSession, ReadOnlyAccount, Result, RoomKeyCounts,
|
||||
Session,
|
||||
};
|
||||
use crate::{
|
||||
gossiping::{GossipRequest, SecretInfo},
|
||||
@@ -163,6 +164,34 @@ impl CryptoStore for MemoryStore {
|
||||
Ok(self.inbound_group_sessions.get_all())
|
||||
}
|
||||
|
||||
async fn inbound_group_session_counts(&self) -> Result<RoomKeyCounts> {
|
||||
let backed_up =
|
||||
self.get_inbound_group_sessions().await?.into_iter().filter(|s| s.backed_up()).count();
|
||||
|
||||
Ok(RoomKeyCounts { total: self.inbound_group_sessions.count(), backed_up })
|
||||
}
|
||||
|
||||
async fn inbound_group_sessions_for_backup(
|
||||
&self,
|
||||
limit: usize,
|
||||
) -> Result<Vec<InboundGroupSession>> {
|
||||
Ok(self
|
||||
.get_inbound_group_sessions()
|
||||
.await?
|
||||
.into_iter()
|
||||
.filter(|s| !s.backed_up())
|
||||
.take(limit)
|
||||
.collect())
|
||||
}
|
||||
|
||||
async fn reset_backup_state(&self) -> Result<()> {
|
||||
for session in self.get_inbound_group_sessions().await? {
|
||||
session.reset_backup_state()
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_outbound_group_sessions(
|
||||
&self,
|
||||
_: &RoomId,
|
||||
@@ -271,6 +300,10 @@ impl CryptoStore for MemoryStore {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn load_backup_keys(&self) -> Result<BackupKeys> {
|
||||
Ok(BackupKeys::default())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -107,11 +107,15 @@ pub(crate) struct Store {
|
||||
verification_machine: VerificationMachine,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
#[derive(Default, Debug)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct Changes {
|
||||
pub account: Option<ReadOnlyAccount>,
|
||||
pub private_identity: Option<PrivateCrossSigningIdentity>,
|
||||
#[cfg(feature = "backups_v1")]
|
||||
pub backup_version: Option<String>,
|
||||
#[cfg(feature = "backups_v1")]
|
||||
pub recovery_key: Option<crate::backups::RecoveryKey>,
|
||||
pub sessions: Vec<Session>,
|
||||
pub message_hashes: Vec<OlmMessageHash>,
|
||||
pub inbound_group_sessions: Vec<InboundGroupSession>,
|
||||
@@ -170,6 +174,26 @@ impl DeviceChanges {
|
||||
}
|
||||
}
|
||||
|
||||
/// Struct holding info about how many room keys the store has.
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct RoomKeyCounts {
|
||||
/// The total number of room keys the store has.
|
||||
pub total: usize,
|
||||
/// The number of backed up room keys the store has.
|
||||
pub backed_up: usize,
|
||||
}
|
||||
|
||||
/// Stored versions of the backup keys.
|
||||
#[derive(Default, Debug)]
|
||||
pub struct BackupKeys {
|
||||
/// The recovery key, the one used to decrypt backed up room keys.
|
||||
#[cfg(feature = "backups_v1")]
|
||||
pub recovery_key: Option<crate::backups::RecoveryKey>,
|
||||
/// The version that we are using for backups.
|
||||
#[cfg(feature = "backups_v1")]
|
||||
pub backup_version: Option<String>,
|
||||
}
|
||||
|
||||
/// A struct containing private cross signing keys that can be backed up or
|
||||
/// uploaded to the secret store.
|
||||
#[derive(Zeroize)]
|
||||
@@ -431,7 +455,18 @@ impl Store {
|
||||
| SecretName::CrossSigningSelfSigningKey => {
|
||||
self.identity.lock().await.export_secret(secret_name).await
|
||||
}
|
||||
SecretName::RecoveryKey => None,
|
||||
SecretName::RecoveryKey => {
|
||||
#[cfg(feature = "backups_v1")]
|
||||
if let Some(key) = self.load_backup_keys().await.unwrap().recovery_key {
|
||||
let exported = key.to_base64();
|
||||
Some(exported)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "backups_v1"))]
|
||||
None
|
||||
}
|
||||
name => {
|
||||
warn!(secret =? name, "Unknown secret was requested");
|
||||
None
|
||||
@@ -496,7 +531,12 @@ impl Store {
|
||||
self.save_changes(changes).await?;
|
||||
}
|
||||
}
|
||||
SecretName::RecoveryKey => (),
|
||||
SecretName::RecoveryKey => {
|
||||
// We don't import the recovery key here since we'll want to
|
||||
// check if the public key matches to the latest version on the
|
||||
// server. We instead leave the key in the event and let the
|
||||
// user import it later.
|
||||
}
|
||||
name => {
|
||||
warn!(secret =? name, "Tried to import an unknown secret");
|
||||
}
|
||||
@@ -630,6 +670,22 @@ pub trait CryptoStore: AsyncTraitDeps {
|
||||
/// Get all the inbound group sessions we have stored.
|
||||
async fn get_inbound_group_sessions(&self) -> Result<Vec<InboundGroupSession>>;
|
||||
|
||||
/// Get the number inbound group sessions we have and how many of them are
|
||||
/// backed up.
|
||||
async fn inbound_group_session_counts(&self) -> Result<RoomKeyCounts>;
|
||||
|
||||
/// Get all the inbound group sessions we have not backed up yet.
|
||||
async fn inbound_group_sessions_for_backup(
|
||||
&self,
|
||||
limit: usize,
|
||||
) -> Result<Vec<InboundGroupSession>>;
|
||||
|
||||
/// Reset the backup state of all the stored inbound group sessions.
|
||||
async fn reset_backup_state(&self) -> Result<()>;
|
||||
|
||||
/// Get the backup keys we have stored.
|
||||
async fn load_backup_keys(&self) -> Result<BackupKeys>;
|
||||
|
||||
/// Get the outbound group sessions we have stored that is used for the
|
||||
/// given room.
|
||||
async fn get_outbound_group_sessions(
|
||||
|
||||
@@ -29,14 +29,14 @@ use ruma::{
|
||||
pub use sled::Error;
|
||||
use sled::{
|
||||
transaction::{ConflictableTransactionError, TransactionError},
|
||||
Config, Db, Transactional, Tree,
|
||||
Config, Db, IVec, Transactional, Tree,
|
||||
};
|
||||
use tracing::trace;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::{
|
||||
caches::SessionStore, Changes, CryptoStore, CryptoStoreError, InboundGroupSession, PickleKey,
|
||||
ReadOnlyAccount, Result, Session,
|
||||
caches::SessionStore, BackupKeys, Changes, CryptoStore, CryptoStoreError, InboundGroupSession,
|
||||
PickleKey, ReadOnlyAccount, Result, RoomKeyCounts, Session,
|
||||
};
|
||||
use crate::{
|
||||
gossiping::{GossipRequest, SecretInfo},
|
||||
@@ -205,9 +205,46 @@ impl SledStore {
|
||||
self.account_info.read().unwrap().clone()
|
||||
}
|
||||
|
||||
fn upgrade_database(db: &Db) -> Result<()> {
|
||||
let version = db
|
||||
.get("version")?
|
||||
async fn reset_backup_state(&self) -> Result<()> {
|
||||
let mut pickles: Vec<(IVec, PickledInboundGroupSession)> = self
|
||||
.inbound_group_sessions
|
||||
.iter()
|
||||
.map(|p| {
|
||||
let item = p?;
|
||||
Ok((
|
||||
item.0,
|
||||
serde_json::from_slice(&item.1).map_err(CryptoStoreError::Serialization)?,
|
||||
))
|
||||
})
|
||||
.collect::<Result<_>>()?;
|
||||
|
||||
for (_, pickle) in &mut pickles {
|
||||
pickle.backed_up = false;
|
||||
}
|
||||
|
||||
let ret: Result<(), TransactionError<serde_json::Error>> =
|
||||
self.inbound_group_sessions.transaction(|inbound_sessions| {
|
||||
for (key, pickle) in &pickles {
|
||||
inbound_sessions.insert(
|
||||
key,
|
||||
serde_json::to_vec(&pickle).map_err(ConflictableTransactionError::Abort)?,
|
||||
)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
});
|
||||
|
||||
ret?;
|
||||
|
||||
self.inner.flush_async().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn upgrade(&self) -> Result<()> {
|
||||
let version = self
|
||||
.inner
|
||||
.get("store_version")?
|
||||
.map(|v| {
|
||||
let (version_bytes, _) = v.split_at(std::mem::size_of::<u8>());
|
||||
u8::from_be_bytes(version_bytes.try_into().unwrap_or_default())
|
||||
@@ -225,17 +262,17 @@ impl SledStore {
|
||||
if version == 0 {
|
||||
// We changed the schema but migrating this isn't important since we
|
||||
// rotate the group sessions relatively often anyways so we just
|
||||
// drop it.
|
||||
db.drop_tree("outbound_group_sessions")?;
|
||||
// clear the tree.
|
||||
self.outbound_group_sessions.clear()?;
|
||||
}
|
||||
|
||||
db.insert("version", DATABASE_VERSION.to_be_bytes().as_ref())?;
|
||||
self.inner.insert("store_version", DATABASE_VERSION.to_be_bytes().as_ref())?;
|
||||
self.inner.flush()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn open_helper(db: Db, path: Option<PathBuf>, passphrase: Option<&str>) -> Result<Self> {
|
||||
Self::upgrade_database(&db)?;
|
||||
let account = db.open_tree("account")?;
|
||||
let private_identity = db.open_tree("private_identity")?;
|
||||
|
||||
@@ -263,7 +300,7 @@ impl SledStore {
|
||||
.expect("Can't create default pickle key")
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
let database = Self {
|
||||
account_info: RwLock::new(None).into(),
|
||||
path,
|
||||
inner: db,
|
||||
@@ -283,7 +320,11 @@ impl SledStore {
|
||||
tracked_users,
|
||||
olm_hashes,
|
||||
identities,
|
||||
})
|
||||
};
|
||||
|
||||
database.upgrade()?;
|
||||
|
||||
Ok(database)
|
||||
}
|
||||
|
||||
fn get_or_create_pickle_key(passphrase: &str, database: &Db) -> Result<PickleKey> {
|
||||
@@ -361,6 +402,9 @@ impl SledStore {
|
||||
None
|
||||
};
|
||||
|
||||
#[cfg(feature = "backups_v1")]
|
||||
let recovery_key_pickle = changes.recovery_key.map(|r| r.pickle(self.get_pickle_key()));
|
||||
|
||||
let device_changes = changes.devices;
|
||||
let mut session_changes = HashMap::new();
|
||||
|
||||
@@ -399,6 +443,8 @@ impl SledStore {
|
||||
let identity_changes = changes.identities;
|
||||
let olm_hashes = changes.message_hashes;
|
||||
let key_requests = changes.key_requests;
|
||||
#[cfg(feature = "backups_v1")]
|
||||
let backup_version = changes.backup_version;
|
||||
|
||||
let ret: Result<(), TransactionError<serde_json::Error>> = (
|
||||
&self.account,
|
||||
@@ -441,6 +487,22 @@ impl SledStore {
|
||||
)?;
|
||||
}
|
||||
|
||||
#[cfg(feature = "backups_v1")]
|
||||
if let Some(r) = &recovery_key_pickle {
|
||||
account.insert(
|
||||
"recovery_key_v1".encode(),
|
||||
serde_json::to_vec(r).map_err(ConflictableTransactionError::Abort)?,
|
||||
)?;
|
||||
}
|
||||
|
||||
#[cfg(feature = "backups_v1")]
|
||||
if let Some(b) = &backup_version {
|
||||
account.insert(
|
||||
"backup_version_v1".encode(),
|
||||
serde_json::to_vec(b).map_err(ConflictableTransactionError::Abort)?,
|
||||
)?;
|
||||
}
|
||||
|
||||
for device in device_changes.new.iter().chain(&device_changes.changed) {
|
||||
let key = (device.user_id().as_str(), device.device_id().as_str()).encode();
|
||||
let device = serde_json::to_vec(&device)
|
||||
@@ -655,6 +717,57 @@ impl CryptoStore for SledStore {
|
||||
.collect())
|
||||
}
|
||||
|
||||
async fn inbound_group_session_counts(&self) -> Result<RoomKeyCounts> {
|
||||
let pickles: Vec<PickledInboundGroupSession> = self
|
||||
.inbound_group_sessions
|
||||
.iter()
|
||||
.map(|p| {
|
||||
let item = p?;
|
||||
serde_json::from_slice(&item.1).map_err(CryptoStoreError::Serialization)
|
||||
})
|
||||
.collect::<Result<_>>()?;
|
||||
|
||||
let total = pickles.len();
|
||||
let backed_up = pickles.into_iter().filter(|p| p.backed_up).count();
|
||||
|
||||
Ok(RoomKeyCounts { total, backed_up })
|
||||
}
|
||||
|
||||
async fn inbound_group_sessions_for_backup(
|
||||
&self,
|
||||
limit: usize,
|
||||
) -> Result<Vec<InboundGroupSession>> {
|
||||
let pickles: Vec<InboundGroupSession> = self
|
||||
.inbound_group_sessions
|
||||
.iter()
|
||||
.map(|p| {
|
||||
let item = p?;
|
||||
serde_json::from_slice(&item.1).map_err(CryptoStoreError::from)
|
||||
})
|
||||
.filter_map(|p: Result<PickledInboundGroupSession, CryptoStoreError>| match p {
|
||||
Ok(p) => {
|
||||
if !p.backed_up {
|
||||
Some(
|
||||
InboundGroupSession::from_pickle(p, self.get_pickle_mode())
|
||||
.map_err(CryptoStoreError::from),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
Err(p) => Some(Err(p)),
|
||||
})
|
||||
.take(limit)
|
||||
.collect::<Result<_>>()?;
|
||||
|
||||
Ok(pickles)
|
||||
}
|
||||
|
||||
async fn reset_backup_state(&self) -> Result<()> {
|
||||
self.reset_backup_state().await
|
||||
}
|
||||
|
||||
async fn get_outbound_group_sessions(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
@@ -670,14 +783,14 @@ impl CryptoStore for SledStore {
|
||||
!self.users_for_key_query_cache.is_empty()
|
||||
}
|
||||
|
||||
fn tracked_users(&self) -> HashSet<UserId> {
|
||||
self.tracked_users_cache.to_owned().iter().map(|u| u.clone()).collect()
|
||||
}
|
||||
|
||||
fn users_for_key_query(&self) -> HashSet<UserId> {
|
||||
self.users_for_key_query_cache.iter().map(|u| u.clone()).collect()
|
||||
}
|
||||
|
||||
fn tracked_users(&self) -> HashSet<UserId> {
|
||||
self.tracked_users_cache.to_owned().iter().map(|u| u.clone()).collect()
|
||||
}
|
||||
|
||||
async fn update_tracked_user(&self, user: &UserId, dirty: bool) -> Result<bool> {
|
||||
let already_added = self.tracked_users_cache.insert(user.clone());
|
||||
|
||||
@@ -796,6 +909,32 @@ impl CryptoStore for SledStore {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn load_backup_keys(&self) -> Result<BackupKeys> {
|
||||
let version = self
|
||||
.account
|
||||
.get("backup_version_v1".encode())?
|
||||
.map(|v| serde_json::from_slice(&v))
|
||||
.transpose()?;
|
||||
|
||||
#[cfg(feature = "backups_v1")]
|
||||
let recovery_key = {
|
||||
self.account
|
||||
.get("recovery_key_v1".encode())?
|
||||
.map(|p| serde_json::from_slice(&p))
|
||||
.transpose()?
|
||||
.map(|p| {
|
||||
crate::backups::RecoveryKey::from_pickle(p, self.get_pickle_key())
|
||||
.map_err(|_| CryptoStoreError::UnpicklingError)
|
||||
})
|
||||
.transpose()?
|
||||
};
|
||||
|
||||
#[cfg(not(feature = "backups_v1"))]
|
||||
let recovery_key = None;
|
||||
|
||||
Ok(BackupKeys { backup_version: version, recovery_key })
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -775,11 +775,12 @@ impl TryFrom<OutgoingRequest> for OutgoingContent {
|
||||
|
||||
fn try_from(value: OutgoingRequest) -> Result<Self, Self::Error> {
|
||||
match value.request() {
|
||||
crate::OutgoingRequests::KeysUpload(_) => Err("Invalid request type".to_owned()),
|
||||
crate::OutgoingRequests::KeysQuery(_) => Err("Invalid request type".to_owned()),
|
||||
crate::OutgoingRequests::KeysUpload(_)
|
||||
| crate::OutgoingRequests::KeysQuery(_)
|
||||
| crate::OutgoingRequests::KeysBackup(_)
|
||||
| crate::OutgoingRequests::SignatureUpload(_)
|
||||
| crate::OutgoingRequests::KeysClaim(_) => Err("Invalid request type".to_owned()),
|
||||
crate::OutgoingRequests::ToDeviceRequest(r) => Self::try_from(r.clone()),
|
||||
crate::OutgoingRequests::SignatureUpload(_) => Err("Invalid request type".to_owned()),
|
||||
crate::OutgoingRequests::KeysClaim(_) => Err("Invalid request type".to_owned()),
|
||||
crate::OutgoingRequests::RoomMessage(r) => Ok(Self::from(r.clone())),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -535,7 +535,7 @@ impl IdentitiesBeingVerified {
|
||||
};
|
||||
|
||||
if should_request_secrets {
|
||||
let secret_requests = self.request_missing_secrets().await;
|
||||
let secret_requests = self.request_missing_secrets().await?;
|
||||
changes.key_requests = secret_requests;
|
||||
}
|
||||
|
||||
@@ -547,9 +547,16 @@ impl IdentitiesBeingVerified {
|
||||
.unwrap_or(VerificationResult::Ok))
|
||||
}
|
||||
|
||||
async fn request_missing_secrets(&self) -> Vec<GossipRequest> {
|
||||
let secrets = self.private_identity.get_missing_secrets().await;
|
||||
GossipMachine::request_missing_secrets(self.user_id(), secrets)
|
||||
async fn request_missing_secrets(&self) -> Result<Vec<GossipRequest>, CryptoStoreError> {
|
||||
#[allow(unused_mut)]
|
||||
let mut secrets = self.private_identity.get_missing_secrets().await;
|
||||
|
||||
#[cfg(feature = "backups_v1")]
|
||||
if self.store.inner.load_backup_keys().await?.recovery_key.is_none() {
|
||||
secrets.push(ruma::events::secret::request::SecretName::RecoveryKey);
|
||||
}
|
||||
|
||||
Ok(GossipMachine::request_missing_secrets(self.user_id(), secrets))
|
||||
}
|
||||
|
||||
async fn mark_identity_as_verified(
|
||||
@@ -667,9 +674,6 @@ impl IdentitiesBeingVerified {
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) mod test {
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
use wasm_bindgen_test::wasm_bindgen_test;
|
||||
use matrix_sdk_test::async_test;
|
||||
use std::convert::TryInto;
|
||||
|
||||
use ruma::{
|
||||
|
||||
@@ -110,7 +110,7 @@ features = ["wasm-bindgen"]
|
||||
[dev-dependencies]
|
||||
anyhow = "1.0"
|
||||
dirs = "3.0.2"
|
||||
futures = { version = "0.3.15", default-features = false }
|
||||
futures = { version = "0.3.15", default-features = false, features = ["executor"] }
|
||||
lazy_static = "1.4.0"
|
||||
matches = "0.1.8"
|
||||
matrix-sdk-test = { version = "0.4.0", path = "../matrix-sdk-test" }
|
||||
|
||||
@@ -26,15 +26,14 @@ use std::{
|
||||
use anymap2::any::CloneAnySendSync;
|
||||
use dashmap::DashMap;
|
||||
use futures_core::stream::Stream;
|
||||
use futures_timer::Delay as sleep;
|
||||
use matrix_sdk_base::{
|
||||
deserialized_responses::{JoinedRoom, LeftRoom, SyncResponse},
|
||||
deserialized_responses::SyncResponse,
|
||||
media::{MediaEventContent, MediaFormat, MediaRequest, MediaThumbnailSize, MediaType},
|
||||
BaseClient, Session, Store,
|
||||
};
|
||||
use matrix_sdk_common::{
|
||||
instant::{Duration, Instant},
|
||||
locks::{Mutex, RwLock},
|
||||
locks::{Mutex, RwLock, RwLockReadGuard},
|
||||
};
|
||||
use mime::{self, Mime};
|
||||
use ruma::{
|
||||
@@ -109,27 +108,31 @@ pub enum LoopCtrl {
|
||||
/// All of the state is held in an `Arc` so the `Client` can be cloned freely.
|
||||
#[derive(Clone)]
|
||||
pub struct Client {
|
||||
pub(crate) inner: Arc<ClientInner>,
|
||||
}
|
||||
|
||||
pub(crate) struct ClientInner {
|
||||
/// The URL of the homeserver to connect to.
|
||||
homeserver: Arc<RwLock<Url>>,
|
||||
/// The underlying HTTP client.
|
||||
http_client: HttpClient,
|
||||
/// User session data.
|
||||
pub(crate) base_client: BaseClient,
|
||||
base_client: BaseClient,
|
||||
/// Locks making sure we only have one group session sharing request in
|
||||
/// flight per room.
|
||||
#[cfg(feature = "encryption")]
|
||||
pub(crate) group_session_locks: Arc<DashMap<RoomId, Arc<Mutex<()>>>>,
|
||||
pub(crate) group_session_locks: DashMap<RoomId, Arc<Mutex<()>>>,
|
||||
#[cfg(feature = "encryption")]
|
||||
/// Lock making sure we're only doing one key claim request at a time.
|
||||
pub(crate) key_claim_lock: Arc<Mutex<()>>,
|
||||
pub(crate) members_request_locks: Arc<DashMap<RoomId, Arc<Mutex<()>>>>,
|
||||
pub(crate) typing_notice_times: Arc<DashMap<RoomId, Instant>>,
|
||||
pub(crate) key_claim_lock: Mutex<()>,
|
||||
pub(crate) members_request_locks: DashMap<RoomId, Arc<Mutex<()>>>,
|
||||
pub(crate) typing_notice_times: DashMap<RoomId, Instant>,
|
||||
/// Event handlers. See `register_event_handler`.
|
||||
pub(crate) event_handlers: Arc<RwLock<EventHandlerMap>>,
|
||||
event_handlers: RwLock<EventHandlerMap>,
|
||||
/// Custom event handler context. See `register_event_handler_context`.
|
||||
pub(crate) event_handler_data: Arc<StdRwLock<AnyMap>>,
|
||||
event_handler_data: StdRwLock<AnyMap>,
|
||||
/// Notification handlers. See `register_notification_handler`.
|
||||
notification_handlers: Arc<RwLock<Vec<NotificationHandlerFn>>>,
|
||||
notification_handlers: RwLock<Vec<NotificationHandlerFn>>,
|
||||
/// Whether the client should operate in application service style mode.
|
||||
/// This is low-level functionality. For an high-level API check the
|
||||
/// `matrix_sdk_appservice` crate.
|
||||
@@ -142,7 +145,7 @@ pub struct Client {
|
||||
/// synchronization, e.g. if we send out a request to create a room, we can
|
||||
/// wait for the sync to get the data to fetch a room object from the state
|
||||
/// store.
|
||||
pub(crate) sync_beat: Arc<event_listener::Event>,
|
||||
pub(crate) sync_beat: event_listener::Event,
|
||||
}
|
||||
|
||||
#[cfg(not(tarpaulin_include))]
|
||||
@@ -186,7 +189,7 @@ impl Client {
|
||||
let http_client =
|
||||
HttpClient::new(client, homeserver.clone(), session, config.request_config);
|
||||
|
||||
Ok(Self {
|
||||
let inner = Arc::new(ClientInner {
|
||||
homeserver,
|
||||
http_client,
|
||||
base_client,
|
||||
@@ -201,8 +204,10 @@ impl Client {
|
||||
notification_handlers: Default::default(),
|
||||
appservice_mode: config.appservice_mode,
|
||||
use_discovery_response: config.use_discovery_response,
|
||||
sync_beat: event_listener::Event::new().into(),
|
||||
})
|
||||
sync_beat: event_listener::Event::new(),
|
||||
});
|
||||
|
||||
Ok(Self { inner })
|
||||
}
|
||||
|
||||
/// Create a new [`Client`] using homeserver auto discovery.
|
||||
@@ -226,12 +231,12 @@ impl Client {
|
||||
/// // First let's try to construct an user id, presumably from user input.
|
||||
/// let alice = UserId::try_from("@alice:example.org")?;
|
||||
///
|
||||
/// // Now let's try to discover the homeserver and create client object.
|
||||
/// // Now let's try to discover the homeserver and create a client object.
|
||||
/// let client = Client::new_from_user_id(&alice).await?;
|
||||
///
|
||||
/// // Finally let's try to login.
|
||||
/// client.login(alice, "password", None, None).await?;
|
||||
/// # matrix_sdk::Result::Ok(()) });
|
||||
/// # Result::<_, matrix_sdk::Error>::Ok(()) });
|
||||
/// ```
|
||||
///
|
||||
/// [spec]: https://spec.matrix.org/unstable/client-server-api/#well-known-uri
|
||||
@@ -268,6 +273,24 @@ impl Client {
|
||||
Ok(client)
|
||||
}
|
||||
|
||||
pub(crate) fn base_client(&self) -> &BaseClient {
|
||||
&self.inner.base_client
|
||||
}
|
||||
|
||||
#[cfg(feature = "encryption")]
|
||||
pub(crate) async fn olm_machine(&self) -> Option<matrix_sdk_base::crypto::OlmMachine> {
|
||||
self.base_client().olm_machine().await
|
||||
}
|
||||
|
||||
#[cfg(feature = "encryption")]
|
||||
pub(crate) async fn mark_request_as_sent(
|
||||
&self,
|
||||
request_id: &matrix_sdk_base::uuid::Uuid,
|
||||
response: impl Into<matrix_sdk_base::crypto::IncomingResponse<'_>>,
|
||||
) -> Result<(), matrix_sdk_base::Error> {
|
||||
self.base_client().mark_request_as_sent(request_id, response).await
|
||||
}
|
||||
|
||||
fn homeserver_from_user_id(user_id: &UserId) -> Result<Url> {
|
||||
let homeserver = format!("https://{}", user_id.server_name());
|
||||
#[allow(unused_mut)]
|
||||
@@ -290,7 +313,7 @@ impl Client {
|
||||
///
|
||||
/// * `homeserver_url` - The new URL to use.
|
||||
pub async fn set_homeserver(&self, homeserver_url: Url) {
|
||||
let mut homeserver = self.homeserver.write().await;
|
||||
let mut homeserver = self.inner.homeserver.write().await;
|
||||
*homeserver = homeserver_url;
|
||||
}
|
||||
|
||||
@@ -324,23 +347,23 @@ impl Client {
|
||||
|
||||
/// Is the client logged in.
|
||||
pub async fn logged_in(&self) -> bool {
|
||||
self.base_client.logged_in().await
|
||||
self.inner.base_client.logged_in().await
|
||||
}
|
||||
|
||||
/// The Homeserver of the client.
|
||||
pub async fn homeserver(&self) -> Url {
|
||||
self.homeserver.read().await.clone()
|
||||
self.inner.homeserver.read().await.clone()
|
||||
}
|
||||
|
||||
/// Get the user id of the current owner of the client.
|
||||
pub async fn user_id(&self) -> Option<UserId> {
|
||||
let session = self.base_client.session().read().await;
|
||||
let session = self.inner.base_client.session().read().await;
|
||||
session.as_ref().cloned().map(|s| s.user_id)
|
||||
}
|
||||
|
||||
/// Get the device id that identifies the current session.
|
||||
pub async fn device_id(&self) -> Option<DeviceIdBox> {
|
||||
let session = self.base_client.session().read().await;
|
||||
let session = self.inner.base_client.session().read().await;
|
||||
session.as_ref().map(|s| s.device_id.clone())
|
||||
}
|
||||
|
||||
@@ -351,7 +374,7 @@ impl Client {
|
||||
/// Can be used with [`Client::restore_login`] to restore a previously
|
||||
/// logged in session.
|
||||
pub async fn session(&self) -> Option<Session> {
|
||||
self.base_client.session().read().await.clone()
|
||||
self.inner.base_client.session().read().await.clone()
|
||||
}
|
||||
|
||||
/// Fetches the display name of the owner of the client.
|
||||
@@ -469,7 +492,7 @@ impl Client {
|
||||
|
||||
/// Get a reference to the store.
|
||||
pub fn store(&self) -> &Store {
|
||||
self.base_client.store()
|
||||
self.inner.base_client.store()
|
||||
}
|
||||
|
||||
/// Sets the mxc avatar url of the client's owner. The avatar gets unset if
|
||||
@@ -611,32 +634,41 @@ impl Client {
|
||||
<H::Future as Future>::Output: EventHandlerResult,
|
||||
{
|
||||
let event_type = H::ID.1;
|
||||
self.event_handlers.write().await.entry(H::ID).or_default().push(Box::new(move |data| {
|
||||
let maybe_fut = serde_json::from_str(data.raw.get())
|
||||
.map(|ev| handler.clone().handle_event(ev, data));
|
||||
self.inner.event_handlers.write().await.entry(H::ID).or_default().push(Box::new(
|
||||
move |data| {
|
||||
let maybe_fut = serde_json::from_str(data.raw.get())
|
||||
.map(|ev| handler.clone().handle_event(ev, data));
|
||||
|
||||
Box::pin(async move {
|
||||
match maybe_fut {
|
||||
Ok(Some(fut)) => {
|
||||
fut.await.print_error(event_type);
|
||||
}
|
||||
Ok(None) => {
|
||||
error!("Event handler for {} has an invalid context argument", event_type);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"Failed to deserialize `{}` event, skipping event handler.\n\
|
||||
Box::pin(async move {
|
||||
match maybe_fut {
|
||||
Ok(Some(fut)) => {
|
||||
fut.await.print_error(event_type);
|
||||
}
|
||||
Ok(None) => {
|
||||
error!(
|
||||
"Event handler for {} has an invalid context argument",
|
||||
event_type
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"Failed to deserialize `{}` event, skipping event handler.\n\
|
||||
Deserialization error: {}",
|
||||
event_type, e,
|
||||
);
|
||||
event_type, e,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}));
|
||||
})
|
||||
},
|
||||
));
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
pub(crate) async fn event_handlers(&self) -> RwLockReadGuard<'_, EventHandlerMap> {
|
||||
self.inner.event_handlers.read().await
|
||||
}
|
||||
|
||||
/// Add an arbitrary value for use as event handler context.
|
||||
///
|
||||
/// The value can be obtained in an event handler by adding an argument of
|
||||
@@ -678,10 +710,18 @@ impl Client {
|
||||
where
|
||||
T: Clone + Send + Sync + 'static,
|
||||
{
|
||||
self.event_handler_data.write().unwrap().insert(ctx);
|
||||
self.inner.event_handler_data.write().unwrap().insert(ctx);
|
||||
self
|
||||
}
|
||||
|
||||
pub(crate) fn event_handler_context<T>(&self) -> Option<T>
|
||||
where
|
||||
T: Clone + Send + Sync + 'static,
|
||||
{
|
||||
let map = self.inner.event_handler_data.read().unwrap();
|
||||
map.get::<T>().cloned()
|
||||
}
|
||||
|
||||
/// Register a handler for a notification.
|
||||
///
|
||||
/// Similar to [`Client::register_event_handler`], but only allows functions
|
||||
@@ -692,13 +732,19 @@ impl Client {
|
||||
H: Fn(Notification, room::Room, Client) -> Fut + Send + Sync + 'static,
|
||||
Fut: Future<Output = ()> + Send + 'static,
|
||||
{
|
||||
self.notification_handlers.write().await.push(Box::new(
|
||||
self.inner.notification_handlers.write().await.push(Box::new(
|
||||
move |notification, room, client| Box::pin((handler)(notification, room, client)),
|
||||
));
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
pub(crate) async fn notification_handlers(
|
||||
&self,
|
||||
) -> RwLockReadGuard<'_, Vec<NotificationHandlerFn>> {
|
||||
self.inner.notification_handlers.read().await
|
||||
}
|
||||
|
||||
/// Get all the rooms the client knows about.
|
||||
///
|
||||
/// This will return the list of joined, invited, and left rooms.
|
||||
@@ -849,7 +895,7 @@ impl Client {
|
||||
/// "Logged in as {}, got device_id {} and access_token {}",
|
||||
/// user, response.device_id, response.access_token
|
||||
/// );
|
||||
/// # matrix_sdk::Result::Ok(()) });
|
||||
/// # Result::<_, matrix_sdk::Error>::Ok(()) });
|
||||
/// ```
|
||||
///
|
||||
/// [`restore_login`]: #method.restore_login
|
||||
@@ -1166,7 +1212,7 @@ impl Client {
|
||||
///
|
||||
/// * `response` - A successful login response.
|
||||
async fn receive_login_response(&self, response: &login::Response) -> Result<()> {
|
||||
if self.use_discovery_response {
|
||||
if self.inner.use_discovery_response {
|
||||
if let Some(well_known) = &response.well_known {
|
||||
if let Ok(homeserver) = Url::parse(&well_known.homeserver.base_url) {
|
||||
self.set_homeserver(homeserver).await;
|
||||
@@ -1174,7 +1220,7 @@ impl Client {
|
||||
}
|
||||
}
|
||||
|
||||
self.base_client.receive_login_response(response).await?;
|
||||
self.inner.base_client.receive_login_response(response).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -1192,9 +1238,52 @@ impl Client {
|
||||
/// * `session` - A session that the user already has from a
|
||||
/// previous login call.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```no_run
|
||||
/// use matrix_sdk::{Client, Session, ruma::{DeviceIdBox, user_id}};
|
||||
/// # use url::Url;
|
||||
/// # use futures::executor::block_on;
|
||||
/// # block_on(async {
|
||||
///
|
||||
/// let homeserver = Url::parse("http://example.com")?;
|
||||
/// let client = Client::new(homeserver).await?;
|
||||
///
|
||||
/// let session = Session {
|
||||
/// access_token: "My-Token".to_owned(),
|
||||
/// user_id: user_id!("@example:localhost"),
|
||||
/// device_id: DeviceIdBox::from("MYDEVICEID"),
|
||||
/// };
|
||||
///
|
||||
/// client.restore_login(session).await?;
|
||||
/// # anyhow::Result::<()>::Ok(()) });
|
||||
/// ```
|
||||
///
|
||||
/// The `Session` object can also be created from the response the
|
||||
/// [`Client::login()`] method returns:
|
||||
///
|
||||
/// ```no_run
|
||||
/// use matrix_sdk::{Client, Session, ruma::{DeviceIdBox, user_id}};
|
||||
/// # use url::Url;
|
||||
/// # use futures::executor::block_on;
|
||||
/// # block_on(async {
|
||||
///
|
||||
/// let homeserver = Url::parse("http://example.com")?;
|
||||
/// let client = Client::new(homeserver).await?;
|
||||
///
|
||||
/// let session: Session = client
|
||||
/// .login("example", "my-password", None, None)
|
||||
/// .await?
|
||||
/// .into();
|
||||
///
|
||||
/// // Persist the `Session` so it can later be used to restore the login.
|
||||
/// client.restore_login(session).await?;
|
||||
/// # anyhow::Result::<()>::Ok(()) });
|
||||
/// ```
|
||||
///
|
||||
/// [`login`]: #method.login
|
||||
pub async fn restore_login(&self, session: Session) -> Result<()> {
|
||||
Ok(self.base_client.restore_login(session).await?)
|
||||
Ok(self.inner.base_client.restore_login(session).await?)
|
||||
}
|
||||
|
||||
/// Register a user to the server.
|
||||
@@ -1241,8 +1330,8 @@ impl Client {
|
||||
let homeserver = self.homeserver().await;
|
||||
info!("Registering to {}", homeserver);
|
||||
|
||||
let config = if self.appservice_mode {
|
||||
Some(self.http_client.request_config.force_auth())
|
||||
let config = if self.inner.appservice_mode {
|
||||
Some(self.inner.http_client.request_config.force_auth())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
@@ -1307,14 +1396,14 @@ impl Client {
|
||||
filter_name: &str,
|
||||
definition: FilterDefinition<'_>,
|
||||
) -> Result<String> {
|
||||
if let Some(filter) = self.base_client.get_filter(filter_name).await? {
|
||||
if let Some(filter) = self.inner.base_client.get_filter(filter_name).await? {
|
||||
Ok(filter)
|
||||
} else {
|
||||
let user_id = self.user_id().await.ok_or(Error::AuthenticationRequired)?;
|
||||
let request = FilterUploadRequest::new(&user_id, definition);
|
||||
let response = self.send(request, None).await?;
|
||||
|
||||
self.base_client.receive_filter_upload(filter_name, &response).await?;
|
||||
self.inner.base_client.receive_filter_upload(filter_name, &response).await?;
|
||||
|
||||
Ok(response.filter_id)
|
||||
}
|
||||
@@ -1468,7 +1557,7 @@ impl Client {
|
||||
/// for room in response.chunk {
|
||||
/// println!("Found room {:?}", room);
|
||||
/// }
|
||||
/// # matrix_sdk::Result::Ok(()) });
|
||||
/// # Result::<_, matrix_sdk::Error>::Ok(()) });
|
||||
/// ```
|
||||
pub async fn public_rooms_filtered(
|
||||
&self,
|
||||
@@ -1526,8 +1615,8 @@ impl Client {
|
||||
content_type: Some(content_type.essence_str()),
|
||||
});
|
||||
|
||||
let request_config = self.http_client.request_config.timeout(timeout);
|
||||
Ok(self.http_client.upload(request, Some(request_config)).await?)
|
||||
let request_config = self.inner.http_client.request_config.timeout(timeout);
|
||||
Ok(self.inner.http_client.upload(request, Some(request_config)).await?)
|
||||
}
|
||||
|
||||
/// Send an arbitrary request to the server, without updating client state.
|
||||
@@ -1568,7 +1657,7 @@ impl Client {
|
||||
///
|
||||
/// // Check the corresponding Response struct to find out what types are
|
||||
/// // returned
|
||||
/// # matrix_sdk::Result::Ok(()) });
|
||||
/// # Result::<_, matrix_sdk::Error>::Ok(()) });
|
||||
/// ```
|
||||
pub async fn send<Request>(
|
||||
&self,
|
||||
@@ -1579,7 +1668,7 @@ impl Client {
|
||||
Request: OutgoingRequest + Debug,
|
||||
HttpError: From<FromHttpResponseError<Request::EndpointError>>,
|
||||
{
|
||||
Ok(self.http_client.send(request, config).await?)
|
||||
Ok(self.inner.http_client.send(request, config).await?)
|
||||
}
|
||||
|
||||
/// Get information of all our own devices.
|
||||
@@ -1603,7 +1692,7 @@ impl Client {
|
||||
/// device.display_name.as_deref().unwrap_or("")
|
||||
/// );
|
||||
/// }
|
||||
/// # matrix_sdk::Result::Ok(()) });
|
||||
/// # Result::<_, matrix_sdk::Error>::Ok(()) });
|
||||
/// ```
|
||||
pub async fn devices(&self) -> HttpResult<get_devices::Response> {
|
||||
let request = get_devices::Request::new();
|
||||
@@ -1656,7 +1745,7 @@ impl Client {
|
||||
/// .await?;
|
||||
/// }
|
||||
/// }
|
||||
/// # matrix_sdk::Result::Ok(()) });
|
||||
/// # Result::<_, matrix_sdk::Error>::Ok(()) });
|
||||
pub async fn delete_devices(
|
||||
&self,
|
||||
devices: &[DeviceIdBox],
|
||||
@@ -1668,135 +1757,6 @@ impl Client {
|
||||
self.send(request, None).await
|
||||
}
|
||||
|
||||
pub(crate) async fn process_sync(
|
||||
&self,
|
||||
response: sync_events::Response,
|
||||
) -> Result<SyncResponse> {
|
||||
let response = self.base_client.receive_sync_response(response).await?;
|
||||
let SyncResponse {
|
||||
next_batch: _,
|
||||
rooms,
|
||||
presence,
|
||||
account_data,
|
||||
to_device: _,
|
||||
device_lists: _,
|
||||
device_one_time_keys_count: _,
|
||||
ambiguity_changes: _,
|
||||
notifications,
|
||||
} = &response;
|
||||
|
||||
self.handle_sync_events(EventKind::GlobalAccountData, &None, &account_data.events).await?;
|
||||
self.handle_sync_events(EventKind::Presence, &None, &presence.events).await?;
|
||||
|
||||
for (room_id, room_info) in &rooms.join {
|
||||
let room = self.get_room(room_id);
|
||||
if room.is_none() {
|
||||
error!("Can't call event handler, room {} not found", room_id);
|
||||
continue;
|
||||
}
|
||||
|
||||
let JoinedRoom { unread_notifications: _, timeline, state, account_data, ephemeral } =
|
||||
room_info;
|
||||
|
||||
self.handle_sync_events(EventKind::EphemeralRoomData, &room, &ephemeral.events).await?;
|
||||
self.handle_sync_events(EventKind::RoomAccountData, &room, &account_data.events)
|
||||
.await?;
|
||||
self.handle_sync_state_events(&room, &state.events).await?;
|
||||
self.handle_sync_timeline_events(&room, &timeline.events).await?;
|
||||
}
|
||||
|
||||
for (room_id, room_info) in &rooms.leave {
|
||||
let room = self.get_room(room_id);
|
||||
if room.is_none() {
|
||||
error!("Can't call event handler, room {} not found", room_id);
|
||||
continue;
|
||||
}
|
||||
|
||||
let LeftRoom { timeline, state, account_data } = room_info;
|
||||
|
||||
self.handle_sync_events(EventKind::RoomAccountData, &room, &account_data.events)
|
||||
.await?;
|
||||
self.handle_sync_state_events(&room, &state.events).await?;
|
||||
self.handle_sync_timeline_events(&room, &timeline.events).await?;
|
||||
}
|
||||
|
||||
for (room_id, room_info) in &rooms.invite {
|
||||
let room = self.get_room(room_id);
|
||||
if room.is_none() {
|
||||
error!("Can't call event handler, room {} not found", room_id);
|
||||
continue;
|
||||
}
|
||||
|
||||
// FIXME: Destructure room_info
|
||||
self.handle_sync_events(
|
||||
EventKind::StrippedState,
|
||||
&room,
|
||||
&room_info.invite_state.events,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
// Construct notification event handler futures
|
||||
let mut futures = Vec::new();
|
||||
for handler in &*self.notification_handlers.read().await {
|
||||
for (room_id, room_notifications) in notifications {
|
||||
let room = match self.get_room(room_id) {
|
||||
Some(room) => room,
|
||||
None => {
|
||||
warn!("Can't call notification handler, room {} not found", room_id);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
futures.extend(room_notifications.iter().map(|notification| {
|
||||
(handler)(notification.clone(), room.clone(), self.clone())
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
// Run the notification handler futures with the
|
||||
// `self.notification_handlers` lock no longer being held, in order.
|
||||
for fut in futures {
|
||||
fut.await;
|
||||
}
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
async fn sync_loop_helper(
|
||||
&self,
|
||||
sync_settings: &mut crate::config::SyncSettings<'_>,
|
||||
) -> Result<SyncResponse> {
|
||||
let response = self.sync_once(sync_settings.clone()).await;
|
||||
|
||||
match response {
|
||||
Ok(r) => {
|
||||
sync_settings.token = Some(r.next_batch.clone());
|
||||
Ok(r)
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Received an invalid response: {}", e);
|
||||
sleep::new(Duration::from_secs(1)).await;
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn delay_sync(last_sync_time: &mut Option<Instant>) {
|
||||
let now = Instant::now();
|
||||
|
||||
// If the last sync happened less than a second ago, sleep for a
|
||||
// while to not hammer out requests if the server doesn't respect
|
||||
// the sync timeout.
|
||||
if let Some(t) = last_sync_time {
|
||||
if now - *t <= Duration::from_secs(1) {
|
||||
sleep::new(Duration::from_secs(1)).await;
|
||||
}
|
||||
}
|
||||
|
||||
*last_sync_time = Some(now);
|
||||
}
|
||||
|
||||
/// Synchronize the client's state with the latest state on the server.
|
||||
///
|
||||
/// ## Syncing Events
|
||||
@@ -1874,7 +1834,7 @@ impl Client {
|
||||
/// // Now keep on syncing forever. `sync()` will use the stored sync token
|
||||
/// // from our `sync_once()` call automatically.
|
||||
/// client.sync(SyncSettings::default()).await;
|
||||
/// # matrix_sdk::Result::Ok(()) });
|
||||
/// # Result::<_, matrix_sdk::Error>::Ok(()) });
|
||||
/// ```
|
||||
///
|
||||
/// [`sync`]: #method.sync
|
||||
@@ -1901,9 +1861,9 @@ impl Client {
|
||||
timeout: sync_settings.timeout,
|
||||
});
|
||||
|
||||
let request_config = self.http_client.request_config.timeout(
|
||||
let request_config = self.inner.http_client.request_config.timeout(
|
||||
sync_settings.timeout.unwrap_or_else(|| Duration::from_secs(0))
|
||||
+ self.http_client.request_config.timeout,
|
||||
+ self.inner.http_client.request_config.timeout,
|
||||
);
|
||||
|
||||
let response = self.send(request, Some(request_config)).await?;
|
||||
@@ -1914,7 +1874,7 @@ impl Client {
|
||||
error!(error =? e, "Error while sending outgoing E2EE requests");
|
||||
};
|
||||
|
||||
self.sync_beat.notify(usize::MAX);
|
||||
self.inner.sync_beat.notify(usize::MAX);
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
@@ -1931,7 +1891,7 @@ impl Client {
|
||||
/// method to react to individual events. If you instead wish to handle
|
||||
/// events in a bulk manner the [`Client::sync_with_callback`] and
|
||||
/// [`Client::sync_stream`] methods can be used instead. Those two methods
|
||||
/// repeadetly return the whole sync response.
|
||||
/// repeatedly return the whole sync response.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
@@ -1965,7 +1925,7 @@ impl Client {
|
||||
/// // Now keep on syncing forever. `sync()` will use the latest sync token
|
||||
/// // automatically.
|
||||
/// client.sync(SyncSettings::default()).await;
|
||||
/// # matrix_sdk::Result::Ok(()) });
|
||||
/// # Result::<_, matrix_sdk::Error>::Ok(()) });
|
||||
/// ```
|
||||
///
|
||||
/// [argument docs]: #method.sync_once
|
||||
@@ -2094,7 +2054,7 @@ impl Client {
|
||||
/// }
|
||||
/// }
|
||||
///
|
||||
/// # matrix_sdk::Result::Ok(()) });
|
||||
/// # Result::<_, matrix_sdk::Error>::Ok(()) });
|
||||
/// ```
|
||||
#[instrument]
|
||||
pub async fn sync_stream<'a>(
|
||||
@@ -2119,7 +2079,7 @@ impl Client {
|
||||
/// Get the current, if any, sync token of the client.
|
||||
/// This will be None if the client didn't sync at least once.
|
||||
pub async fn sync_token(&self) -> Option<String> {
|
||||
self.base_client.sync_token().await
|
||||
self.inner.base_client.sync_token().await
|
||||
}
|
||||
|
||||
/// Get a media file's content.
|
||||
@@ -2138,7 +2098,7 @@ impl Client {
|
||||
use_cache: bool,
|
||||
) -> Result<Vec<u8>> {
|
||||
let content = if use_cache {
|
||||
self.base_client.store().get_media_content(request).await?
|
||||
self.inner.base_client.store().get_media_content(request).await?
|
||||
} else {
|
||||
None
|
||||
};
|
||||
@@ -2182,7 +2142,7 @@ impl Client {
|
||||
};
|
||||
|
||||
if use_cache {
|
||||
self.base_client.store().add_media_content(request, content.clone()).await?;
|
||||
self.inner.base_client.store().add_media_content(request, content.clone()).await?;
|
||||
}
|
||||
|
||||
Ok(content)
|
||||
@@ -2195,7 +2155,7 @@ impl Client {
|
||||
///
|
||||
/// * `request` - The `MediaRequest` of the content.
|
||||
pub async fn remove_media_content(&self, request: &MediaRequest) -> Result<()> {
|
||||
Ok(self.base_client.store().remove_media_content(request).await?)
|
||||
Ok(self.inner.base_client.store().remove_media_content(request).await?)
|
||||
}
|
||||
|
||||
/// Delete all the media content corresponding to the given
|
||||
@@ -2205,7 +2165,7 @@ impl Client {
|
||||
///
|
||||
/// * `uri` - The `MxcUri` of the files.
|
||||
pub async fn remove_media_content_for_uri(&self, uri: &MxcUri) -> Result<()> {
|
||||
Ok(self.base_client.store().remove_media_content_for_uri(uri).await?)
|
||||
Ok(self.inner.base_client.store().remove_media_content_for_uri(uri).await?)
|
||||
}
|
||||
|
||||
/// Get the file of the given media event content.
|
||||
@@ -2724,7 +2684,7 @@ pub(crate) mod test {
|
||||
.add_state_event(EventsJson::PowerLevels)
|
||||
.build_sync_response();
|
||||
|
||||
client.base_client.receive_sync_response(response).await.unwrap();
|
||||
client.inner.base_client.receive_sync_response(response).await.unwrap();
|
||||
let room_id = room_id!("!SVkFJHzfwvuaIEawgC:localhost");
|
||||
|
||||
assert_eq!(client.homeserver().await, Url::parse(&mockito::server_url()).unwrap());
|
||||
|
||||
@@ -90,7 +90,7 @@ impl ClientConfig {
|
||||
/// let client_config = ClientConfig::new()
|
||||
/// .proxy("http://localhost:8080")?;
|
||||
///
|
||||
/// # matrix_sdk::Result::Ok(())
|
||||
/// # Result::<_, matrix_sdk::Error>::Ok(())
|
||||
/// ```
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub fn proxy(mut self, proxy: &str) -> Result<Self> {
|
||||
@@ -105,7 +105,7 @@ impl ClientConfig {
|
||||
}
|
||||
|
||||
/// Set a custom HTTP user agent for the client.
|
||||
pub fn user_agent(mut self, user_agent: &str) -> std::result::Result<Self, InvalidHeaderValue> {
|
||||
pub fn user_agent(mut self, user_agent: &str) -> Result<Self, InvalidHeaderValue> {
|
||||
self.user_agent = Some(HeaderValue::from_str(user_agent)?);
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::{ops::Deref, result::Result as StdResult};
|
||||
use std::ops::Deref;
|
||||
|
||||
use matrix_sdk_base::crypto::{
|
||||
store::CryptoStoreError, Device as BaseDevice, LocalTrust, ReadOnlyDevice,
|
||||
@@ -250,7 +250,7 @@ impl Device {
|
||||
/// }
|
||||
/// # anyhow::Result::<()>::Ok(()) });
|
||||
/// ```
|
||||
pub async fn verify(&self) -> std::result::Result<(), ManualVerifyError> {
|
||||
pub async fn verify(&self) -> Result<(), ManualVerifyError> {
|
||||
let request = self.inner.verify().await?;
|
||||
self.client.send(request, None).await?;
|
||||
|
||||
@@ -391,10 +391,7 @@ impl Device {
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `trust_state` - The new trust state that should be set for the device.
|
||||
pub async fn set_local_trust(
|
||||
&self,
|
||||
trust_state: LocalTrust,
|
||||
) -> StdResult<(), CryptoStoreError> {
|
||||
pub async fn set_local_trust(&self, trust_state: LocalTrust) -> Result<(), CryptoStoreError> {
|
||||
self.inner.set_local_trust(trust_state).await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::{result::Result, sync::Arc};
|
||||
use std::sync::Arc;
|
||||
|
||||
use matrix_sdk_base::{
|
||||
crypto::{
|
||||
|
||||
@@ -266,6 +266,7 @@ use matrix_sdk_base::{
|
||||
use matrix_sdk_common::{instant::Duration, uuid::Uuid};
|
||||
use ruma::{
|
||||
api::client::r0::{
|
||||
backup::add_backup_keys::Response as KeysBackupResponse,
|
||||
keys::{get_keys, upload_keys, upload_signing_keys::Request as UploadSigningKeysRequest},
|
||||
message::send_message_event,
|
||||
to_device::send_event_to_device::{
|
||||
@@ -294,7 +295,7 @@ impl Client {
|
||||
/// called the fingerprint of the device.
|
||||
#[cfg(feature = "encryption")]
|
||||
pub async fn ed25519_key(&self) -> Option<String> {
|
||||
self.base_client.olm_machine().await.map(|o| o.identity_keys().ed25519().to_owned())
|
||||
self.olm_machine().await.map(|o| o.identity_keys().ed25519().to_owned())
|
||||
}
|
||||
|
||||
/// Get the status of the private cross signing keys.
|
||||
@@ -303,7 +304,7 @@ impl Client {
|
||||
/// stored locally.
|
||||
#[cfg(feature = "encryption")]
|
||||
pub async fn cross_signing_status(&self) -> Option<CrossSigningStatus> {
|
||||
if let Some(machine) = self.base_client.olm_machine().await {
|
||||
if let Some(machine) = self.olm_machine().await {
|
||||
Some(machine.cross_signing_status().await)
|
||||
} else {
|
||||
None
|
||||
@@ -316,13 +317,13 @@ impl Client {
|
||||
/// capable devices up to date.
|
||||
#[cfg(feature = "encryption")]
|
||||
pub async fn tracked_users(&self) -> HashSet<UserId> {
|
||||
self.base_client.olm_machine().await.map(|o| o.tracked_users()).unwrap_or_default()
|
||||
self.olm_machine().await.map(|o| o.tracked_users()).unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Get a verification object with the given flow id.
|
||||
#[cfg(feature = "encryption")]
|
||||
pub async fn get_verification(&self, user_id: &UserId, flow_id: &str) -> Option<Verification> {
|
||||
let olm = self.base_client.olm_machine().await?;
|
||||
let olm = self.olm_machine().await?;
|
||||
olm.get_verification(user_id, flow_id).map(|v| match v {
|
||||
matrix_sdk_base::crypto::Verification::SasV1(s) => {
|
||||
SasVerification { inner: s, client: self.clone() }.into()
|
||||
@@ -342,7 +343,7 @@ impl Client {
|
||||
user_id: &UserId,
|
||||
flow_id: impl AsRef<str>,
|
||||
) -> Option<VerificationRequest> {
|
||||
let olm = self.base_client.olm_machine().await?;
|
||||
let olm = self.olm_machine().await?;
|
||||
|
||||
olm.get_verification_request(user_id, flow_id)
|
||||
.map(|r| VerificationRequest { inner: r, client: self.clone() })
|
||||
@@ -387,7 +388,7 @@ impl Client {
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
) -> StdResult<Option<Device>, CryptoStoreError> {
|
||||
let device = self.base_client.get_device(user_id, device_id).await?;
|
||||
let device = self.base_client().get_device(user_id, device_id).await?;
|
||||
|
||||
Ok(device.map(|d| Device { inner: d, client: self.clone() }))
|
||||
}
|
||||
@@ -424,7 +425,7 @@ impl Client {
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
) -> StdResult<UserDevices, CryptoStoreError> {
|
||||
let devices = self.base_client.get_user_devices(user_id).await?;
|
||||
let devices = self.base_client().get_user_devices(user_id).await?;
|
||||
|
||||
Ok(UserDevices { inner: devices, client: self.clone() })
|
||||
}
|
||||
@@ -467,7 +468,7 @@ impl Client {
|
||||
) -> StdResult<Option<crate::encryption::identities::UserIdentity>, CryptoStoreError> {
|
||||
use crate::encryption::identities::UserIdentity;
|
||||
|
||||
if let Some(olm) = self.base_client.olm_machine().await {
|
||||
if let Some(olm) = self.olm_machine().await {
|
||||
let identity = olm.get_identity(user_id).await?;
|
||||
|
||||
Ok(identity.map(|i| match i {
|
||||
@@ -525,7 +526,7 @@ impl Client {
|
||||
/// # anyhow::Result::<()>::Ok(()) });
|
||||
#[cfg(feature = "encryption")]
|
||||
pub async fn bootstrap_cross_signing(&self, auth_data: Option<AuthData<'_>>) -> Result<()> {
|
||||
let olm = self.base_client.olm_machine().await.ok_or(Error::AuthenticationRequired)?;
|
||||
let olm = self.olm_machine().await.ok_or(Error::AuthenticationRequired)?;
|
||||
|
||||
let (request, signature_request) = olm.bootstrap_cross_signing(false).await?;
|
||||
|
||||
@@ -600,7 +601,7 @@ impl Client {
|
||||
passphrase: &str,
|
||||
predicate: impl FnMut(&matrix_sdk_base::crypto::olm::InboundGroupSession) -> bool,
|
||||
) -> Result<()> {
|
||||
let olm = self.base_client.olm_machine().await.ok_or(Error::AuthenticationRequired)?;
|
||||
let olm = self.olm_machine().await.ok_or(Error::AuthenticationRequired)?;
|
||||
|
||||
let keys = olm.export_keys(predicate).await?;
|
||||
let passphrase = zeroize::Zeroizing::new(passphrase.to_owned());
|
||||
@@ -660,7 +661,7 @@ impl Client {
|
||||
path: PathBuf,
|
||||
passphrase: &str,
|
||||
) -> StdResult<RoomKeyImportResult, RoomKeyImportError> {
|
||||
let olm = self.base_client.olm_machine().await.ok_or(RoomKeyImportError::StoreClosed)?;
|
||||
let olm = self.olm_machine().await.ok_or(RoomKeyImportError::StoreClosed)?;
|
||||
let passphrase = zeroize::Zeroizing::new(passphrase.to_owned());
|
||||
|
||||
let decrypt = move || {
|
||||
@@ -683,7 +684,7 @@ impl Client {
|
||||
) -> serde_json::Result<RoomEvent> {
|
||||
use ruma::serde::JsonObject;
|
||||
|
||||
if let Some(machine) = self.base_client.olm_machine().await {
|
||||
if let Some(machine) = self.olm_machine().await {
|
||||
if let AnyRoomEvent::Message(event) = event {
|
||||
if let AnyMessageEvent::RoomEncrypted(_) = event {
|
||||
let room_id = event.room_id();
|
||||
@@ -726,7 +727,7 @@ impl Client {
|
||||
let request = assign!(get_keys::Request::new(), { device_keys });
|
||||
|
||||
let response = self.send(request, None).await?;
|
||||
self.base_client.mark_request_as_sent(request_id, &response).await?;
|
||||
self.mark_request_as_sent(request_id, &response).await?;
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
@@ -843,7 +844,7 @@ impl Client {
|
||||
if let Some(room) = self.get_joined_room(&response.room_id) {
|
||||
Ok(Some(room))
|
||||
} else {
|
||||
self.sync_beat.listen().wait_timeout(SYNC_WAIT_TIME);
|
||||
self.inner.sync_beat.listen().wait_timeout(SYNC_WAIT_TIME);
|
||||
Ok(self.get_joined_room(&response.room_id))
|
||||
}
|
||||
}
|
||||
@@ -859,11 +860,11 @@ impl Client {
|
||||
&self,
|
||||
users: impl Iterator<Item = &UserId>,
|
||||
) -> Result<()> {
|
||||
let _lock = self.key_claim_lock.lock().await;
|
||||
let _lock = self.inner.key_claim_lock.lock().await;
|
||||
|
||||
if let Some((request_id, request)) = self.base_client.get_missing_sessions(users).await? {
|
||||
if let Some((request_id, request)) = self.base_client().get_missing_sessions(users).await? {
|
||||
let response = self.send(request, None).await?;
|
||||
self.base_client.mark_request_as_sent(&request_id, &response).await?;
|
||||
self.mark_request_as_sent(&request_id, &response).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@@ -892,7 +893,7 @@ impl Client {
|
||||
);
|
||||
|
||||
let response = self.send(request.clone(), None).await?;
|
||||
self.base_client.mark_request_as_sent(request_id, &response).await?;
|
||||
self.mark_request_as_sent(request_id, &response).await?;
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
@@ -970,25 +971,41 @@ impl Client {
|
||||
}
|
||||
OutgoingRequests::ToDeviceRequest(request) => {
|
||||
let response = self.send_to_device(request).await?;
|
||||
self.base_client.mark_request_as_sent(r.request_id(), &response).await?;
|
||||
self.mark_request_as_sent(r.request_id(), &response).await?;
|
||||
}
|
||||
OutgoingRequests::SignatureUpload(request) => {
|
||||
let response = self.send(request.clone(), None).await?;
|
||||
self.base_client.mark_request_as_sent(r.request_id(), &response).await?;
|
||||
self.mark_request_as_sent(r.request_id(), &response).await?;
|
||||
}
|
||||
OutgoingRequests::RoomMessage(request) => {
|
||||
let response = self.room_send_helper(request).await?;
|
||||
self.base_client.mark_request_as_sent(r.request_id(), &response).await?;
|
||||
self.mark_request_as_sent(r.request_id(), &response).await?;
|
||||
}
|
||||
OutgoingRequests::KeysClaim(request) => {
|
||||
let response = self.send(request.clone(), None).await?;
|
||||
self.base_client.mark_request_as_sent(r.request_id(), &response).await?;
|
||||
self.mark_request_as_sent(r.request_id(), &response).await?;
|
||||
}
|
||||
OutgoingRequests::KeysBackup(request) => {
|
||||
let response = self.send_backup_request(request).await?;
|
||||
self.mark_request_as_sent(r.request_id(), &response).await?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn send_backup_request(
|
||||
&self,
|
||||
request: &matrix_sdk_base::crypto::KeysBackupRequest,
|
||||
) -> Result<KeysBackupResponse> {
|
||||
let request = ruma::api::client::r0::backup::add_backup_keys::Request::new(
|
||||
&request.version,
|
||||
request.rooms.to_owned(),
|
||||
);
|
||||
|
||||
Ok(self.send(request, None).await?)
|
||||
}
|
||||
|
||||
pub(crate) async fn send_outgoing_requests(&self) -> Result<()> {
|
||||
const MAX_CONCURRENT_REQUESTS: usize = 20;
|
||||
|
||||
@@ -998,7 +1015,7 @@ impl Client {
|
||||
warn!("Error while claiming one-time keys {:?}", e);
|
||||
}
|
||||
|
||||
let outgoing_requests = stream::iter(self.base_client.outgoing_requests().await?)
|
||||
let outgoing_requests = stream::iter(self.base_client().outgoing_requests().await?)
|
||||
.map(|r| self.send_outgoing_request(r));
|
||||
|
||||
let requests = outgoing_requests.buffer_unordered(MAX_CONCURRENT_REQUESTS);
|
||||
|
||||
@@ -71,7 +71,7 @@ impl QrVerification {
|
||||
///
|
||||
/// The [`to_bytes()`](#method.to_bytes) method can be used to instead
|
||||
/// output the raw bytes that should be encoded as a QR code.
|
||||
pub fn to_qr_code(&self) -> std::result::Result<QrCode, EncodingError> {
|
||||
pub fn to_qr_code(&self) -> Result<QrCode, EncodingError> {
|
||||
self.inner.to_qr_code()
|
||||
}
|
||||
|
||||
@@ -80,7 +80,7 @@ impl QrVerification {
|
||||
///
|
||||
/// The [`to_qr_code()`](#method.to_qr_code) method can be used to instead
|
||||
/// output a `QrCode` object that can be rendered.
|
||||
pub fn to_bytes(&self) -> std::result::Result<Vec<u8>, EncodingError> {
|
||||
pub fn to_bytes(&self) -> Result<Vec<u8>, EncodingError> {
|
||||
self.inner.to_bytes()
|
||||
}
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ use thiserror::Error;
|
||||
use url::ParseError as UrlParseError;
|
||||
|
||||
/// Result type of the matrix-sdk.
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
pub type Result<T, E = Error> = std::result::Result<T, E>;
|
||||
|
||||
/// Result type of a pure HTTP request.
|
||||
pub type HttpResult<T> = std::result::Result<T, HttpError>;
|
||||
|
||||
@@ -186,8 +186,7 @@ pub struct Ctx<T>(pub T);
|
||||
|
||||
impl<T: Clone + Send + Sync + 'static> EventHandlerContext for Ctx<T> {
|
||||
fn from_data(data: &EventHandlerData<'_>) -> Option<Self> {
|
||||
let anymap = data.client.event_handler_data.read().unwrap();
|
||||
Some(Ctx(anymap.get::<T>()?.clone()))
|
||||
data.client.event_handler_context::<T>().map(Ctx)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -330,8 +329,7 @@ impl Client {
|
||||
|
||||
// Construct event handler futures
|
||||
let futures: Vec<_> = self
|
||||
.event_handlers
|
||||
.read()
|
||||
.event_handlers()
|
||||
.await
|
||||
.get(&event_handler_id)
|
||||
.into_iter()
|
||||
|
||||
@@ -53,8 +53,8 @@ pub mod event_handler;
|
||||
mod http_client;
|
||||
/// High-level room API
|
||||
pub mod room;
|
||||
/// High-level room API
|
||||
mod room_member;
|
||||
mod sync;
|
||||
|
||||
#[cfg(feature = "encryption")]
|
||||
pub mod encryption;
|
||||
|
||||
@@ -168,14 +168,17 @@ impl Common {
|
||||
|
||||
pub(crate) async fn request_members(&self) -> Result<Option<MembersResponse>> {
|
||||
if let Some(mutex) =
|
||||
self.client.members_request_locks.get(self.inner.room_id()).map(|m| m.clone())
|
||||
self.client.inner.members_request_locks.get(self.inner.room_id()).map(|m| m.clone())
|
||||
{
|
||||
mutex.lock().await;
|
||||
|
||||
Ok(None)
|
||||
} else {
|
||||
let mutex = Arc::new(Mutex::new(()));
|
||||
self.client.members_request_locks.insert(self.inner.room_id().clone(), mutex.clone());
|
||||
self.client
|
||||
.inner
|
||||
.members_request_locks
|
||||
.insert(self.inner.room_id().clone(), mutex.clone());
|
||||
|
||||
let _guard = mutex.lock().await;
|
||||
|
||||
@@ -183,9 +186,9 @@ impl Common {
|
||||
let response = self.client.send(request, None).await?;
|
||||
|
||||
let response =
|
||||
self.client.base_client.receive_members(self.inner.room_id(), &response).await?;
|
||||
self.client.base_client().receive_members(self.inner.room_id(), &response).await?;
|
||||
|
||||
self.client.members_request_locks.remove(self.inner.room_id());
|
||||
self.client.inner.members_request_locks.remove(self.inner.room_id());
|
||||
|
||||
Ok(Some(response))
|
||||
}
|
||||
|
||||
@@ -170,37 +170,39 @@ impl Joined {
|
||||
/// if let Some(room) = client.get_joined_room(&room_id) {
|
||||
/// room.typing_notice(true).await?
|
||||
/// }
|
||||
/// # matrix_sdk::Result::Ok(()) });
|
||||
/// # Result::<_, matrix_sdk::Error>::Ok(()) });
|
||||
/// ```
|
||||
pub async fn typing_notice(&self, typing: bool) -> Result<()> {
|
||||
// Only send a request to the homeserver if the old timeout has elapsed
|
||||
// or the typing notice changed state within the
|
||||
// TYPING_NOTICE_TIMEOUT
|
||||
let send =
|
||||
if let Some(typing_time) = self.client.typing_notice_times.get(self.inner.room_id()) {
|
||||
if typing_time.elapsed() > TYPING_NOTICE_RESEND_TIMEOUT {
|
||||
// We always reactivate the typing notice if typing is true or
|
||||
// we may need to deactivate it if it's
|
||||
// currently active if typing is false
|
||||
typing || typing_time.elapsed() <= TYPING_NOTICE_TIMEOUT
|
||||
} else {
|
||||
// Only send a request when we need to deactivate typing
|
||||
!typing
|
||||
}
|
||||
let send = if let Some(typing_time) =
|
||||
self.client.inner.typing_notice_times.get(self.inner.room_id())
|
||||
{
|
||||
if typing_time.elapsed() > TYPING_NOTICE_RESEND_TIMEOUT {
|
||||
// We always reactivate the typing notice if typing is true or
|
||||
// we may need to deactivate it if it's
|
||||
// currently active if typing is false
|
||||
typing || typing_time.elapsed() <= TYPING_NOTICE_TIMEOUT
|
||||
} else {
|
||||
// Typing notice is currently deactivated, therefore, send a request
|
||||
// only when it's about to be activated
|
||||
typing
|
||||
};
|
||||
// Only send a request when we need to deactivate typing
|
||||
!typing
|
||||
}
|
||||
} else {
|
||||
// Typing notice is currently deactivated, therefore, send a request
|
||||
// only when it's about to be activated
|
||||
typing
|
||||
};
|
||||
|
||||
if send {
|
||||
let typing = if typing {
|
||||
self.client
|
||||
.inner
|
||||
.typing_notice_times
|
||||
.insert(self.inner.room_id().clone(), Instant::now());
|
||||
Typing::Yes(TYPING_NOTICE_TIMEOUT)
|
||||
} else {
|
||||
self.client.typing_notice_times.remove(self.inner.room_id());
|
||||
self.client.inner.typing_notice_times.remove(self.inner.room_id());
|
||||
Typing::No
|
||||
};
|
||||
|
||||
@@ -278,7 +280,7 @@ impl Joined {
|
||||
/// if let Some(room) = client.get_joined_room(&room_id) {
|
||||
/// room.enable_encryption().await?
|
||||
/// }
|
||||
/// # matrix_sdk::Result::Ok(()) });
|
||||
/// # Result::<_, matrix_sdk::Error>::Ok(()) });
|
||||
/// ```
|
||||
pub async fn enable_encryption(&self) -> Result<()> {
|
||||
use ruma::{
|
||||
@@ -294,7 +296,7 @@ impl Joined {
|
||||
// TODO do we want to return an error here if we time out? This
|
||||
// could be quite useful if someone wants to enable encryption and
|
||||
// send a message right after it's enabled.
|
||||
self.client.sync_beat.listen().wait_timeout(SYNC_WAIT_TIME);
|
||||
self.client.inner.sync_beat.listen().wait_timeout(SYNC_WAIT_TIME);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@@ -311,7 +313,7 @@ impl Joined {
|
||||
// TODO expose this publicly so people can pre-share a group session if
|
||||
// e.g. a user starts to type a message for a room.
|
||||
if let Some(mutex) =
|
||||
self.client.group_session_locks.get(self.inner.room_id()).map(|m| m.clone())
|
||||
self.client.inner.group_session_locks.get(self.inner.room_id()).map(|m| m.clone())
|
||||
{
|
||||
// If a group session share request is already going on,
|
||||
// await the release of the lock.
|
||||
@@ -320,7 +322,10 @@ impl Joined {
|
||||
// Otherwise create a new lock and share the group
|
||||
// session.
|
||||
let mutex = Arc::new(Mutex::new(()));
|
||||
self.client.group_session_locks.insert(self.inner.room_id().clone(), mutex.clone());
|
||||
self.client
|
||||
.inner
|
||||
.group_session_locks
|
||||
.insert(self.inner.room_id().clone(), mutex.clone());
|
||||
|
||||
let _guard = mutex.lock().await;
|
||||
|
||||
@@ -334,13 +339,13 @@ impl Joined {
|
||||
|
||||
let response = self.share_group_session().await;
|
||||
|
||||
self.client.group_session_locks.remove(self.inner.room_id());
|
||||
self.client.inner.group_session_locks.remove(self.inner.room_id());
|
||||
|
||||
// If one of the responses failed invalidate the group
|
||||
// session as using it would end up in undecryptable
|
||||
// messages.
|
||||
if let Err(r) = response {
|
||||
self.client.base_client.invalidate_group_session(self.inner.room_id()).await?;
|
||||
self.client.base_client().invalidate_group_session(self.inner.room_id()).await?;
|
||||
return Err(r);
|
||||
}
|
||||
}
|
||||
@@ -357,12 +362,12 @@ impl Joined {
|
||||
#[cfg(feature = "encryption")]
|
||||
async fn share_group_session(&self) -> Result<()> {
|
||||
let mut requests =
|
||||
self.client.base_client.share_group_session(self.inner.room_id()).await?;
|
||||
self.client.base_client().share_group_session(self.inner.room_id()).await?;
|
||||
|
||||
for request in requests.drain(..) {
|
||||
let response = self.client.send_to_device(&request).await?;
|
||||
|
||||
self.client.base_client.mark_request_as_sent(&request.txn_id, &response).await?;
|
||||
self.client.mark_request_as_sent(&request.txn_id, &response).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@@ -449,7 +454,7 @@ impl Joined {
|
||||
/// if let Some(room) = client.get_joined_room(&room_id) {
|
||||
/// room.send(content, Some(txn_id)).await?;
|
||||
/// }
|
||||
/// # matrix_sdk::Result::Ok(()) });
|
||||
/// # Result::<_, matrix_sdk::Error>::Ok(()) });
|
||||
/// ```
|
||||
///
|
||||
/// [`SyncMessageEvent`]: ruma::events::SyncMessageEvent
|
||||
@@ -517,7 +522,7 @@ impl Joined {
|
||||
/// if let Some(room) = client.get_joined_room(&room_id) {
|
||||
/// room.send_raw(content, "m.room.message", None).await?;
|
||||
/// }
|
||||
/// # matrix_sdk::Result::Ok(()) });
|
||||
/// # Result::<_, matrix_sdk::Error>::Ok(()) });
|
||||
/// ```
|
||||
///
|
||||
/// [`SyncMessageEvent`]: ruma::events::SyncMessageEvent
|
||||
@@ -554,8 +559,7 @@ impl Joined {
|
||||
|
||||
self.preshare_group_session().await?;
|
||||
|
||||
let olm =
|
||||
self.client.base_client.olm_machine().await.expect("Olm machine wasn't started");
|
||||
let olm = self.client.olm_machine().await.expect("Olm machine wasn't started");
|
||||
|
||||
let encrypted_content =
|
||||
olm.encrypt_raw(self.inner.room_id(), content, event_type).await?;
|
||||
@@ -631,7 +635,7 @@ impl Joined {
|
||||
/// None,
|
||||
/// ).await?;
|
||||
/// }
|
||||
/// # matrix_sdk::Result::Ok(()) });
|
||||
/// # Result::<_, matrix_sdk::Error>::Ok(()) });
|
||||
/// ```
|
||||
pub async fn send_attachment<R: Read>(
|
||||
&self,
|
||||
@@ -698,7 +702,7 @@ impl Joined {
|
||||
/// if let Some(room) = client.get_joined_room(&room_id) {
|
||||
/// room.send_state_event(content, "").await?;
|
||||
/// }
|
||||
/// # matrix_sdk::Result::Ok(()) });
|
||||
/// # Result::<_, matrix_sdk::Error>::Ok(()) });
|
||||
/// ```
|
||||
pub async fn send_state_event(
|
||||
&self,
|
||||
@@ -791,7 +795,7 @@ impl Joined {
|
||||
/// let reason = Some("Indecent material");
|
||||
/// room.redact(&event_id, reason, None).await?;
|
||||
/// }
|
||||
/// # matrix_sdk::Result::Ok(()) });
|
||||
/// # Result::<_, matrix_sdk::Error>::Ok(()) });
|
||||
/// ```
|
||||
pub async fn redact(
|
||||
&self,
|
||||
@@ -834,7 +838,7 @@ impl Joined {
|
||||
///
|
||||
/// room.set_tag("u.work", tag_info ).await?;
|
||||
/// }
|
||||
/// # matrix_sdk::Result::Ok(()) });
|
||||
/// # Result::<_, matrix_sdk::Error>::Ok(()) });
|
||||
/// ```
|
||||
pub async fn set_tag(&self, tag: &str, tag_info: TagInfo) -> HttpResult<create_tag::Response> {
|
||||
let user_id = self.client.user_id().await.ok_or(HttpError::AuthenticationRequired)?;
|
||||
|
||||
144
crates/matrix-sdk/src/sync.rs
Normal file
144
crates/matrix-sdk/src/sync.rs
Normal file
@@ -0,0 +1,144 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use futures_timer::Delay as sleep;
|
||||
use matrix_sdk_base::{
|
||||
deserialized_responses::{JoinedRoom, LeftRoom, SyncResponse},
|
||||
instant::Instant,
|
||||
};
|
||||
use ruma::api::client::r0::sync::sync_events;
|
||||
use tracing::{error, warn};
|
||||
|
||||
use crate::{event_handler::EventKind, Client, Result};
|
||||
|
||||
/// Internal functionality related to getting events from the server
|
||||
/// (`sync_events` endpoint)
|
||||
impl Client {
|
||||
pub(crate) async fn process_sync(
|
||||
&self,
|
||||
response: sync_events::Response,
|
||||
) -> Result<SyncResponse> {
|
||||
let response = self.base_client().receive_sync_response(response).await?;
|
||||
let SyncResponse {
|
||||
next_batch: _,
|
||||
rooms,
|
||||
presence,
|
||||
account_data,
|
||||
to_device: _,
|
||||
device_lists: _,
|
||||
device_one_time_keys_count: _,
|
||||
ambiguity_changes: _,
|
||||
notifications,
|
||||
} = &response;
|
||||
|
||||
self.handle_sync_events(EventKind::GlobalAccountData, &None, &account_data.events).await?;
|
||||
self.handle_sync_events(EventKind::Presence, &None, &presence.events).await?;
|
||||
|
||||
for (room_id, room_info) in &rooms.join {
|
||||
let room = self.get_room(room_id);
|
||||
if room.is_none() {
|
||||
error!("Can't call event handler, room {} not found", room_id);
|
||||
continue;
|
||||
}
|
||||
|
||||
let JoinedRoom { unread_notifications: _, timeline, state, account_data, ephemeral } =
|
||||
room_info;
|
||||
|
||||
self.handle_sync_events(EventKind::EphemeralRoomData, &room, &ephemeral.events).await?;
|
||||
self.handle_sync_events(EventKind::RoomAccountData, &room, &account_data.events)
|
||||
.await?;
|
||||
self.handle_sync_state_events(&room, &state.events).await?;
|
||||
self.handle_sync_timeline_events(&room, &timeline.events).await?;
|
||||
}
|
||||
|
||||
for (room_id, room_info) in &rooms.leave {
|
||||
let room = self.get_room(room_id);
|
||||
if room.is_none() {
|
||||
error!("Can't call event handler, room {} not found", room_id);
|
||||
continue;
|
||||
}
|
||||
|
||||
let LeftRoom { timeline, state, account_data } = room_info;
|
||||
|
||||
self.handle_sync_events(EventKind::RoomAccountData, &room, &account_data.events)
|
||||
.await?;
|
||||
self.handle_sync_state_events(&room, &state.events).await?;
|
||||
self.handle_sync_timeline_events(&room, &timeline.events).await?;
|
||||
}
|
||||
|
||||
for (room_id, room_info) in &rooms.invite {
|
||||
let room = self.get_room(room_id);
|
||||
if room.is_none() {
|
||||
error!("Can't call event handler, room {} not found", room_id);
|
||||
continue;
|
||||
}
|
||||
|
||||
// FIXME: Destructure room_info
|
||||
self.handle_sync_events(
|
||||
EventKind::StrippedState,
|
||||
&room,
|
||||
&room_info.invite_state.events,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
// Construct notification event handler futures
|
||||
let mut futures = Vec::new();
|
||||
for handler in &*self.notification_handlers().await {
|
||||
for (room_id, room_notifications) in notifications {
|
||||
let room = match self.get_room(room_id) {
|
||||
Some(room) => room,
|
||||
None => {
|
||||
warn!("Can't call notification handler, room {} not found", room_id);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
futures.extend(room_notifications.iter().map(|notification| {
|
||||
(handler)(notification.clone(), room.clone(), self.clone())
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
// Run the notification handler futures with the
|
||||
// `self.notification_handlers` lock no longer being held, in order.
|
||||
for fut in futures {
|
||||
fut.await;
|
||||
}
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
pub(crate) async fn sync_loop_helper(
|
||||
&self,
|
||||
sync_settings: &mut crate::config::SyncSettings<'_>,
|
||||
) -> Result<SyncResponse> {
|
||||
let response = self.sync_once(sync_settings.clone()).await;
|
||||
|
||||
match response {
|
||||
Ok(r) => {
|
||||
sync_settings.token = Some(r.next_batch.clone());
|
||||
Ok(r)
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Received an invalid response: {}", e);
|
||||
sleep::new(Duration::from_secs(1)).await;
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn delay_sync(last_sync_time: &mut Option<Instant>) {
|
||||
let now = Instant::now();
|
||||
|
||||
// If the last sync happened less than a second ago, sleep for a
|
||||
// while to not hammer out requests if the server doesn't respect
|
||||
// the sync timeout.
|
||||
if let Some(t) = last_sync_time {
|
||||
if now - *t <= Duration::from_secs(1) {
|
||||
sleep::new(Duration::from_secs(1)).await;
|
||||
}
|
||||
}
|
||||
|
||||
*last_sync_time = Some(now);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user