From b5c4fe3f7dfb5fccdd2f02645f1d800b36d463ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=A9vin=20Commaille?= Date: Sun, 9 Mar 2025 12:13:26 +0100 Subject: [PATCH] test(sdk): Allow any MockEndpoint to override the expected access token MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Kévin Commaille --- crates/matrix-sdk/src/test_utils/mocks/mod.rs | 157 +++++++++--------- .../tests/integration/refresh_token.rs | 4 +- 2 files changed, 84 insertions(+), 77 deletions(-) diff --git a/crates/matrix-sdk/src/test_utils/mocks/mod.rs b/crates/matrix-sdk/src/test_utils/mocks/mod.rs index d2109126c..6e71e9dad 100644 --- a/crates/matrix-sdk/src/test_utils/mocks/mod.rs +++ b/crates/matrix-sdk/src/test_utils/mocks/mod.rs @@ -288,13 +288,12 @@ impl MatrixMockServer { /// # anyhow::Ok(()) }); /// ``` pub fn mock_sync(&self) -> MockEndpoint<'_, SyncEndpoint> { - let mock = Mock::given(method("GET")) - .and(path("/_matrix/client/v3/sync")) - .and(header("authorization", "Bearer 1234")); + let mock = Mock::given(method("GET")).and(path("/_matrix/client/v3/sync")); self.mock_endpoint( mock, SyncEndpoint { sync_response_builder: self.sync_response_builder.clone() }, ) + .expect_default_access_token() } /// Creates a prebuilt mock for sending an event in a room. @@ -336,9 +335,8 @@ impl MatrixMockServer { /// ``` pub fn mock_room_send(&self) -> MockEndpoint<'_, RoomSendEndpoint> { let mock = Mock::given(method("PUT")) - .and(header("authorization", "Bearer 1234")) .and(path_regex(r"^/_matrix/client/v3/rooms/.*/send/.*".to_owned())); - self.mock_endpoint(mock, RoomSendEndpoint) + self.mock_endpoint(mock, RoomSendEndpoint).expect_default_access_token() } /// Creates a prebuilt mock for sending a state event in a room. @@ -385,10 +383,9 @@ impl MatrixMockServer { /// # anyhow::Ok(()) }); /// ``` pub fn mock_room_send_state(&self) -> MockEndpoint<'_, RoomSendStateEndpoint> { - let mock = Mock::given(method("PUT")) - .and(header("authorization", "Bearer 1234")) - .and(path_regex(r"^/_matrix/client/v3/rooms/.*/state/.*/.*")); - self.mock_endpoint(mock, RoomSendStateEndpoint::default()) + let mock = + Mock::given(method("PUT")).and(path_regex(r"^/_matrix/client/v3/rooms/.*/state/.*/.*")); + self.mock_endpoint(mock, RoomSendStateEndpoint::default()).expect_default_access_token() } /// Creates a prebuilt mock for asking whether *a* room is encrypted or not. @@ -418,9 +415,8 @@ impl MatrixMockServer { /// ``` pub fn mock_room_state_encryption(&self) -> MockEndpoint<'_, EncryptionStateEndpoint> { let mock = Mock::given(method("GET")) - .and(header("authorization", "Bearer 1234")) .and(path_regex(r"^/_matrix/client/v3/rooms/.*/state/m.*room.*encryption.?")); - self.mock_endpoint(mock, EncryptionStateEndpoint) + self.mock_endpoint(mock, EncryptionStateEndpoint).expect_default_access_token() } /// Creates a prebuilt mock for setting the room encryption state. @@ -458,9 +454,8 @@ impl MatrixMockServer { /// ``` pub fn mock_set_room_state_encryption(&self) -> MockEndpoint<'_, SetEncryptionStateEndpoint> { let mock = Mock::given(method("PUT")) - .and(header("authorization", "Bearer 1234")) .and(path_regex(r"^/_matrix/client/v3/rooms/.*/state/m.*room.*encryption.?")); - self.mock_endpoint(mock, SetEncryptionStateEndpoint) + self.mock_endpoint(mock, SetEncryptionStateEndpoint).expect_default_access_token() } /// Creates a prebuilt mock for the room redact endpoint. @@ -491,32 +486,29 @@ impl MatrixMockServer { /// ``` pub fn mock_room_redact(&self) -> MockEndpoint<'_, RoomRedactEndpoint> { let mock = Mock::given(method("PUT")) - .and(path_regex(r"^/_matrix/client/v3/rooms/.*/redact/.*?/.*?")) - .and(header("authorization", "Bearer 1234")); - self.mock_endpoint(mock, RoomRedactEndpoint) + .and(path_regex(r"^/_matrix/client/v3/rooms/.*/redact/.*?/.*?")); + self.mock_endpoint(mock, RoomRedactEndpoint).expect_default_access_token() } /// Creates a prebuilt mock for retrieving an event with /room/.../event. pub fn mock_room_event(&self) -> MockEndpoint<'_, RoomEventEndpoint> { - let mock = Mock::given(method("GET")).and(header("authorization", "Bearer 1234")); + let mock = Mock::given(method("GET")); self.mock_endpoint(mock, RoomEventEndpoint { room: None, match_event_id: false }) + .expect_default_access_token() } /// Create a prebuild mock for paginating room message with the `/messages` /// endpoint. pub fn mock_room_messages(&self) -> MockEndpoint<'_, RoomMessagesEndpoint> { - let mock = Mock::given(method("GET")) - .and(path_regex(r"^/_matrix/client/v3/rooms/.*/messages$")) - .and(header("authorization", "Bearer 1234")); - self.mock_endpoint(mock, RoomMessagesEndpoint) + let mock = + Mock::given(method("GET")).and(path_regex(r"^/_matrix/client/v3/rooms/.*/messages$")); + self.mock_endpoint(mock, RoomMessagesEndpoint).expect_default_access_token() } /// Create a prebuilt mock for uploading media. pub fn mock_upload(&self) -> MockEndpoint<'_, UploadEndpoint> { - let mock = Mock::given(method("POST")) - .and(path("/_matrix/media/v3/upload")) - .and(header("authorization", "Bearer 1234")); - self.mock_endpoint(mock, UploadEndpoint) + let mock = Mock::given(method("POST")).and(path("/_matrix/media/v3/upload")); + self.mock_endpoint(mock, UploadEndpoint).expect_default_access_token() } /// Create a prebuilt mock for resolving room aliases. @@ -768,26 +760,23 @@ impl MatrixMockServer { /// # } /// ``` pub fn mock_room_keys_version(&self) -> MockEndpoint<'_, RoomKeysVersionEndpoint> { - let mock = Mock::given(method("GET")) - .and(path_regex(r"_matrix/client/v3/room_keys/version")) - .and(header("authorization", "Bearer 1234")); - self.mock_endpoint(mock, RoomKeysVersionEndpoint) + let mock = + Mock::given(method("GET")).and(path_regex(r"_matrix/client/v3/room_keys/version")); + self.mock_endpoint(mock, RoomKeysVersionEndpoint).expect_default_access_token() } /// Create a prebuilt mock for adding key storage backups via POST pub fn mock_add_room_keys_version(&self) -> MockEndpoint<'_, AddRoomKeysVersionEndpoint> { - let mock = Mock::given(method("POST")) - .and(path_regex(r"_matrix/client/v3/room_keys/version")) - .and(header("authorization", "Bearer 1234")); - self.mock_endpoint(mock, AddRoomKeysVersionEndpoint) + let mock = + Mock::given(method("POST")).and(path_regex(r"_matrix/client/v3/room_keys/version")); + self.mock_endpoint(mock, AddRoomKeysVersionEndpoint).expect_default_access_token() } /// Create a prebuilt mock for adding key storage backups via POST pub fn mock_delete_room_keys_version(&self) -> MockEndpoint<'_, DeleteRoomKeysVersionEndpoint> { let mock = Mock::given(method("DELETE")) - .and(path_regex(r"_matrix/client/v3/room_keys/version/[^/]*")) - .and(header("authorization", "Bearer 1234")); - self.mock_endpoint(mock, DeleteRoomKeysVersionEndpoint) + .and(path_regex(r"_matrix/client/v3/room_keys/version/[^/]*")); + self.mock_endpoint(mock, DeleteRoomKeysVersionEndpoint).expect_default_access_token() } /// Create a prebuilt mock for getting the room members in a room. @@ -945,9 +934,8 @@ impl MatrixMockServer { /// events. pub fn mock_set_room_pinned_events(&self) -> MockEndpoint<'_, SetRoomPinnedEventsEndpoint> { let mock = Mock::given(method("PUT")) - .and(path_regex(r"^/_matrix/client/v3/rooms/.*/state/m.room.pinned_events/.*?")) - .and(header("authorization", "Bearer 1234")); - self.mock_endpoint(mock, SetRoomPinnedEventsEndpoint) + .and(path_regex(r"^/_matrix/client/v3/rooms/.*/state/m.room.pinned_events/.*?")); + self.mock_endpoint(mock, SetRoomPinnedEventsEndpoint).expect_default_access_token() } /// Creates a prebuilt mock for the endpoint used to get information about @@ -958,25 +946,21 @@ impl MatrixMockServer { pub fn mock_who_am_i(&self) -> MockEndpoint<'_, WhoAmIEndpoint> { let mock = Mock::given(method("GET")).and(path_regex(r"^/_matrix/client/v3/account/whoami")); - self.mock_endpoint(mock, WhoAmIEndpoint { expected_access_token: "1234" }) + self.mock_endpoint(mock, WhoAmIEndpoint).expect_default_access_token() } /// Creates a prebuilt mock for the endpoint used to publish end-to-end /// encryption keys. pub fn mock_upload_keys(&self) -> MockEndpoint<'_, UploadKeysEndpoint> { - let mock = Mock::given(method("POST")) - .and(path_regex(r"^/_matrix/client/v3/keys/upload")) - .and(header("authorization", "Bearer 1234")); - self.mock_endpoint(mock, UploadKeysEndpoint) + let mock = Mock::given(method("POST")).and(path_regex(r"^/_matrix/client/v3/keys/upload")); + self.mock_endpoint(mock, UploadKeysEndpoint).expect_default_access_token() } /// Creates a prebuilt mock for the endpoint used to query end-to-end /// encryption keys. pub fn mock_query_keys(&self) -> MockEndpoint<'_, QueryKeysEndpoint> { - let mock = Mock::given(method("POST")) - .and(path_regex(r"^/_matrix/client/v3/keys/query")) - .and(header("authorization", "Bearer 1234")); - self.mock_endpoint(mock, QueryKeysEndpoint) + let mock = Mock::given(method("POST")).and(path_regex(r"^/_matrix/client/v3/keys/query")); + self.mock_endpoint(mock, QueryKeysEndpoint).expect_default_access_token() } /// Creates a prebuilt mock for the endpoint used to discover the URL of a @@ -992,9 +976,8 @@ impl MatrixMockServer { &self, ) -> MockEndpoint<'_, UploadCrossSigningKeysEndpoint> { let mock = Mock::given(method("POST")) - .and(path_regex(r"^/_matrix/client/v3/keys/device_signing/upload")) - .and(header("authorization", "Bearer 1234")); - self.mock_endpoint(mock, UploadCrossSigningKeysEndpoint) + .and(path_regex(r"^/_matrix/client/v3/keys/device_signing/upload")); + self.mock_endpoint(mock, UploadCrossSigningKeysEndpoint).expect_default_access_token() } /// Creates a prebuilt mock for the endpoint used to publish cross-signing @@ -1003,9 +986,8 @@ impl MatrixMockServer { &self, ) -> MockEndpoint<'_, UploadCrossSigningSignaturesEndpoint> { let mock = Mock::given(method("POST")) - .and(path_regex(r"^/_matrix/client/v3/keys/signatures/upload")) - .and(header("authorization", "Bearer 1234")); - self.mock_endpoint(mock, UploadCrossSigningSignaturesEndpoint) + .and(path_regex(r"^/_matrix/client/v3/keys/signatures/upload")); + self.mock_endpoint(mock, UploadCrossSigningSignaturesEndpoint).expect_default_access_token() } } @@ -1162,11 +1144,24 @@ pub struct MockEndpoint<'a, T> { server: &'a MockServer, mock: MockBuilder, endpoint: T, + expected_access_token: ExpectedAccessToken, } impl<'a, T> MockEndpoint<'a, T> { fn new(server: &'a MockServer, mock: MockBuilder, endpoint: T) -> Self { - Self { server, mock, endpoint } + Self { server, mock, endpoint, expected_access_token: ExpectedAccessToken::None } + } + + /// Expect authentication with the default access token on this endpoint. + pub fn expect_default_access_token(mut self) -> Self { + self.expected_access_token = ExpectedAccessToken::Default; + self + } + + /// Expect authentication with the given access token on this endpoint. + pub fn expect_access_token(mut self, access_token: &'static str) -> Self { + self.expected_access_token = ExpectedAccessToken::Custom(access_token); + self } /// Specify how to respond to a query (viz., like @@ -1211,7 +1206,11 @@ impl<'a, T> MockEndpoint<'a, T> { /// # anyhow::Ok(()) }); /// ``` pub fn respond_with(self, func: R) -> MatrixMock<'a> { - MatrixMock { mock: self.mock.respond_with(func), server: self.server } + let mock = self + .expected_access_token + .maybe_match_authorization_header(self.mock) + .respond_with(func); + MatrixMock { mock, server: self.server } } /// Returns a send endpoint that emulates a transient failure, i.e responds @@ -1292,6 +1291,30 @@ impl<'a, T> MockEndpoint<'a, T> { } } +/// The access token to expect on an endpoint. +enum ExpectedAccessToken { + /// We don't expect an access token. + None, + + /// We expect the default access token. + Default, + + /// We expect the given access token. + Custom(&'static str), +} + +impl ExpectedAccessToken { + /// Match an `Authorization` header on the given mock if one is expected. + fn maybe_match_authorization_header(&self, mock: MockBuilder) -> MockBuilder { + let token = match self { + Self::None => return mock, + Self::Default => "1234", + Self::Custom(token) => token, + }; + mock.and(header(http::header::AUTHORIZATION, format!("Bearer {token}"))) + } +} + /// A prebuilt mock for sending a message like event in a room. pub struct RoomSendEndpoint; @@ -2377,26 +2400,11 @@ impl<'a> MockEndpoint<'a, SetRoomPinnedEventsEndpoint> { } /// A prebuilt mock for `GET /account/whoami` request. -pub struct WhoAmIEndpoint { - expected_access_token: &'static str, -} - -impl WhoAmIEndpoint { - fn add_access_token_matcher(&self, mock: MockBuilder) -> MockBuilder { - mock.and(header("authorization", format!("Bearer {}", self.expected_access_token))) - } -} +pub struct WhoAmIEndpoint; impl<'a> MockEndpoint<'a, WhoAmIEndpoint> { - /// Override the access token to expect for this endpoint. - pub fn expected_access_token(mut self, access_token: &'static str) -> Self { - self.endpoint.expected_access_token = access_token; - self - } - /// Returns a successful response with a user ID and device ID. - pub fn ok(mut self) -> MatrixMock<'a> { - self.mock = self.endpoint.add_access_token_matcher(self.mock); + pub fn ok(self) -> MatrixMock<'a> { self.respond_with(ResponseTemplate::new(200).set_body_json(json!({ "user_id": "@joe:example.org", "device_id": "D3V1C31D", @@ -2404,8 +2412,7 @@ impl<'a> MockEndpoint<'a, WhoAmIEndpoint> { } /// Returns an error response with an `M_UNKNOWN_TOKEN`. - pub fn err_unknown_token(mut self) -> MatrixMock<'a> { - self.mock = self.endpoint.add_access_token_matcher(self.mock); + pub fn err_unknown_token(self) -> MatrixMock<'a> { self.respond_with(ResponseTemplate::new(401).set_body_json(json!({ "errcode": "M_UNKNOWN_TOKEN", "error": "Invalid token" diff --git a/crates/matrix-sdk/tests/integration/refresh_token.rs b/crates/matrix-sdk/tests/integration/refresh_token.rs index 1972c0b55..58c251287 100644 --- a/crates/matrix-sdk/tests/integration/refresh_token.rs +++ b/crates/matrix-sdk/tests/integration/refresh_token.rs @@ -567,7 +567,7 @@ async fn test_oauth_refresh_token_handled_success() { // Return an error first so the token is refreshed. server .mock_who_am_i() - .expected_access_token("prev-access-token") + .expect_access_token("prev-access-token") .err_unknown_token() .expect(1) .named("whoami_unknown_token") @@ -619,7 +619,7 @@ async fn test_oauth_refresh_token_handled_failure() { // Return an error first so the token is refreshed. server .mock_who_am_i() - .expected_access_token("prev-access-token") + .expect_access_token("prev-access-token") .err_unknown_token() .expect(1) .named("whoami_unknown_token")