refactor: Introduce more early returns to reduce rightwards drift

This commit is contained in:
Jonas Platte
2022-11-04 13:38:36 +01:00
committed by Jonas Platte
parent e59acfe28c
commit a30e40ed3a
17 changed files with 571 additions and 622 deletions

View File

@@ -449,7 +449,7 @@ impl BackupMachine {
.collect();
for session in &sessions {
session.mark_as_backed_up()
session.mark_as_backed_up();
}
trace!(request_id = ?r.request_id, keys = ?r.sessions, "Marking room keys as backed up");
@@ -470,54 +470,54 @@ impl BackupMachine {
expected = r.request_id.to_string().as_str(),
got = request_id.to_string().as_str(),
"Tried to mark a pending backup as sent but the request id didn't match"
)
);
}
} else {
warn!(
request_id = request_id.to_string().as_str(),
"Tried to mark a pending backup as sent but there isn't a backup pending"
);
}
};
Ok(())
}
async fn backup_helper(&self) -> Result<Option<PendingBackup>, CryptoStoreError> {
if let Some(backup_key) = &*self.backup_key.read().await {
if let Some(version) = backup_key.backup_version() {
let sessions =
self.store.inbound_group_sessions_for_backup(Self::BACKUP_BATCH_SIZE).await?;
if !sessions.is_empty() {
let key_count = sessions.len();
let (backup, session_record) = Self::backup_keys(sessions, backup_key).await;
info!(
key_count = key_count,
keys = ?session_record,
?backup_key,
"Successfully created a room keys backup request"
);
let request = PendingBackup {
request_id: TransactionId::new(),
request: KeysBackupRequest { version, rooms: backup },
sessions: session_record,
};
Ok(Some(request))
} else {
trace!(?backup_key, "No room keys need to be backed up");
Ok(None)
}
} else {
warn!("Trying to backup room keys but the backup key wasn't uploaded");
Ok(None)
}
} else {
let Some(backup_key) = &*self.backup_key.read().await else {
warn!("Trying to backup room keys but no backup key was found");
Ok(None)
return Ok(None);
};
let Some(version) = backup_key.backup_version() else {
warn!("Trying to backup room keys but the backup key wasn't uploaded");
return Ok(None);
};
let sessions =
self.store.inbound_group_sessions_for_backup(Self::BACKUP_BATCH_SIZE).await?;
if sessions.is_empty() {
trace!(?backup_key, "No room keys need to be backed up");
return Ok(None);
}
let key_count = sessions.len();
let (backup, session_record) = Self::backup_keys(sessions, backup_key).await;
info!(
key_count = key_count,
keys = ?session_record,
?backup_key,
"Successfully created a room keys backup request"
);
let request = PendingBackup {
request_id: TransactionId::new(),
request: KeysBackupRequest { version, rooms: backup },
sessions: session_record,
};
Ok(Some(request))
}
/// Backup all the non-backed up room keys we know about

View File

@@ -315,68 +315,7 @@ impl GossipMachine {
let device =
self.store.get_device(&event.sender, &event.content.requesting_device_id).await?;
if let Some(device) = device {
match self.should_share_key(&device, &session).await {
Err(e) => {
if let KeyForwardDecision::ChangedSenderKey = e {
warn!(
user_id = device.user_id().as_str(),
device_id = device.device_id().as_str(),
"Received a key request from a device that changed \
their Curve25519 sender key"
);
} else {
debug!(
user_id = device.user_id().as_str(),
device_id = device.device_id().as_str(),
reason = ?e,
"Received a key request that we won't serve",
);
}
Ok(None)
}
Ok(message_index) => {
info!(
user_id = %device.user_id(),
device_id = %device.device_id(),
session_id = session.session_id(),
room_id = %session.room_id,
?message_index,
"Serving a room key request",
);
match self.forward_room_key(&session, &device, message_index).await {
Ok(s) => Ok(Some(s)),
Err(OlmError::MissingSession) => {
info!(
user_id = %device.user_id(),
device_id = %device.device_id(),
session_id = session.session_id(),
"Key request is missing an Olm session, \
putting the request in the wait queue",
);
self.handle_key_share_without_session(device, event.to_owned().into());
Ok(None)
}
Err(OlmError::SessionExport(e)) => {
warn!(
user_id = %device.user_id(),
device_id = %device.device_id(),
session_id = session.session_id(),
"Can't serve a room key request, the session \
can't be exported into a forwarded room key: \
{:?}",
e
);
Ok(None)
}
Err(e) => Err(e),
}
}
}
} else {
let Some(device) = device else {
warn!(
user_id = %event.sender,
device_id = %event.content.requesting_device_id,
@@ -384,7 +323,66 @@ impl GossipMachine {
);
self.store.update_tracked_user(&event.sender, true).await?;
Ok(None)
return Ok(None);
};
let message_index = match self.should_share_key(&device, &session).await {
Ok(message_index) => message_index,
Err(e) => {
if let KeyForwardDecision::ChangedSenderKey = e {
warn!(
user_id = device.user_id().as_str(),
device_id = device.device_id().as_str(),
"Received a key request from a device that changed \
their Curve25519 sender key"
);
} else {
debug!(
user_id = device.user_id().as_str(),
device_id = device.device_id().as_str(),
reason = ?e,
"Received a key request that we won't serve",
);
}
return Ok(None);
}
};
info!(
user_id = %device.user_id(),
device_id = %device.device_id(),
session_id = session.session_id(),
room_id = %session.room_id,
?message_index,
"Serving a room key request",
);
match self.forward_room_key(&session, &device, message_index).await {
Ok(s) => Ok(Some(s)),
Err(OlmError::MissingSession) => {
info!(
user_id = %device.user_id(),
device_id = %device.device_id(),
session_id = session.session_id(),
"Key request is missing an Olm session, \
putting the request in the wait queue",
);
self.handle_key_share_without_session(device, event.to_owned().into());
Ok(None)
}
Err(OlmError::SessionExport(e)) => {
warn!(
user_id = %device.user_id(),
device_id = %device.device_id(),
session_id = session.session_id(),
"Can't serve a room key request, the session \
can't be exported into a forwarded room key: {e:?}",
);
Ok(None)
}
Err(e) => Err(e),
}
}
@@ -919,25 +917,18 @@ impl GossipMachine {
sender_key: Curve25519PublicKey,
event: &DecryptedForwardedRoomKeyEvent,
) -> Result<Option<InboundGroupSession>, CryptoStoreError> {
if let Some(info) = event.room_key_info() {
if let Some(request) =
self.store.get_secret_request_by_info(&info.clone().into()).await?
{
if self.should_accept_forward(&request, sender_key).await? {
self.accept_forwarded_room_key(&request, sender_key, event).await
} else {
warn!(
sender = %event.sender,
%sender_key,
room_id = %info.room_id(),
session_id = info.session_id(),
"Received a forwarded room key from an unknown device, or \
from a device that the key request recipient doesn't own",
);
let Some(info) = event.room_key_info() else {
warn!(
sender = event.sender.as_str(),
sender_key = sender_key.to_base64(),
algorithm = %event.content.algorithm(),
"Received a forwarded room key with an unsupported algorithm",
);
return Ok(None);
};
Ok(None)
}
} else {
let Some(request) =
self.store.get_secret_request_by_info(&info.clone().into()).await? else {
warn!(
sender = %event.sender,
sender_key = %sender_key,
@@ -947,15 +938,19 @@ impl GossipMachine {
algorithm = %info.algorithm(),
"Received a forwarded room key that we didn't request",
);
return Ok(None);
};
Ok(None)
}
if self.should_accept_forward(&request, sender_key).await? {
self.accept_forwarded_room_key(&request, sender_key, event).await
} else {
warn!(
sender = event.sender.as_str(),
sender_key = sender_key.to_base64(),
algorithm = %event.content.algorithm(),
"Received a forwarded room key with an unsupported algorithm",
sender = %event.sender,
%sender_key,
room_id = %info.room_id(),
session_id = info.session_id(),
"Received a forwarded room key from an unknown device, or \
from a device that the key request recipient doesn't own",
);
Ok(None)

View File

@@ -208,11 +208,8 @@ impl Device {
/// Get the Olm sessions that belong to this device.
pub(crate) async fn get_sessions(&self) -> StoreResult<Option<Arc<Mutex<Vec<Session>>>>> {
if let Some(k) = self.curve25519_key() {
self.verification_machine.store.get_sessions(&k.to_base64()).await
} else {
Ok(None)
}
let Some(k) = self.curve25519_key() else { return Ok(None) };
self.verification_machine.store.get_sessions(&k.to_base64()).await
}
/// Is this device considered to be verified.

View File

@@ -147,14 +147,15 @@ fn verify_signature(
signatures: &Signatures,
canonical_json: &str,
) -> Result<(), SignatureError> {
if let Some(s) = signatures.get(user_id).and_then(|m| m.get(key_id)) {
match s {
Ok(Signature::Ed25519(s)) => Ok(public_key.verify(canonical_json.as_bytes(), s)?),
Ok(Signature::Other(_)) => Err(SignatureError::UnsupportedAlgorithm),
Err(_) => Err(SignatureError::InvalidSignature),
}
} else {
Err(SignatureError::NoSignatureFound)
let s = signatures
.get(user_id)
.and_then(|m| m.get(key_id))
.ok_or(SignatureError::NoSignatureFound)?;
match s {
Ok(Signature::Ed25519(s)) => Ok(public_key.verify(canonical_json.as_bytes(), s)?),
Ok(Signature::Other(_)) => Err(SignatureError::UnsupportedAlgorithm),
Err(_) => Err(SignatureError::InvalidSignature),
}
}

View File

@@ -341,180 +341,186 @@ impl VerificationMachine {
}
};
if let Some(content) = event.verification_content() {
match &content {
AnyVerificationContent::Request(r) => {
info!(
let Some(content) = event.verification_content() else { return Ok(()) };
match &content {
AnyVerificationContent::Request(r) => {
info!(
sender = event.sender().as_str(),
from_device = r.from_device().as_str(),
"Received a new verification request",
);
let Some(timestamp) = event.timestamp() else {
warn!(
sender = event.sender().as_str(),
from_device = r.from_device().as_str(),
"Received a new verification request",
"The key verification request didn't contain a valid timestamp"
);
return Ok(());
};
if let Some(timestamp) = event.timestamp() {
if Self::is_timestamp_valid(timestamp) {
if !event_sent_from_us(&event, r.from_device()) {
let request = VerificationRequest::from_request(
self.verifications.clone(),
self.store.clone(),
event.sender(),
flow_id,
r,
);
self.insert_request(request);
} else {
trace!(
sender = event.sender().as_str(),
from_device = r.from_device().as_str(),
"The received verification request was sent by us, ignoring it",
);
}
} else {
trace!(
sender = event.sender().as_str(),
from_device = r.from_device().as_str(),
?timestamp,
"The received verification request was too old or too far into the future",
);
}
} else {
warn!(
sender = event.sender().as_str(),
from_device = r.from_device().as_str(),
"The key verification request didn't contain a valid timestamp"
);
}
if !Self::is_timestamp_valid(timestamp) {
trace!(
sender = event.sender().as_str(),
from_device = r.from_device().as_str(),
?timestamp,
"The received verification request was too old or too far into the future",
);
return Ok(());
}
AnyVerificationContent::Cancel(c) => {
if let Some(verification) = self.get_request(event.sender(), flow_id.as_str()) {
verification.receive_cancel(event.sender(), c);
}
if let Some(verification) =
self.get_verification(event.sender(), flow_id.as_str())
{
match verification {
Verification::SasV1(sas) => {
// This won't produce an outgoing content
let _ = sas.receive_any_event(event.sender(), &content);
}
#[cfg(feature = "qrcode")]
Verification::QrV1(qr) => qr.receive_cancel(event.sender(), c),
}
}
if event_sent_from_us(&event, r.from_device()) {
trace!(
sender = event.sender().as_str(),
from_device = r.from_device().as_str(),
"The received verification request was sent by us, ignoring it",
);
return Ok(());
}
AnyVerificationContent::Ready(c) => {
if let Some(request) = self.get_request(event.sender(), flow_id.as_str()) {
if request.flow_id() == &flow_id {
request.receive_ready(event.sender(), c);
} else {
flow_id_mismatch();
}
}
}
AnyVerificationContent::Start(c) => {
if let Some(request) = self.get_request(event.sender(), flow_id.as_str()) {
if request.flow_id() == &flow_id {
request.receive_start(event.sender(), c).await?
} else {
flow_id_mismatch();
}
} else if let FlowId::ToDevice(_) = flow_id {
// TODO remove this soon, this has been deprecated by
// MSC3122 https://github.com/matrix-org/matrix-doc/pull/3122
if let Some(device) =
self.store.get_device(event.sender(), c.from_device()).await?
{
let identities = self.store.get_identities(device).await?;
match Sas::from_start_event(flow_id, c, identities, None, false) {
Ok(sas) => {
self.verifications.insert_sas(sas);
}
Err(cancellation) => self.queue_up_content(
event.sender(),
c.from_device(),
cancellation,
None,
),
}
}
}
}
AnyVerificationContent::Accept(_) | AnyVerificationContent::Key(_) => {
if let Some(sas) = self.get_sas(event.sender(), flow_id.as_str()) {
if sas.flow_id() == &flow_id {
if let Some((content, request_info)) =
sas.receive_any_event(event.sender(), &content)
{
self.queue_up_content(
sas.other_user_id(),
sas.other_device_id(),
content,
request_info,
);
}
} else {
flow_id_mismatch();
}
}
}
AnyVerificationContent::Mac(_) => {
if let Some(s) = self.get_sas(event.sender(), flow_id.as_str()) {
if s.flow_id() == &flow_id {
let content = s.receive_any_event(event.sender(), &content);
let request = VerificationRequest::from_request(
self.verifications.clone(),
self.store.clone(),
event.sender(),
flow_id,
r,
);
if s.is_done() {
self.mark_sas_as_done(s, content.map(|(c, _)| c)).await?;
} else {
// Even if we are not done (yet), there might be content to send
// out, e.g. in the case where we are done with our side of the
// verification process, but the other side has not yet sent their
// "done".
if let Some((content, request_id)) = content {
self.queue_up_content(
s.other_user_id(),
s.other_device_id(),
content,
request_id,
);
}
}
} else {
flow_id_mismatch();
}
}
self.insert_request(request);
}
AnyVerificationContent::Cancel(c) => {
if let Some(verification) = self.get_request(event.sender(), flow_id.as_str()) {
verification.receive_cancel(event.sender(), c);
}
AnyVerificationContent::Done(c) => {
if let Some(verification) = self.get_request(event.sender(), flow_id.as_str()) {
verification.receive_done(event.sender(), c);
}
#[allow(clippy::single_match)]
match self.get_verification(event.sender(), flow_id.as_str()) {
Some(Verification::SasV1(sas)) => {
let content = sas.receive_any_event(event.sender(), &content);
if sas.is_done() {
self.mark_sas_as_done(sas, content.map(|(c, _)| c)).await?;
}
if let Some(verification) = self.get_verification(event.sender(), flow_id.as_str())
{
match verification {
Verification::SasV1(sas) => {
// This won't produce an outgoing content
let _ = sas.receive_any_event(event.sender(), &content);
}
#[cfg(feature = "qrcode")]
Some(Verification::QrV1(qr)) => {
let (cancellation, request) = qr.receive_done(c).await?;
if let Some(c) = cancellation {
self.verifications.add_request(c.into())
}
if let Some(s) = request {
self.verifications.add_request(s.into())
}
}
None => (),
Verification::QrV1(qr) => qr.receive_cancel(event.sender(), c),
}
}
}
AnyVerificationContent::Ready(c) => {
let Some(request) = self.get_request(event.sender(), flow_id.as_str()) else {
return Ok(());
};
if request.flow_id() == &flow_id {
request.receive_ready(event.sender(), c);
} else {
flow_id_mismatch();
}
}
AnyVerificationContent::Start(c) => {
if let Some(request) = self.get_request(event.sender(), flow_id.as_str()) {
if request.flow_id() == &flow_id {
request.receive_start(event.sender(), c).await?
} else {
flow_id_mismatch();
}
} else if let FlowId::ToDevice(_) = flow_id {
// TODO remove this soon, this has been deprecated by
// MSC3122 https://github.com/matrix-org/matrix-doc/pull/3122
if let Some(device) =
self.store.get_device(event.sender(), c.from_device()).await?
{
let identities = self.store.get_identities(device).await?;
match Sas::from_start_event(flow_id, c, identities, None, false) {
Ok(sas) => {
self.verifications.insert_sas(sas);
}
Err(cancellation) => self.queue_up_content(
event.sender(),
c.from_device(),
cancellation,
None,
),
}
}
}
}
AnyVerificationContent::Accept(_) | AnyVerificationContent::Key(_) => {
let Some(sas) = self.get_sas(event.sender(), flow_id.as_str()) else {
return Ok(());
};
if sas.flow_id() != &flow_id {
flow_id_mismatch();
return Ok(());
}
let Some((content, request_info)) =
sas.receive_any_event(event.sender(), &content) else { return Ok(()) };
self.queue_up_content(
sas.other_user_id(),
sas.other_device_id(),
content,
request_info,
);
}
AnyVerificationContent::Mac(_) => {
let Some(s) = self.get_sas(event.sender(), flow_id.as_str()) else { return Ok(()) };
if s.flow_id() != &flow_id {
flow_id_mismatch();
return Ok(());
}
let content = s.receive_any_event(event.sender(), &content);
if s.is_done() {
self.mark_sas_as_done(s, content.map(|(c, _)| c)).await?;
} else {
// Even if we are not done (yet), there might be content to
// send out, e.g. in the case where we are done with our
// side of the verification process, but the other side has
// not yet sent their "done".
let Some((content, request_id)) = content else { return Ok(()) };
self.queue_up_content(
s.other_user_id(),
s.other_device_id(),
content,
request_id,
);
}
}
AnyVerificationContent::Done(c) => {
if let Some(verification) = self.get_request(event.sender(), flow_id.as_str()) {
verification.receive_done(event.sender(), c);
}
#[allow(clippy::single_match)]
match self.get_verification(event.sender(), flow_id.as_str()) {
Some(Verification::SasV1(sas)) => {
let content = sas.receive_any_event(event.sender(), &content);
if sas.is_done() {
self.mark_sas_as_done(sas, content.map(|(c, _)| c)).await?;
}
}
#[cfg(feature = "qrcode")]
Some(Verification::QrV1(qr)) => {
let (cancellation, request) = qr.receive_done(c).await?;
if let Some(c) = cancellation {
self.verifications.add_request(c.into())
}
if let Some(s) = request {
self.verifications.add_request(s.into())
}
}
None => {}
}
}
}
Ok(())

View File

@@ -701,45 +701,41 @@ impl IdentitiesBeingVerified {
) -> Result<Option<ReadOnlyDevice>, CryptoStoreError> {
let device = self.store.get_device(self.other_user_id(), self.other_device_id()).await?;
if let Some(device) = device {
if device.keys() == self.device_being_verified.keys() {
if verified_devices.map_or(false, |v| v.contains(&device)) {
trace!(
user_id = device.user_id().as_str(),
device_id = device.device_id().as_str(),
"Marking device as verified.",
);
device.set_trust_state(LocalTrust::Verified);
Ok(Some(device))
} else {
info!(
user_id = device.user_id().as_str(),
device_id = device.device_id().as_str(),
"The interactive verification process didn't verify \
the device",
);
Ok(None)
}
} else {
warn!(
user_id = device.user_id().as_str(),
device_id = device.device_id().as_str(),
"The device keys have changed while an interactive \
verification was going on, not marking the device as verified.",
);
Ok(None)
}
} else {
let Some(device) = device else {
let device = &self.device_being_verified;
info!(
user_id = device.user_id().as_str(),
device_id = device.device_id().as_str(),
"The device was deleted while an interactive verification was \
going on.",
"The device was deleted while an interactive verification was going on.",
);
return Ok(None);
};
if device.keys() != self.device_being_verified.keys() {
warn!(
user_id = device.user_id().as_str(),
device_id = device.device_id().as_str(),
"The device keys have changed while an interactive verification \
was going on, not marking the device as verified.",
);
return Ok(None);
}
if verified_devices.map_or(false, |v| v.contains(&device)) {
trace!(
user_id = device.user_id().as_str(),
device_id = device.device_id().as_str(),
"Marking device as verified.",
);
device.set_trust_state(LocalTrust::Verified);
Ok(Some(device))
} else {
info!(
user_id = device.user_id().as_str(),
device_id = device.device_id().as_str(),
"The interactive verification process didn't verify the device",
);
Ok(None)

View File

@@ -322,8 +322,8 @@ impl VerificationRequest {
&self,
data: QrVerificationData,
) -> Result<Option<QrVerification>, ScanError> {
let fut = if let InnerRequest::Ready(r) = &*self.inner.lock().unwrap() {
Some(QrVerification::from_scan(
let future = if let InnerRequest::Ready(r) = &*self.inner.lock().unwrap() {
QrVerification::from_scan(
r.store.clone(),
r.other_user_id.clone(),
r.state.other_device_id.clone(),
@@ -331,41 +331,37 @@ impl VerificationRequest {
data,
self.we_started,
Some(self.inner.clone().into()),
))
)
} else {
None
return Ok(None);
};
if let Some(future) = fut {
let qr_verification = future.await?;
let qr_verification = future.await?;
// We may have previously started our own QR verification (e.g. two devices
// displaying QR code at the same time), so we need to replace it with the newly
// scanned code.
if self
.verification_cache
.get_qr(qr_verification.other_user_id(), qr_verification.flow_id().as_str())
.is_some()
{
debug!(
user_id = %self.other_user(),
flow_id = self.flow_id().as_str(),
"Replacing existing QR verification"
);
self.verification_cache.replace_qr(qr_verification.clone());
} else {
debug!(
user_id = %self.other_user(),
flow_id = self.flow_id().as_str(),
"Inserting new QR verification"
);
self.verification_cache.insert_qr(qr_verification.clone());
}
Ok(Some(qr_verification))
// We may have previously started our own QR verification (e.g. two devices
// displaying QR code at the same time), so we need to replace it with the newly
// scanned code.
if self
.verification_cache
.get_qr(qr_verification.other_user_id(), qr_verification.flow_id().as_str())
.is_some()
{
debug!(
user_id = %self.other_user(),
flow_id = self.flow_id().as_str(),
"Replacing existing QR verification"
);
self.verification_cache.replace_qr(qr_verification.clone());
} else {
Ok(None)
debug!(
user_id = %self.other_user(),
flow_id = self.flow_id().as_str(),
"Inserting new QR verification"
);
self.verification_cache.insert_qr(qr_verification.clone());
}
Ok(Some(qr_verification))
}
pub(crate) fn from_request(
@@ -541,31 +537,24 @@ impl VerificationRequest {
let cancelled = Cancelled::new(true, code);
let cancel_content = cancelled.as_content(self.flow_id());
if let OutgoingContent::ToDevice(c) = cancel_content {
let recipients: Vec<OwnedDeviceId> = self
.recipient_devices
.iter()
.filter(|&d| filter_device.map_or(true, |device| **d != *device))
.cloned()
.collect();
let OutgoingContent::ToDevice(c) = cancel_content else { return None };
let recip_devices: Vec<OwnedDeviceId> = self
.recipient_devices
.iter()
.filter(|&d| filter_device.map_or(true, |device| **d != *device))
.cloned()
.collect();
// We don't need to notify anyone if no recipients were present
// but we did have a filter device, since this means that only a
// single device received the `m.key.verification.request` and that
// device accepted the request.
if recipients.is_empty() && filter_device.is_some() {
None
} else {
Some(ToDeviceRequest::for_recipients(
self.other_user(),
recipients,
c,
TransactionId::new(),
))
}
} else {
None
if recip_devices.is_empty() && filter_device.is_some() {
// We don't need to notify anyone if no recipients were present but
// we did have a filter device, since this means that only a single
// device received the `m.key.verification.request` and that device
// accepted the request.
return None;
}
let recipient = self.other_user();
Some(ToDeviceRequest::for_recipients(recipient, recip_devices, c, TransactionId::new()))
}
pub(crate) fn receive_ready(&self, sender: &UserId, content: &ReadyContent<'_>) {
@@ -601,17 +590,16 @@ impl VerificationRequest {
) -> Result<(), CryptoStoreError> {
let inner = self.inner.lock().unwrap().clone();
if let InnerRequest::Ready(s) = inner {
s.receive_start(sender, content, self.we_started, self.inner.clone().into()).await?;
} else {
let InnerRequest::Ready(s) = inner else {
warn!(
sender = sender.as_str(),
device_id = content.from_device().as_str(),
"Received a key verification start event but we're not yet in the ready state"
);
}
return Ok(());
};
Ok(())
s.receive_start(sender, content, self.we_started, self.inner.clone().into()).await
}
pub(crate) fn receive_done(&self, sender: &UserId, content: &DoneContent<'_>) {
@@ -628,21 +616,23 @@ impl VerificationRequest {
}
pub(crate) fn receive_cancel(&self, sender: &UserId, content: &CancelContent<'_>) {
if sender == self.other_user() {
trace!(
sender = sender.as_str(),
code = content.cancel_code().as_str(),
"Cancelling a verification request, other user has cancelled"
);
let mut inner = self.inner.lock().unwrap();
inner.cancel(false, content.cancel_code());
if sender != self.other_user() {
return;
}
if self.we_started() {
if let Some(request) =
self.cancel_for_other_devices(content.cancel_code().to_owned(), None)
{
self.verification_cache.add_verification_request(request.into());
}
trace!(
sender = sender.as_str(),
code = content.cancel_code().as_str(),
"Cancelling a verification request, other user has cancelled"
);
let mut inner = self.inner.lock().unwrap();
inner.cancel(false, content.cancel_code());
if self.we_started() {
if let Some(request) =
self.cancel_for_other_devices(content.cancel_code().to_owned(), None)
{
self.verification_cache.add_verification_request(request.into());
}
}
}
@@ -726,14 +716,11 @@ impl InnerRequest {
}
fn accept(&mut self, methods: Vec<VerificationMethod>) -> Option<OutgoingContent> {
if let InnerRequest::Requested(s) = self {
let (state, content) = s.clone().accept(methods);
*self = InnerRequest::Ready(state);
let InnerRequest::Requested(s) = self else { return None };
let (state, content) = s.clone().accept(methods);
*self = InnerRequest::Ready(state);
Some(content)
} else {
None
}
Some(content)
}
fn receive_done(&mut self, content: &DoneContent<'_>) {

View File

@@ -180,20 +180,17 @@ impl InnerSas {
self,
methods: Vec<ShortAuthenticationString>,
) -> Option<(InnerSas, OwnedAcceptContent)> {
if let InnerSas::Started(s) = self {
let sas = s.into_we_accepted(methods);
let content = sas.as_content();
let InnerSas::Started(s) = self else { return None };
let sas = s.into_we_accepted(methods);
let content = sas.as_content();
trace!(
flow_id = sas.verification_flow_id.as_str(),
accepted_protocols = ?sas.state.accepted_protocols,
"Accepted a SAS verification"
);
trace!(
flow_id = sas.verification_flow_id.as_str(),
accepted_protocols = ?sas.state.accepted_protocols,
"Accepted a SAS verification"
);
Some((InnerSas::WeAccepted(sas), content))
} else {
None
}
Some((InnerSas::WeAccepted(sas), content))
}
#[cfg(test)]

View File

@@ -615,30 +615,30 @@ impl SasState<Created> {
) -> Result<SasState<Accepted>, SasState<Cancelled>> {
self.check_event(sender, content.flow_id()).map_err(|c| self.clone().cancel(true, c))?;
if let AcceptMethod::SasV1(content) = content.method() {
let accepted_protocols = AcceptedProtocols::try_from(content.clone())
.map_err(|c| self.clone().cancel(true, c))?;
let AcceptMethod::SasV1(content) = content.method() else {
return Err(self.cancel(true, CancelCode::UnknownMethod));
};
let start_content = self.as_content().into();
let accepted_protocols = AcceptedProtocols::try_from(content.clone())
.map_err(|c| self.clone().cancel(true, c))?;
Ok(SasState {
inner: self.inner,
our_public_key: self.our_public_key,
ids: self.ids,
verification_flow_id: self.verification_flow_id,
creation_time: self.creation_time,
last_event_time: Instant::now().into(),
started_from_request: self.started_from_request,
state: Arc::new(Accepted {
start_content,
commitment: content.commitment.clone(),
request_id: TransactionId::new(),
accepted_protocols,
}),
})
} else {
Err(self.cancel(true, CancelCode::UnknownMethod))
}
let start_content = self.as_content().into();
Ok(SasState {
inner: self.inner,
our_public_key: self.our_public_key,
ids: self.ids,
verification_flow_id: self.verification_flow_id,
creation_time: self.creation_time,
last_event_time: Instant::now().into(),
started_from_request: self.started_from_request,
state: Arc::new(Accepted {
start_content,
commitment: content.commitment.clone(),
request_id: TransactionId::new(),
accepted_protocols,
}),
})
}
}
@@ -689,41 +689,44 @@ impl SasState<Started> {
state: Arc::new(Cancelled::new(true, CancelCode::UnknownMethod)),
};
if let StartMethod::SasV1(method_content) = content.method() {
let commitment = calculate_commitment(our_public_key, content);
let state = match content.method() {
StartMethod::SasV1(method_content) => {
let commitment = calculate_commitment(our_public_key, content);
info!(
public_key = our_public_key.to_base64(),
%commitment,
?content,
"Calculated SAS commitment",
);
info!(
public_key = our_public_key.to_base64(),
%commitment,
?content,
"Calculated SAS commitment",
);
if let Ok(accepted_protocols) = AcceptedProtocols::try_from(method_content) {
Ok(SasState {
inner: Arc::new(Mutex::new(Some(sas))),
our_public_key,
let Ok(accepted_protocols) = AcceptedProtocols::try_from(method_content) else {
return Err(canceled());
};
ids: SasIds { account, other_device, other_identity, own_identity },
creation_time: Arc::new(Instant::now()),
last_event_time: Arc::new(Instant::now()),
started_from_request,
verification_flow_id: flow_id,
state: Arc::new(Started {
protocol_definitions: method_content.to_owned(),
accepted_protocols,
commitment,
}),
})
} else {
Err(canceled())
Started {
protocol_definitions: method_content.to_owned(),
accepted_protocols,
commitment,
}
}
} else {
Err(canceled())
}
_ => return Err(canceled()),
};
Ok(SasState {
inner: Arc::new(Mutex::new(Some(sas))),
our_public_key,
ids: SasIds { account, other_device, other_identity, own_identity },
creation_time: Arc::new(Instant::now()),
last_event_time: Arc::new(Instant::now()),
started_from_request,
verification_flow_id: flow_id,
state: Arc::new(state),
})
}
#[cfg(test)]
@@ -813,30 +816,30 @@ impl SasState<Started> {
) -> Result<SasState<Accepted>, SasState<Cancelled>> {
self.check_event(sender, content.flow_id()).map_err(|c| self.clone().cancel(true, c))?;
if let AcceptMethod::SasV1(content) = content.method() {
let accepted_protocols = AcceptedProtocols::try_from(content.clone())
.map_err(|c| self.clone().cancel(true, c))?;
let AcceptMethod::SasV1(content) = content.method() else {
return Err(self.cancel(true, CancelCode::UnknownMethod));
};
let start_content = self.as_content().into();
let accepted_protocols = AcceptedProtocols::try_from(content.clone())
.map_err(|c| self.clone().cancel(true, c))?;
Ok(SasState {
inner: self.inner,
our_public_key: self.our_public_key,
ids: self.ids,
verification_flow_id: self.verification_flow_id,
creation_time: self.creation_time,
last_event_time: Instant::now().into(),
started_from_request: self.started_from_request,
state: Arc::new(Accepted {
start_content,
commitment: content.commitment.clone(),
request_id: TransactionId::new(),
accepted_protocols,
}),
})
} else {
Err(self.cancel(true, CancelCode::UnknownMethod))
}
let start_content = self.as_content().into();
Ok(SasState {
inner: self.inner,
our_public_key: self.our_public_key,
ids: self.ids,
verification_flow_id: self.verification_flow_id,
creation_time: self.creation_time,
last_event_time: Instant::now().into(),
started_from_request: self.started_from_request,
state: Arc::new(Accepted {
start_content,
commitment: content.commitment.clone(),
request_id: TransactionId::new(),
accepted_protocols,
}),
})
}
}

View File

@@ -380,7 +380,7 @@ impl ClientBuilder {
if let Some(issuer) = well_known.authentication.map(|auth| auth.issuer) {
authentication_issuer = Url::parse(&issuer).ok();
};
}
well_known.homeserver.base_url
}

View File

@@ -310,11 +310,8 @@ impl Client {
/// The OIDC Provider that is trusted by the homeserver.
pub async fn authentication_issuer(&self) -> Option<Url> {
if let Some(server) = &self.inner.authentication_issuer {
Some(server.read().await.clone())
} else {
None
}
let server = self.inner.authentication_issuer.as_ref()?;
Some(server.read().await.clone())
}
fn session_meta(&self) -> Option<&SessionMeta> {
@@ -1776,8 +1773,11 @@ impl Client {
Request: OutgoingRequest + Debug,
HttpError: From<FromHttpResponseError<Request::EndpointError>>,
{
let homeserver =
if let Some(h) = homeserver { h } else { self.homeserver().await.to_string() };
let homeserver = match homeserver {
Some(hs) => hs,
None => self.homeserver().await.to_string(),
};
self.inner
.http_client
.send(

View File

@@ -479,11 +479,8 @@ impl Encryption {
/// This can be used to check which private cross signing keys we have
/// stored locally.
pub async fn cross_signing_status(&self) -> Option<CrossSigningStatus> {
if let Some(machine) = self.client.olm_machine() {
Some(machine.cross_signing_status().await)
} else {
None
}
let machine = self.client.olm_machine()?;
Some(machine.cross_signing_status().await)
}
/// Get all the tracked users we know about
@@ -562,12 +559,9 @@ impl Encryption {
user_id: &UserId,
device_id: &DeviceId,
) -> Result<Option<Device>, CryptoStoreError> {
if let Some(machine) = self.client.olm_machine() {
let device = machine.get_device(user_id, device_id, None).await?;
Ok(device.map(|d| Device { inner: d, client: self.client.clone() }))
} else {
Ok(None)
}
let Some(machine) = self.client.olm_machine() else { return Ok(None) };
let device = machine.get_device(user_id, device_id, None).await?;
Ok(device.map(|d| Device { inner: d, client: self.client.clone() }))
}
/// Get a map holding all the devices of an user.
@@ -643,20 +637,17 @@ impl Encryption {
) -> Result<Option<crate::encryption::identities::UserIdentity>, CryptoStoreError> {
use crate::encryption::identities::UserIdentity;
if let Some(olm) = self.client.olm_machine() {
let identity = olm.get_identity(user_id, None).await?;
let Some(olm) = self.client.olm_machine() else { return Ok(None) };
let identity = olm.get_identity(user_id, None).await?;
Ok(identity.map(|i| match i {
matrix_sdk_base::crypto::UserIdentities::Own(i) => {
UserIdentity::new_own(self.client.clone(), i)
}
matrix_sdk_base::crypto::UserIdentities::Other(i) => {
UserIdentity::new(self.client.clone(), i, self.client.get_dm_room(user_id))
}
}))
} else {
Ok(None)
}
Ok(identity.map(|i| match i {
matrix_sdk_base::crypto::UserIdentities::Own(i) => {
UserIdentity::new_own(self.client.clone(), i)
}
matrix_sdk_base::crypto::UserIdentities::Other(i) => {
UserIdentity::new(self.client.clone(), i, self.client.get_dm_room(user_id))
}
}))
}
/// Create and upload a new cross signing identity.

View File

@@ -141,26 +141,20 @@ impl VerificationRequest {
/// for the remainder of the verification flow.
#[cfg(feature = "qrcode")]
pub async fn scan_qr_code(&self, data: QrVerificationData) -> Result<Option<QrVerification>> {
if let Some(qr) = self.inner.scan_qr_code(data).await? {
if let Some(request) = qr.reciprocate() {
self.client.send_verification_request(request).await?;
}
Ok(Some(QrVerification { inner: qr, client: self.client.clone() }))
} else {
Ok(None)
let Some(qr) = self.inner.scan_qr_code(data).await? else { return Ok(None) };
if let Some(request) = qr.reciprocate() {
self.client.send_verification_request(request).await?;
}
Ok(Some(QrVerification { inner: qr, client: self.client.clone() }))
}
/// Transition from this verification request into a SAS verification flow.
pub async fn start_sas(&self) -> Result<Option<SasVerification>> {
if let Some((sas, request)) = self.inner.start_sas().await? {
self.client.send_verification_request(request).await?;
let Some((sas, request)) = self.inner.start_sas().await? else { return Ok(None) };
self.client.send_verification_request(request).await?;
Ok(Some(SasVerification { inner: sas, client: self.client.clone() }))
} else {
Ok(None)
}
Ok(Some(SasVerification { inner: sas, client: self.client.clone() }))
}
/// Cancel the verification request

View File

@@ -115,50 +115,47 @@ impl Media {
if use_cache { self.client.store().get_media_content(request).await? } else { None };
if let Some(content) = content {
Ok(content)
} else {
let content: Vec<u8> = match &request.source {
MediaSource::Encrypted(file) => {
let request = get_content::v3::Request::from_url(&file.url)?;
let content: Vec<u8> = self.client.send(request, None).await?.file;
#[cfg(feature = "e2e-encryption")]
let content = {
let mut cursor = std::io::Cursor::new(content);
let mut reader = matrix_sdk_base::crypto::AttachmentDecryptor::new(
&mut cursor,
file.as_ref().clone().into(),
)?;
let mut decrypted = Vec::new();
reader.read_to_end(&mut decrypted)?;
decrypted
};
content
}
MediaSource::Plain(uri) => {
if let MediaFormat::Thumbnail(size) = &request.format {
let request = get_content_thumbnail::v3::Request::from_url(
uri,
size.width,
size.height,
)?;
self.client.send(request, None).await?.file
} else {
let request = get_content::v3::Request::from_url(uri)?;
self.client.send(request, None).await?.file
}
}
};
if use_cache {
self.client.store().add_media_content(request, content.clone()).await?;
}
Ok(content)
return Ok(content);
}
let content: Vec<u8> = match &request.source {
MediaSource::Encrypted(file) => {
let request = get_content::v3::Request::from_url(&file.url)?;
let content: Vec<u8> = self.client.send(request, None).await?.file;
#[cfg(feature = "e2e-encryption")]
let content = {
let mut cursor = std::io::Cursor::new(content);
let mut reader = matrix_sdk_base::crypto::AttachmentDecryptor::new(
&mut cursor,
file.as_ref().clone().into(),
)?;
let mut decrypted = Vec::new();
reader.read_to_end(&mut decrypted)?;
decrypted
};
content
}
MediaSource::Plain(uri) => {
if let MediaFormat::Thumbnail(size) = &request.format {
let request =
get_content_thumbnail::v3::Request::from_url(uri, size.width, size.height)?;
self.client.send(request, None).await?.file
} else {
let request = get_content::v3::Request::from_url(uri)?;
self.client.send(request, None).await?.file
}
}
};
if use_cache {
self.client.store().add_media_content(request, content.clone()).await?;
}
Ok(content)
}
/// Remove a media file's content from the store.
@@ -200,17 +197,11 @@ impl Media {
event_content: impl MediaEventContent,
use_cache: bool,
) -> Result<Option<Vec<u8>>> {
if let Some(source) = event_content.source() {
Ok(Some(
self.get_media_content(
&MediaRequest { source, format: MediaFormat::File },
use_cache,
)
.await?,
))
} else {
Ok(None)
}
let Some(source) = event_content.source() else { return Ok(None) };
let file = self
.get_media_content(&MediaRequest { source, format: MediaFormat::File }, use_cache)
.await?;
Ok(Some(file))
}
/// Remove the file of the given media event content from the cache.
@@ -223,7 +214,7 @@ impl Media {
/// * `event_content` - The media event content.
pub async fn remove_file(&self, event_content: impl MediaEventContent) -> Result<()> {
if let Some(source) = event_content.source() {
self.remove_media_content(&MediaRequest { source, format: MediaFormat::File }).await?
self.remove_media_content(&MediaRequest { source, format: MediaFormat::File }).await?;
}
Ok(())
@@ -253,17 +244,14 @@ impl Media {
size: MediaThumbnailSize,
use_cache: bool,
) -> Result<Option<Vec<u8>>> {
if let Some(source) = event_content.thumbnail_source() {
Ok(Some(
self.get_media_content(
&MediaRequest { source, format: MediaFormat::Thumbnail(size) },
use_cache,
)
.await?,
))
} else {
Ok(None)
}
let Some(source) = event_content.thumbnail_source() else { return Ok(None) };
let thumbnail = self
.get_media_content(
&MediaRequest { source, format: MediaFormat::Thumbnail(size) },
use_cache,
)
.await?;
Ok(Some(thumbnail))
}
/// Remove the thumbnail of the given media event content from the cache.

View File

@@ -160,12 +160,9 @@ impl Common {
/// # })
/// ```
pub async fn avatar(&self, format: MediaFormat) -> Result<Option<Vec<u8>>> {
if let Some(url) = self.avatar_url() {
let request = MediaRequest { source: MediaSource::Plain(url.to_owned()), format };
Ok(Some(self.client.media().get_media_content(&request, true).await?))
} else {
Ok(None)
}
let Some(url) = self.avatar_url() else { return Ok(None) };
let request = MediaRequest { source: MediaSource::Plain(url.to_owned()), format };
Ok(Some(self.client.media().get_media_content(&request, true).await?))
}
/// Sends a request to `/_matrix/client/r0/rooms/{room_id}/messages` and

View File

@@ -60,11 +60,8 @@ impl RoomMember {
/// # })
/// ```
pub async fn avatar(&self, format: MediaFormat) -> Result<Option<Vec<u8>>> {
if let Some(url) = self.avatar_url() {
let request = MediaRequest { source: MediaSource::Plain(url.to_owned()), format };
Ok(Some(self.client.media().get_media_content(&request, true).await?))
} else {
Ok(None)
}
let Some(url) = self.avatar_url() else { return Ok(None) };
let request = MediaRequest { source: MediaSource::Plain(url.to_owned()), format };
Ok(Some(self.client.media().get_media_content(&request, true).await?))
}
}

View File

@@ -509,7 +509,7 @@ impl SlidingSync {
/// Run this stream to receive new updates from the server.
pub async fn stream<'a>(
&self,
) -> Result<impl Stream<Item = Result<UpdateSummary, crate::Error>> + '_, crate::Error> {
) -> Result<impl Stream<Item = Result<UpdateSummary, crate::Error>> + '_> {
let views = self.views.lock_ref().to_vec();
let extensions = self.extensions.clone();
let client = self.client.clone();
@@ -525,7 +525,7 @@ impl SlidingSync {
let mut new_remaining_generators = Vec::new();
let mut new_remaining_views = Vec::new();
for (mut generator, view) in std::iter::zip(remaining_generators, remaining_views) {
for (mut generator, view) in std::iter::zip(remaining_generators, remaining_views) {
if let Some(request) = generator.next() {
requests.push(request);
new_remaining_generators.push(generator);