diff --git a/bindings/matrix-sdk-ffi/src/session_verification.rs b/bindings/matrix-sdk-ffi/src/session_verification.rs index b40f951b3..2e173f607 100644 --- a/bindings/matrix-sdk-ffi/src/session_verification.rs +++ b/bindings/matrix-sdk-ffi/src/session_verification.rs @@ -4,7 +4,7 @@ use futures_util::StreamExt; use matrix_sdk::{ encryption::{ identities::UserIdentity, - verification::{SasState, SasVerification, VerificationRequest}, + verification::{SasState, SasVerification, VerificationRequest, VerificationRequestState}, Encryption, }, ruma::events::{key::verification::VerificationMethod, AnyToDeviceEvent}, @@ -107,7 +107,13 @@ impl SessionVerificationController { .await .map_err(anyhow::Error::from)?; - *self.verification_request.write().unwrap() = Some(verification_request); + *self.verification_request.write().unwrap() = Some(verification_request.clone()); + + RUNTIME.spawn(Self::listen_to_verification_request_changes( + verification_request, + self.sas_verification.clone(), + self.delegate.clone(), + )); Ok(()) } @@ -125,7 +131,7 @@ impl SessionVerificationController { } let delegate = self.delegate.clone(); - RUNTIME.spawn(Self::listen_to_changes(delegate, verification)); + RUNTIME.spawn(Self::listen_to_sas_verification_changes(verification, delegate)); } _ => { if let Some(delegate) = &*self.delegate.read().unwrap() { @@ -203,60 +209,57 @@ impl SessionVerificationController { ); } } - // TODO: Use the changes stream for this as well once we expose - // VerificationRequest::changes() in the main crate. - AnyToDeviceEvent::KeyVerificationStart(event) => { - if !self.is_transaction_id_valid(event.content.transaction_id.to_string()) { - return; - } - - let Some(verification) = self - .encryption - .get_verification( - self.user_identity.user_id(), - event.content.transaction_id.as_str(), - ) - .await - else { - return; - }; - - let Some(sas_verification) = verification.sas() else { return }; - - *self.sas_verification.write().unwrap() = Some(sas_verification.clone()); - - if sas_verification.accept().await.is_ok() { - if let Some(delegate) = &*self.delegate.read().unwrap() { - delegate.did_start_sas_verification() - } - - let delegate = self.delegate.clone(); - RUNTIME.spawn(Self::listen_to_changes(delegate, sas_verification)); - } else if let Some(delegate) = &*self.delegate.read().unwrap() { - delegate.did_fail() - } - } - AnyToDeviceEvent::KeyVerificationReady(event) => { - if !self.is_transaction_id_valid(event.content.transaction_id.to_string()) { - return; - } - - if let Some(delegate) = &*self.delegate.read().unwrap() { - delegate.did_accept_verification_request() - } - } _ => (), } } - fn is_transaction_id_valid(&self, transaction_id: String) -> bool { - match &*self.verification_request.read().unwrap() { - Some(verification) => verification.flow_id() == transaction_id, - None => false, + async fn listen_to_verification_request_changes( + verification_request: VerificationRequest, + sas_verification: Arc>>, + delegate: Delegate, + ) { + let mut stream = verification_request.changes(); + + while let Some(state) = stream.next().await { + match state { + VerificationRequestState::Transitioned { verification } => { + let Some(verification) = verification.sas() else { + error!("Invalid, non-sas verification flow. Returning."); + return; + }; + + *sas_verification.write().unwrap() = Some(verification.clone()); + + if verification.accept().await.is_ok() { + if let Some(delegate) = &*delegate.read().unwrap() { + delegate.did_start_sas_verification() + } + + let delegate = delegate.clone(); + RUNTIME.spawn(Self::listen_to_sas_verification_changes( + verification, + delegate, + )); + } else if let Some(delegate) = &*delegate.read().unwrap() { + delegate.did_fail() + } + } + VerificationRequestState::Ready { .. } => { + if let Some(delegate) = &*delegate.read().unwrap() { + delegate.did_accept_verification_request() + } + } + VerificationRequestState::Cancelled(..) => { + if let Some(delegate) = &*delegate.read().unwrap() { + delegate.did_cancel(); + } + } + _ => {} + } } } - async fn listen_to_changes(delegate: Delegate, sas: SasVerification) { + async fn listen_to_sas_verification_changes(sas: SasVerification, delegate: Delegate) { let mut stream = sas.changes(); while let Some(state) = stream.next().await {