diff --git a/crates/matrix-sdk/src/test_utils/mocks/mod.rs b/crates/matrix-sdk/src/test_utils/mocks/mod.rs index 0e7f4e98d..e9341b511 100644 --- a/crates/matrix-sdk/src/test_utils/mocks/mod.rs +++ b/crates/matrix-sdk/src/test_utils/mocks/mod.rs @@ -950,22 +950,19 @@ impl MatrixMockServer { MockEndpoint { mock, server: &self.server, endpoint: SetRoomPinnedEventsEndpoint } } - /// Creates a prebuilt mock for the endpoint used to get information about - /// the owner of the default access token. - pub fn mock_who_am_i(&self) -> MockEndpoint<'_, WhoAmIEndpoint> { - self.mock_who_am_i_with_access_token("1234") - } - /// Creates a prebuilt mock for the endpoint used to get information about /// the owner of the given access token. - pub fn mock_who_am_i_with_access_token( - &self, - access_token: &str, - ) -> MockEndpoint<'_, WhoAmIEndpoint> { - let mock = Mock::given(method("GET")) - .and(path_regex(r"^/_matrix/client/v3/account/whoami")) - .and(header("authorization", format!("Bearer {access_token}"))); - MockEndpoint { mock, server: &self.server, endpoint: WhoAmIEndpoint } + /// + /// If no access token is provided, the access token to match is `"1234"`, + /// which matches the default value in the mock data. + pub fn mock_who_am_i(&self) -> MockEndpoint<'_, WhoAmIEndpoint> { + let mock = + Mock::given(method("GET")).and(path_regex(r"^/_matrix/client/v3/account/whoami")); + MockEndpoint { + mock, + server: &self.server, + endpoint: WhoAmIEndpoint { expected_access_token: "1234" }, + } } /// Creates a prebuilt mock for the endpoint used to publish end-to-end @@ -2418,25 +2415,43 @@ impl<'a> MockEndpoint<'a, SetRoomPinnedEventsEndpoint> { } /// A prebuilt mock for `GET /account/whoami` request. -pub struct WhoAmIEndpoint; +pub struct WhoAmIEndpoint { + expected_access_token: &'static str, +} + +impl WhoAmIEndpoint { + fn add_access_token_matcher(&self, mock: MockBuilder) -> MockBuilder { + mock.and(header("authorization", format!("Bearer {}", self.expected_access_token))) + } +} impl<'a> MockEndpoint<'a, WhoAmIEndpoint> { + /// Override the access token to expect for this endpoint. + pub fn expected_access_token(mut self, access_token: &'static str) -> Self { + self.endpoint.expected_access_token = access_token; + self + } + /// Returns a successful response with a user ID and device ID. pub fn ok(self) -> MatrixMock<'a> { - let mock = self.mock.respond_with(ResponseTemplate::new(200).set_body_json(json!({ - "user_id": "@joe:example.org", - "device_id": "D3V1C31D", - }))); + let mock = self.endpoint.add_access_token_matcher(self.mock).respond_with( + ResponseTemplate::new(200).set_body_json(json!({ + "user_id": "@joe:example.org", + "device_id": "D3V1C31D", + })), + ); MatrixMock { server: self.server, mock } } /// Returns an error response with an `M_UNKNOWN_TOKEN`. pub fn err_unknown_token(self) -> MatrixMock<'a> { - let mock = self.mock.respond_with(ResponseTemplate::new(401).set_body_json(json!({ - "errcode": "M_UNKNOWN_TOKEN", - "error": "Invalid token" - }))); + let mock = self.endpoint.add_access_token_matcher(self.mock).respond_with( + ResponseTemplate::new(401).set_body_json(json!({ + "errcode": "M_UNKNOWN_TOKEN", + "error": "Invalid token" + })), + ); MatrixMock { server: self.server, mock } } diff --git a/crates/matrix-sdk/tests/integration/refresh_token.rs b/crates/matrix-sdk/tests/integration/refresh_token.rs index ead506acc..d308d92d7 100644 --- a/crates/matrix-sdk/tests/integration/refresh_token.rs +++ b/crates/matrix-sdk/tests/integration/refresh_token.rs @@ -577,15 +577,19 @@ async fn test_refresh_token_handled_other_error() { #[cfg(feature = "experimental-oidc")] #[async_test] async fn test_oauth_refresh_token_handled_success() { - use matrix_sdk::test_utils::{ - client::oauth::{mock_prev_session_tokens, mock_session, mock_session_tokens}, - mocks::MatrixMockServer, + use matrix_sdk::{ + assert_next_eq_with_timeout, + test_utils::{ + client::oauth::{mock_prev_session_tokens, mock_session, mock_session_tokens}, + mocks::MatrixMockServer, + }, }; let server = MatrixMockServer::new().await; // Return an error first so the token is refreshed. server - .mock_who_am_i_with_access_token("prev-access-token") + .mock_who_am_i() + .expected_access_token("prev-access-token") .err_unknown_token() .expect(1) .named("whoami_unknown_token") @@ -604,17 +608,14 @@ async fn test_oauth_refresh_token_handled_success() { oidc.restore_session(mock_session(mock_prev_session_tokens(), issuer)).await.unwrap(); let mut tokens_stream = oidc.session_tokens_stream().unwrap(); - let tokens_join_handle = spawn(async move { - let tokens = tokens_stream.next().await.unwrap(); - assert_eq!(tokens, mock_session_tokens()); - }); - let mut session_changes = client.subscribe_to_session_changes(); client.whoami().await.unwrap(); - tokens_join_handle.await.unwrap(); + // We receive the new tokens on the stream. + assert_next_eq_with_timeout!(tokens_stream, mock_session_tokens()); + // We get notified once that the tokens were refreshed. assert_eq!(session_changes.try_recv(), Ok(SessionChange::TokensRefreshed)); assert_eq!(session_changes.try_recv(), Err(TryRecvError::Empty)); } @@ -633,7 +634,8 @@ async fn test_oauth_refresh_token_handled_failure() { let server = MatrixMockServer::new().await; // Return an error first so the token is refreshed. server - .mock_who_am_i_with_access_token("prev-access-token") + .mock_who_am_i() + .expected_access_token("prev-access-token") .err_unknown_token() .expect(1) .named("whoami_unknown_token") @@ -653,10 +655,12 @@ async fn test_oauth_refresh_token_handled_failure() { let mut session_changes = client.subscribe_to_session_changes(); + // The request fails with a refresh token error. let res = client.whoami().await; assert_let!(Err(HttpError::RefreshToken(RefreshTokenError::Oidc(oidc_err))) = res); assert_matches!(*oidc_err, OidcError::RefreshToken(_)); + // We get notified once that the token is invalid. assert_eq!(session_changes.try_recv(), Ok(SessionChange::UnknownToken { soft_logout: false })); assert_eq!(session_changes.try_recv(), Err(TryRecvError::Empty)); } @@ -700,10 +704,14 @@ async fn test_oauth_handle_refresh_tokens() { oauth_server.mock_token().ok().mock_once().named("refresh_token").mount().await; oidc.refresh_access_token().await.unwrap(); + + // The tokens were updated. assert_eq!(oidc.session_tokens(), Some(mock_session_tokens())); + // The save session callback was called once. assert_eq!(*num_save_session_callback_calls.lock().unwrap(), 1); + // We get notified once that the tokens were refreshed. assert_eq!(session_changes.try_recv(), Ok(SessionChange::TokensRefreshed)); assert_eq!(session_changes.try_recv(), Err(TryRecvError::Empty)); }