Rewrite wait_if_user_pending to fix races

This commit is contained in:
Richard van der Hoff
2023-03-07 14:08:35 +00:00
parent 508f1bc976
commit 16871d1dcd
6 changed files with 354 additions and 12 deletions

164
Cargo.lock generated
View File

@@ -233,6 +233,55 @@ dependencies = [
"tokio",
]
[[package]]
name = "async-executor"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "17adb73da160dfb475c183343c8cccd80721ea5a605d3eb57125f0a7b7a92d0b"
dependencies = [
"async-lock",
"async-task",
"concurrent-queue",
"fastrand",
"futures-lite",
"slab",
]
[[package]]
name = "async-global-executor"
version = "2.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f1b6f5d7df27bd294849f8eec66ecfc63d11814df7a4f5d74168a2394467b776"
dependencies = [
"async-channel",
"async-executor",
"async-io",
"async-lock",
"blocking",
"futures-lite",
"once_cell",
]
[[package]]
name = "async-io"
version = "1.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8c374dda1ed3e7d8f0d9ba58715f924862c63eae6849c92d3a18e7fbde9e2794"
dependencies = [
"async-lock",
"autocfg",
"concurrent-queue",
"futures-lite",
"libc",
"log",
"parking",
"polling",
"slab",
"socket2",
"waker-fn",
"windows-sys 0.42.0",
]
[[package]]
name = "async-lock"
version = "2.6.0"
@@ -249,6 +298,51 @@ version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "390a110411bbc7c93b77a736cbd694f64cb06dfa2702173f63169d7a1e1b5298"
[[package]]
name = "async-process"
version = "1.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6381ead98388605d0d9ff86371043b5aa922a3905824244de40dc263a14fcba4"
dependencies = [
"async-io",
"async-lock",
"autocfg",
"blocking",
"cfg-if",
"event-listener",
"futures-lite",
"libc",
"signal-hook",
"windows-sys 0.42.0",
]
[[package]]
name = "async-std"
version = "1.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "62565bb4402e926b29953c785397c6dc0391b7b446e45008b0049eb43cec6f5d"
dependencies = [
"async-channel",
"async-global-executor",
"async-io",
"async-lock",
"async-process",
"crossbeam-utils",
"futures-channel",
"futures-core",
"futures-io",
"futures-lite",
"gloo-timers",
"kv-log-macro",
"log",
"memchr",
"once_cell",
"pin-project-lite",
"pin-utils",
"slab",
"wasm-bindgen-futures",
]
[[package]]
name = "async-stream"
version = "0.3.3"
@@ -270,6 +364,12 @@ dependencies = [
"syn",
]
[[package]]
name = "async-task"
version = "4.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a40729d2133846d9ed0ea60a8b9541bccddab49cd30f0715a1da672fe9a2524"
[[package]]
name = "async-trait"
version = "0.1.64"
@@ -290,6 +390,12 @@ dependencies = [
"autocfg",
]
[[package]]
name = "atomic-waker"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "debc29dde2e69f9e47506b525f639ed42300fc014a3e007832592448fa8e4599"
[[package]]
name = "atty"
version = "0.2.14"
@@ -482,6 +588,20 @@ dependencies = [
"generic-array",
]
[[package]]
name = "blocking"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3c67b173a56acffd6d2326fb7ab938ba0b00a71480e14902b2591c87bc5741e8"
dependencies = [
"async-channel",
"async-lock",
"async-task",
"atomic-waker",
"fastrand",
"futures-lite",
]
[[package]]
name = "bs58"
version = "0.4.0"
@@ -2396,6 +2516,15 @@ version = "0.2.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "984e109462d46ad18314f10e392c286c3d47bce203088a09012de1015b45b737"
[[package]]
name = "kv-log-macro"
version = "1.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0de8b303297635ad57c9f5059fd9cee7a47f8e8daa09df0fcd07dd39fb22977f"
dependencies = [
"log",
]
[[package]]
name = "lazy-regex"
version = "2.4.1"
@@ -2490,6 +2619,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "abb12e687cfb44aa40f41fc3978ef76448f9b6038cad6aef4259d3c095a2382e"
dependencies = [
"cfg-if",
"value-bag",
]
[[package]]
@@ -2730,6 +2860,7 @@ dependencies = [
"aes",
"anyhow",
"assert_matches",
"async-std",
"async-trait",
"atomic",
"base64 0.21.0",
@@ -3804,6 +3935,20 @@ dependencies = [
"miniz_oxide 0.6.2",
]
[[package]]
name = "polling"
version = "2.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "22122d5ec4f9fe1b3916419b76be1e80bcb93f618d071d2edf841b137b2a2bd6"
dependencies = [
"autocfg",
"cfg-if",
"libc",
"log",
"wepoll-ffi",
"windows-sys 0.42.0",
]
[[package]]
name = "poly1305"
version = "0.7.2"
@@ -5825,6 +5970,16 @@ version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d"
[[package]]
name = "value-bag"
version = "1.0.0-alpha.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2209b78d1249f7e6f3293657c9779fe31ced465df091bbd433a1cf88e916ec55"
dependencies = [
"ctor",
"version_check",
]
[[package]]
name = "vcpkg"
version = "0.2.15"
@@ -6053,6 +6208,15 @@ version = "0.1.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9193164d4de03a926d909d3bc7c30543cecb35400c02114792c2cae20d5e2dbb"
[[package]]
name = "wepoll-ffi"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d743fdedc5c64377b5fc2bc036b01c7fd642205a0d96356034ae3404d49eb7fb"
dependencies = [
"cc",
]
[[package]]
name = "which"
version = "4.4.0"

View File

@@ -27,6 +27,7 @@ testing = ["dep:http"]
[dependencies]
aes = "0.8.1"
atomic = "0.5.1"
async-std = { version = "1.12.0", features = ["unstable"] }
async-trait = { workspace = true }
base64 = { workspace = true }
bs58 = { version = "0.4.0", optional = true }

View File

@@ -69,6 +69,9 @@ pub(crate) struct KeysQueryListener {
pub(crate) enum UserKeyQueryResult {
WasPending,
WasNotPending,
/// A query was pending, but we gave up waiting
TimeoutExpired,
}
impl KeysQueryListener {

View File

@@ -1259,9 +1259,7 @@ impl OlmMachine {
async fn wait_if_user_pending(&self, user_id: &UserId, timeout: Option<Duration>) {
if let Some(timeout) = timeout {
let listener = self.identity_manager.listen_for_received_queries();
let _ = listener.wait_if_user_pending(timeout, user_id).await;
self.store.wait_if_user_key_query_pending(timeout, user_id).await;
}
}

View File

@@ -176,11 +176,11 @@ impl SessionManager {
let user_devices = if user_devices.is_empty() {
match self
.keys_query_listener
.wait_if_user_pending(Self::KEYS_QUERY_WAIT_TIME, user_id)
.store
.wait_if_user_key_query_pending(Self::KEYS_QUERY_WAIT_TIME, user_id)
.await
{
Ok(WasPending) => self.store.get_readonly_devices_filtered(user_id).await?,
WasPending => self.store.get_readonly_devices_filtered(user_id).await?,
_ => user_devices,
}
} else {
@@ -404,15 +404,22 @@ mod tests {
use matrix_sdk_common::locks::Mutex;
use matrix_sdk_test::{async_test, response_from_file};
use ruma::{
api::{client::keys::claim_keys::v3::Response as KeyClaimResponse, IncomingResponse},
api::{
client::keys::{
claim_keys::v3::Response as KeyClaimResponse,
get_keys::v3::Response as KeysQueryResponse,
},
IncomingResponse,
},
device_id, user_id, DeviceId, UserId,
};
use serde_json::json;
use tracing::info;
use super::SessionManager;
use crate::{
gossiping::GossipMachine,
identities::{KeysQueryListener, ReadOnlyDevice},
identities::{IdentityManager, KeysQueryListener, ReadOnlyDevice},
olm::{Account, PrivateCrossSigningIdentity, ReadOnlyAccount},
session_manager::GroupSessionCache,
store::{IntoCryptoStore, MemoryStore, Store},
@@ -528,6 +535,56 @@ mod tests {
assert!(manager.get_missing_sessions(iter::once(bob.user_id())).await.unwrap().is_none());
}
#[async_test]
async fn session_creation_waits_for_keys_query() {
let manager = session_manager().await;
let identity_manager = IdentityManager::new(
manager.account.user_id.clone(),
manager.account.device_id.clone(),
manager.store.clone(),
);
// start a keys query request. At this point, we are only interested in our own
// devices.
let (key_query_txn_id, key_query_request) =
identity_manager.users_for_key_query().await.unwrap().pop().unwrap();
info!("Initial key query: {:?}", key_query_request);
// now bob turns up, and we start tracking his devices...
let bob = bob_account();
let bob_device = ReadOnlyDevice::from_account(&bob).await;
manager.store.update_tracked_users(iter::once(bob.user_id())).await.unwrap();
// ... and start off an attempt to get the missing sessions. This should block
// for now.
let missing_sessions_future = manager.get_missing_sessions(iter::once(bob.user_id()));
// the initial keys query completes, and we start another
let response_json = json!({ "device_keys": { manager.account.user_id(): {}}});
let response =
KeysQueryResponse::try_from_http_response(response_from_file(&response_json)).unwrap();
identity_manager.receive_keys_query_response(&key_query_txn_id, &response).await.unwrap();
let (key_query_txn_id, key_query_request) =
identity_manager.users_for_key_query().await.unwrap().pop().unwrap();
info!("Second key query: {:?}", key_query_request);
// that second request completes with info on bob's device
let response_json = json!({ "device_keys": { bob.user_id(): {
bob_device.device_id(): bob_device.as_device_keys()
}}});
let response =
KeysQueryResponse::try_from_http_response(response_from_file(&response_json)).unwrap();
identity_manager.receive_keys_query_response(&key_query_txn_id, &response).await.unwrap();
// the missing_sessions_future should now finally complete, with a claim
// including bob's device
let (_, keys_claim_request) = missing_sessions_future.await.unwrap().unwrap();
//info!("Key claim: {:?}", keys_claim_request);
let bob_key_claims = keys_claim_request.one_time_keys.get(bob.user_id()).unwrap();
assert!(bob_key_claims.contains_key(bob_device.device_id()));
}
// This test doesn't run on macos because we're modifying the session
// creation time so we can get around the UNWEDGING_INTERVAL.
#[async_test]

View File

@@ -42,9 +42,11 @@ use std::{
collections::{HashMap, HashSet},
fmt::Debug,
ops::Deref,
sync::{atomic::AtomicBool, Arc},
sync::{atomic::AtomicBool, Arc, Weak},
time::Duration,
};
use async_std::sync::{Condvar, Mutex as AsyncStdMutex};
use atomic::Ordering;
use dashmap::DashSet;
use matrix_sdk_common::locks::Mutex;
@@ -80,10 +82,12 @@ mod traits;
pub mod integration_tests;
pub use error::{CryptoStoreError, Result};
use matrix_sdk_common::timeout::timeout;
pub use memorystore::MemoryStore;
pub use traits::{CryptoStore, DynCryptoStore, IntoCryptoStore};
pub use crate::gossiping::{GossipRequest, SecretInfo};
use crate::identities::UserKeyQueryResult;
/// A wrapper for our CryptoStore trait object.
///
@@ -98,7 +102,17 @@ pub(crate) struct Store {
inner: Arc<DynCryptoStore>,
verification_machine: VerificationMachine,
tracked_users_cache: Arc<DashSet<OwnedUserId>>,
users_for_key_query: Arc<Mutex<UsersForKeyQuery>>,
/// Record of the users that are waiting for a /keys/query.
//
// This uses an async_std::sync::Mutex rather than a
// matrix_sdk_common::locks::Mutex because it has to match the Condvar (and tokio lacks a
// working Condvar implementation)
users_for_key_query: Arc<AsyncStdMutex<UsersForKeyQuery>>,
// condition variable that is notified each time an update is received for a user.
users_for_key_query_condvar: Arc<Condvar>,
tracked_user_loading_lock: Arc<Mutex<()>>,
tracked_users_loaded: Arc<AtomicBool>,
}
@@ -118,6 +132,12 @@ struct UsersForKeyQuery {
/// The users pending a lookup, together with the sequence number at which
/// they were added to the list
user_map: HashMap<OwnedUserId, InvalidationSequenceNumber>,
/// A list of tasks waiting for key queries to complete.
///
/// We expect this list to remain fairly short, so don't bother partitioning
/// by user.
tasks_awaiting_key_query: Vec<Weak<KeysQueryWaiter>>,
}
// We use wrapping arithmetic for the sequence numbers, to make sure we never
@@ -133,7 +153,11 @@ type InvalidationSequenceNumber = i64;
impl UsersForKeyQuery {
/// Create a new, empty, `UsersForKeyQueryCache`
fn new() -> Self {
UsersForKeyQuery { next_sequence_number: 0, user_map: HashMap::new() }
UsersForKeyQuery {
next_sequence_number: 0,
user_map: HashMap::new(),
tasks_awaiting_key_query: Vec::new(),
}
}
/// Record a new user that requires a key query
@@ -159,6 +183,24 @@ impl UsersForKeyQuery {
) -> bool {
let last_invalidation = self.user_map.get(user);
// if there were any jobs waiting for this key query to complete, we can flag
// them as completed and remove them from our list.
// we also clear out any tasks that have been cancelled.
self.tasks_awaiting_key_query.retain(|waiter| {
let Some(waiter) = waiter.upgrade() else {
// the TaskAwaitingKeyQuery has been dropped, so it probably timed out and the
// caller went away. We can remove it from our list whether or not it's for this
// user.
return false;
};
if waiter.user == user && waiter.sequence_number.wrapping_sub(query_sequence) <= 0 {
waiter.completed.store(true, Ordering::Relaxed);
false
} else {
true
}
});
if let Some(invalidation_sequence) = last_invalidation {
Span::current().record("invalidation_sequence", invalidation_sequence);
if invalidation_sequence.wrapping_sub(query_sequence) > 0 {
@@ -182,6 +224,44 @@ impl UsersForKeyQuery {
let sequence_number = self.next_sequence_number.wrapping_sub(1);
(self.user_map.keys().cloned().collect(), sequence_number)
}
/// Check if a key query is pending for a user, and register for a wakeup if
/// so.
///
/// If no key query is currently pending, returns `None`. Otherwise, returns
/// (an `Arc` to) a `KeysQueryWaiter`, whose `completed` flag will
/// be set once the lookup completes.
fn maybe_register_waiting_task(&mut self, user: &UserId) -> Option<Arc<KeysQueryWaiter>> {
match self.user_map.get(user) {
None => None,
Some(&sequence_number) => {
let waiter = Arc::new(KeysQueryWaiter {
sequence_number,
user: user.to_owned(),
completed: AtomicBool::new(false),
});
self.tasks_awaiting_key_query.push(Arc::downgrade(&waiter));
Some(waiter)
}
}
}
}
/// Information on a task which is waiting for a `/keys/query` to complete.
#[derive(Debug)]
struct KeysQueryWaiter {
/// The user that we are waiting for
user: OwnedUserId,
/// The sequence number of the last invalidation of the users's device list
/// when we started waiting (ie, any `/keys/query` result with the same or
/// greater sequence number will satisfy this waiter)
sequence_number: InvalidationSequenceNumber,
/// Whether the `/keys/query` has completed.
///
/// This is only modified whilst holding the mutex on `users_for_key_query`.
completed: AtomicBool,
}
#[derive(Default, Debug)]
@@ -371,7 +451,8 @@ impl Store {
inner: store,
verification_machine,
tracked_users_cache: DashSet::new().into(),
users_for_key_query: Mutex::new(UsersForKeyQuery::new()).into(),
users_for_key_query: AsyncStdMutex::new(UsersForKeyQuery::new()).into(),
users_for_key_query_condvar: Condvar::new().into(),
tracked_users_loaded: AtomicBool::new(false).into(),
tracked_user_loading_lock: Mutex::new(()).into(),
}
@@ -770,6 +851,8 @@ impl Store {
}
}
self.inner.save_tracked_users(&store_updates).await?;
// wake up any tasks that may have been waiting for updates
self.users_for_key_query_condvar.notify_all();
Ok(())
}
@@ -826,6 +909,42 @@ impl Store {
Ok(self.users_for_key_query.lock().await.users_for_key_query())
}
/// Wait for a `/keys/query` response to be received if one is expected for
/// the given user.
///
/// If the given timeout elapses, the method will stop waiting and return
/// `UserKeyQueryResult::TimeoutExpired`
pub async fn wait_if_user_key_query_pending(
&self,
timeout_duration: Duration,
user: &UserId,
) -> UserKeyQueryResult {
let mut g = self.users_for_key_query.lock().await;
let Some(w) = g.maybe_register_waiting_task(user) else {
return UserKeyQueryResult::WasNotPending;
};
let f1 = async {
while !w.completed.load(Ordering::Relaxed) {
g = self.users_for_key_query_condvar.wait(g).await;
}
};
match timeout(Box::pin(f1), timeout_duration).await {
Err(_) => {
warn!(
user_id = ?user,
"The user has a pending `/key/query` request which did \
not finish yet, some devices might be missing."
);
UserKeyQueryResult::TimeoutExpired
}
_ => UserKeyQueryResult::WasPending,
}
}
/// See the docs for [`crate::OlmMachine::tracked_users()`].
pub async fn tracked_users(&self) -> Result<HashSet<OwnedUserId>> {
self.load_tracked_users().await?;