Ensure correct Olm session selection for encrypted messages

In some cases, restoring client state from backups may cause Olm sessions
to become corrupted, resulting in a backward ratchet. As a consequence,
the receiving side is unable to decrypt messages. To address this issue,
the failed side initiates a new Olm session and sends a dummy encrypted
message.

However, this dummy message also creates a new Olm session on the
sender's side. To ensure that both sides use the same new session, we
must sort sessions by their creation timestamp before encrypting new
messages.

Although the session list was sorted correctly previously, we selected
the wrong side of the list, resulting in an outdated session being
selected instead of the newest one. This patch resolves the issue by
selecting the newest Olm session and adds a test to the session getter
logic to prevent reintroduction of this bug.
This commit is contained in:
Damir Jelić
2023-03-31 20:12:02 +02:00
parent aeda2d363c
commit b7a78c865c
2 changed files with 151 additions and 43 deletions

View File

@@ -335,6 +335,11 @@ impl Device {
self.verification_machine.store.get_sessions(&k.to_base64()).await
}
#[cfg(test)]
pub(crate) async fn get_most_recent_session(&self) -> OlmResult<Option<Session>> {
self.inner.get_most_recent_session(self.verification_machine.store.inner()).await
}
/// Is this device considered to be verified.
///
/// This method returns true if either [`is_locally_trusted()`] returns true
@@ -612,6 +617,34 @@ impl ReadOnlyDevice {
}
}
/// Find and return the most recently created Olm [`Session`] we are sharing
/// with this device.
pub(crate) async fn get_most_recent_session(
&self,
store: &DynCryptoStore,
) -> OlmResult<Option<Session>> {
if let Some(sender_key) = self.curve25519_key() {
if let Some(s) = store.get_sessions(&sender_key.to_base64()).await? {
let mut sessions = s.lock().await;
sessions.sort_by_key(|s| s.creation_time);
Ok(sessions.last().cloned())
} else {
Ok(None)
}
} else {
warn!(
user_id = ?self.user_id(),
device_id = ?self.device_id(),
"Trying to find a Olm session of a device, but the device doesn't have a \
Curve25519 key",
);
Err(EventError::MissingSenderKey.into())
}
}
/// Does this device support the olm.v2.curve25519-aes-sha2 encryption
/// algorithm.
#[cfg(feature = "experimental-algorithms")]
@@ -680,44 +713,29 @@ impl ReadOnlyDevice {
event_type: &str,
content: Value,
) -> OlmResult<(Session, Raw<ToDeviceEncryptedEventContent>)> {
let Some(sender_key) = self.curve25519_key() else {
warn!(
let session = self.get_most_recent_session(store).await?;
if let Some(mut session) = session {
let message = session.encrypt(self, event_type, content).await?;
trace!(
user_id = ?self.user_id(),
device_id = ?self.device_id(),
"Trying to encrypt a Megolm session, but the device doesn't \
have a curve25519 key",
session_id = session.session_id(),
"Successfully encrypted a Megolm session",
);
return Err(EventError::MissingSenderKey.into());
};
let session = if let Some(s) = store.get_sessions(&sender_key.to_base64()).await? {
let mut sessions = s.lock().await;
sessions.sort_by_key(|s| s.last_use_time);
sessions.get(0).cloned()
Ok((session, message))
} else {
None
};
let Some(mut session) = session else {
warn!(
"Trying to encrypt a Megolm session for user {} on device {}, \
but no Olm session is found",
self.user_id(),
self.device_id()
);
return Err(OlmError::MissingSession);
};
let message = session.encrypt(self, event_type, content).await?;
trace!(
user_id = ?self.user_id(),
device_id = ?self.device_id(),
session_id = session.session_id(),
"Successfully encrypted a Megolm session",
);
Ok((session, message))
Err(OlmError::MissingSession)
}
}
/// Update a device with a new device keys struct.

View File

@@ -1663,7 +1663,12 @@ pub(crate) mod testing {
#[cfg(test)]
pub(crate) mod tests {
use std::{collections::BTreeMap, iter, sync::Arc};
use std::{
collections::BTreeMap,
iter,
sync::Arc,
time::{Duration, SystemTime},
};
use assert_matches::assert_matches;
use futures::select;
@@ -1695,7 +1700,7 @@ pub(crate) mod tests {
room_id,
serde::Raw,
uint, user_id, DeviceId, DeviceKeyAlgorithm, DeviceKeyId, MilliSecondsSinceUnixEpoch,
OwnedDeviceKeyId, TransactionId, UserId,
OwnedDeviceKeyId, SecondsSinceUnixEpoch, TransactionId, UserId,
};
use serde_json::json;
use vodozemac::{
@@ -2027,23 +2032,34 @@ pub(crate) mod tests {
assert!(user_sessions.contains_key(alice_device));
}
#[async_test]
async fn test_session_creation() {
let (alice_machine, bob_machine, one_time_keys) = get_machine_pair().await;
let mut bob_keys = BTreeMap::new();
let (device_key_id, one_time_key) = one_time_keys.iter().next().unwrap();
let mut keys = BTreeMap::new();
keys.insert(device_key_id.clone(), one_time_key.clone());
bob_keys.insert(bob_machine.device_id().into(), keys);
let mut one_time_keys = BTreeMap::new();
one_time_keys.insert(bob_machine.user_id().to_owned(), bob_keys);
async fn create_session(
machine: &OlmMachine,
user_id: &UserId,
device_id: &DeviceId,
key_id: OwnedDeviceKeyId,
one_time_key: Raw<OneTimeKey>,
) {
let keys = BTreeMap::from([(key_id, one_time_key)]);
let keys = BTreeMap::from([(device_id.to_owned(), keys)]);
let one_time_keys = BTreeMap::from([(user_id.to_owned(), keys)]);
let response = claim_keys::v3::Response::new(one_time_keys);
alice_machine.receive_keys_claim_response(&response).await.unwrap();
machine.receive_keys_claim_response(&response).await.unwrap();
}
#[async_test]
async fn test_session_creation() {
let (alice_machine, bob_machine, mut one_time_keys) = get_machine_pair().await;
let (device_key_id, one_time_key) = one_time_keys.pop_first().unwrap();
create_session(
&alice_machine,
bob_machine.user_id(),
bob_machine.device_id(),
device_key_id,
one_time_key,
)
.await;
let session = alice_machine
.store
@@ -2055,6 +2071,80 @@ pub(crate) mod tests {
assert!(!session.lock().await.is_empty())
}
#[async_test]
async fn getting_most_recent_session() {
let (alice_machine, bob_machine, mut one_time_keys) = get_machine_pair().await;
let (device_key_id, one_time_key) = one_time_keys.pop_first().unwrap();
let device = alice_machine
.get_device(bob_machine.user_id(), bob_machine.device_id(), None)
.await
.unwrap()
.unwrap();
assert!(device.get_most_recent_session().await.unwrap().is_none());
create_session(
&alice_machine,
bob_machine.user_id(),
bob_machine.device_id(),
device_key_id,
one_time_key.to_owned(),
)
.await;
for _ in 0..10 {
let (device_key_id, one_time_key) = one_time_keys.pop_first().unwrap();
create_session(
&alice_machine,
bob_machine.user_id(),
bob_machine.device_id(),
device_key_id,
one_time_key.to_owned(),
)
.await;
}
// Since the sessions are created quickly in succession and our timestamps have
// a resolution in seconds, it's very likely that we're going to end up
// with the same timestamps, so we manually masage them to be 10s apart.
let session_id = {
let sessions = alice_machine
.store
.get_sessions(&bob_machine.identity_keys().curve25519.to_base64())
.await
.unwrap()
.unwrap();
let mut use_time = SystemTime::now();
let mut sessions = sessions.lock().await;
let mut session_id = None;
// Iterate through the sessions skipping the first and last element so we know
// that the correct session isn't the first nor the last one.
let (_, sessions_slice) = sessions.as_mut_slice().split_last_mut().unwrap();
for session in sessions_slice.iter_mut().skip(1) {
session.creation_time = SecondsSinceUnixEpoch::from_system_time(use_time).unwrap();
use_time += Duration::from_secs(10);
session_id = Some(session.session_id().to_owned());
}
session_id.unwrap()
};
let newest_session = device.get_most_recent_session().await.unwrap().unwrap();
assert_eq!(
newest_session.session_id(),
session_id,
"The session we found is the one that was most recently created"
);
}
#[async_test]
async fn test_olm_encryption() {
let (alice, bob) = get_machine_pair_with_session().await;