summary refs log tree commit diff
path: root/tests/rest
diff options
context:
space:
mode:
authorOlivier Wilkinson (reivilibre) <oliverw@matrix.org>2023-02-21 14:47:40 +0000
committerOlivier Wilkinson (reivilibre) <oliverw@matrix.org>2023-02-21 14:47:40 +0000
commite0f9a514c61c5458aca122026a3bfab3ec4ccf05 (patch)
tree20d2597633a94c48bc71375590b0cf06afd68486 /tests/rest
parentUse changelog from release branch (diff)
parent1.78.0rc1 (diff)
downloadsynapse-e0f9a514c61c5458aca122026a3bfab3ec4ccf05.tar.xz
Merge branch 'release-v1.78' into matrix-org-hotfixes
Diffstat (limited to 'tests/rest')
-rw-r--r--tests/rest/admin/test_media.py9
-rw-r--r--tests/rest/admin/test_server_notice.py4
-rw-r--r--tests/rest/admin/test_user.py9
-rw-r--r--tests/rest/admin/test_username_available.py15
-rw-r--r--tests/rest/client/test_account.py2
-rw-r--r--tests/rest/client/test_auth.py17
-rw-r--r--tests/rest/client/test_filter.py4
-rw-r--r--tests/rest/client/test_presence.py10
-rw-r--r--tests/rest/client/test_register.py7
-rw-r--r--tests/rest/client/test_report_event.py12
-rw-r--r--tests/rest/client/test_retention.py6
-rw-r--r--tests/rest/client/test_rooms.py12
-rw-r--r--tests/rest/client/test_shadow_banned.py6
-rw-r--r--tests/rest/client/test_third_party_rules.py2
-rw-r--r--tests/rest/client/test_upgrade_room.py2
-rw-r--r--tests/rest/client/utils.py58
-rw-r--r--tests/rest/media/v1/test_media_storage.py49
17 files changed, 142 insertions, 82 deletions
diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py

index aadb31ca83..db77a45ae3 100644 --- a/tests/rest/admin/test_media.py +++ b/tests/rest/admin/test_media.py
@@ -213,7 +213,8 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.admin_user_tok = self.login("admin", "pass") self.filepaths = MediaFilePaths(hs.config.media.media_store_path) - self.url = "/_synapse/admin/v1/media/%s/delete" % self.server_name + self.url = "/_synapse/admin/v1/media/delete" + self.legacy_url = "/_synapse/admin/v1/media/%s/delete" % self.server_name # Move clock up to somewhat realistic time self.reactor.advance(1000000000) @@ -332,11 +333,13 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): channel.json_body["error"], ) - def test_delete_media_never_accessed(self) -> None: + @parameterized.expand([(True,), (False,)]) + def test_delete_media_never_accessed(self, use_legacy_url: bool) -> None: """ Tests that media deleted if it is older than `before_ts` and never accessed `last_access_ts` is `NULL` and `created_ts` < `before_ts` """ + url = self.legacy_url if use_legacy_url else self.url # upload and do not access server_and_media_id = self._create_media() @@ -351,7 +354,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): now_ms = self.clock.time_msec() channel = self.make_request( "POST", - self.url + "?before_ts=" + str(now_ms), + url + "?before_ts=" + str(now_ms), access_token=self.admin_user_tok, ) self.assertEqual(200, channel.code, msg=channel.json_body) diff --git a/tests/rest/admin/test_server_notice.py b/tests/rest/admin/test_server_notice.py
index a2f347f666..f71ff46d87 100644 --- a/tests/rest/admin/test_server_notice.py +++ b/tests/rest/admin/test_server_notice.py
@@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List +from typing import List, Sequence from twisted.test.proto_helpers import MemoryReactor @@ -558,7 +558,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): def _check_invite_and_join_status( self, user_id: str, expected_invites: int, expected_memberships: int - ) -> List[RoomsForUser]: + ) -> Sequence[RoomsForUser]: """Check invite and room membership status of a user. Args diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 5c1ced355f..f5b213219f 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py
@@ -2913,7 +2913,8 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): other_user_tok = self.login("user", "pass") event_builder_factory = self.hs.get_event_builder_factory() event_creation_handler = self.hs.get_event_creation_handler() - storage_controllers = self.hs.get_storage_controllers() + persistence = self.hs.get_storage_controllers().persistence + assert persistence is not None # Create two rooms, one with a local user only and one with both a local # and remote user. @@ -2934,11 +2935,13 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): }, ) - event, context = self.get_success( + event, unpersisted_context = self.get_success( event_creation_handler.create_new_client_event(builder) ) - self.get_success(storage_controllers.persistence.persist_event(event, context)) + context = self.get_success(unpersisted_context.persist(event)) + + self.get_success(persistence.persist_event(event, context)) # Now get rooms url = "/_synapse/admin/v1/users/@joiner:remote_hs/joined_rooms" diff --git a/tests/rest/admin/test_username_available.py b/tests/rest/admin/test_username_available.py
index 30f12f1bff..6c04e6c56c 100644 --- a/tests/rest/admin/test_username_available.py +++ b/tests/rest/admin/test_username_available.py
@@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional + from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin @@ -33,9 +35,14 @@ class UsernameAvailableTestCase(unittest.HomeserverTestCase): self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") - async def check_username(username: str) -> bool: - if username == "allowed": - return True + async def check_username( + localpart: str, + guest_access_token: Optional[str] = None, + assigned_user_id: Optional[str] = None, + inhibit_user_in_use_error: bool = False, + ) -> None: + if localpart == "allowed": + return raise SynapseError( 400, "User ID already taken.", @@ -43,7 +50,7 @@ class UsernameAvailableTestCase(unittest.HomeserverTestCase): ) handler = self.hs.get_registration_handler() - handler.check_username = check_username + handler.check_username = check_username # type: ignore[assignment] def test_username_available(self) -> None: """ diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py
index 88f255c9ee..e2ee1a1766 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py
@@ -1193,7 +1193,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase): return {} # Register a mock that will return the expected result depending on the remote. - self.hs.get_federation_http_client().post_json = Mock(side_effect=post_json) + self.hs.get_federation_http_client().post_json = Mock(side_effect=post_json) # type: ignore[assignment] # Check that we've got the correct response from the client-side endpoint. self._test_status( diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py
index 208ec44829..a144610078 100644 --- a/tests/rest/client/test_auth.py +++ b/tests/rest/client/test_auth.py
@@ -34,7 +34,7 @@ from synapse.util import Clock from tests import unittest from tests.handlers.test_oidc import HAS_OIDC from tests.rest.client.utils import TEST_OIDC_CONFIG, TEST_OIDC_ISSUER -from tests.server import FakeChannel, make_request +from tests.server import FakeChannel from tests.unittest import override_config, skip_unless @@ -43,6 +43,9 @@ class DummyRecaptchaChecker(UserInteractiveAuthChecker): super().__init__(hs) self.recaptcha_attempts: List[Tuple[dict, str]] = [] + def is_enabled(self) -> bool: + return True + def check_auth(self, authdict: dict, clientip: str) -> Any: self.recaptcha_attempts.append((authdict, clientip)) return succeed(True) @@ -1319,16 +1322,8 @@ class OidcBackchannelLogoutTests(unittest.HomeserverTestCase): channel = self.submit_logout_token(logout_token) self.assertEqual(channel.code, 200) - # Now try to exchange the login token - channel = make_request( - self.hs.get_reactor(), - self.site, - "POST", - "/login", - content={"type": "m.login.token", "token": login_token}, - ) - # It should have failed - self.assertEqual(channel.code, 403) + # Now try to exchange the login token, it should fail. + self.helper.login_via_token(login_token, 403) @override_config( { diff --git a/tests/rest/client/test_filter.py b/tests/rest/client/test_filter.py
index afc8d641be..830762fd53 100644 --- a/tests/rest/client/test_filter.py +++ b/tests/rest/client/test_filter.py
@@ -63,14 +63,14 @@ class FilterTestCase(unittest.HomeserverTestCase): def test_add_filter_non_local_user(self) -> None: _is_mine = self.hs.is_mine - self.hs.is_mine = lambda target_user: False + self.hs.is_mine = lambda target_user: False # type: ignore[assignment] channel = self.make_request( "POST", "/_matrix/client/r0/user/%s/filter" % (self.user_id), self.EXAMPLE_FILTER_JSON, ) - self.hs.is_mine = _is_mine + self.hs.is_mine = _is_mine # type: ignore[assignment] self.assertEqual(channel.code, 403) self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN) diff --git a/tests/rest/client/test_presence.py b/tests/rest/client/test_presence.py
index b3738a0304..67e16880e6 100644 --- a/tests/rest/client/test_presence.py +++ b/tests/rest/client/test_presence.py
@@ -36,14 +36,14 @@ class PresenceTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - presence_handler = Mock(spec=PresenceHandler) - presence_handler.set_state.return_value = make_awaitable(None) + self.presence_handler = Mock(spec=PresenceHandler) + self.presence_handler.set_state.return_value = make_awaitable(None) hs = self.setup_test_homeserver( "red", federation_http_client=None, federation_client=Mock(), - presence_handler=presence_handler, + presence_handler=self.presence_handler, ) return hs @@ -61,7 +61,7 @@ class PresenceTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, HTTPStatus.OK) - self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 1) + self.assertEqual(self.presence_handler.set_state.call_count, 1) @unittest.override_config({"use_presence": False}) def test_put_presence_disabled(self) -> None: @@ -76,4 +76,4 @@ class PresenceTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, HTTPStatus.OK) - self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 0) + self.assertEqual(self.presence_handler.set_state.call_count, 0) diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py
index 11cf3939d8..4c561f9525 100644 --- a/tests/rest/client/test_register.py +++ b/tests/rest/client/test_register.py
@@ -151,7 +151,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") def test_POST_guest_registration(self) -> None: - self.hs.config.key.macaroon_secret_key = "test" + self.hs.config.key.macaroon_secret_key = b"test" self.hs.config.registration.allow_guest_access = True channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") @@ -1166,12 +1166,15 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase): """ user_id = self.register_user("kermit_delta", "user") - self.hs.config.account_validity.startup_job_max_delta = self.max_delta + self.hs.config.account_validity.account_validity_startup_job_max_delta = ( + self.max_delta + ) now_ms = self.hs.get_clock().time_msec() self.get_success(self.store._set_expiration_date_when_missing()) res = self.get_success(self.store.get_expiration_ts_for_user(user_id)) + assert res is not None self.assertGreaterEqual(res, now_ms + self.validity_period - self.max_delta) self.assertLessEqual(res, now_ms + self.validity_period) diff --git a/tests/rest/client/test_report_event.py b/tests/rest/client/test_report_event.py
index 7cb1017a4a..1250685d39 100644 --- a/tests/rest/client/test_report_event.py +++ b/tests/rest/client/test_report_event.py
@@ -73,6 +73,18 @@ class ReportEventTestCase(unittest.HomeserverTestCase): data = {"reason": None, "score": None} self._assert_status(400, data) + def test_cannot_report_nonexistent_event(self) -> None: + """ + Tests that we don't accept event reports for events which do not exist. + """ + channel = self.make_request( + "POST", + f"rooms/{self.room_id}/report/$nonsenseeventid:test", + {"reason": "i am very sad"}, + access_token=self.other_user_tok, + ) + self.assertEqual(404, channel.code, msg=channel.result["body"]) + def _assert_status(self, response_status: int, data: JsonDict) -> None: channel = self.make_request( "POST", self.report_path, data, access_token=self.other_user_tok diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index 9c8c1889d3..d3e06bf6b3 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py
@@ -136,6 +136,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): # Send a first event, which should be filtered out at the end of the test. resp = self.helper.send(room_id=room_id, body="1", tok=self.token) first_event_id = resp.get("event_id") + assert isinstance(first_event_id, str) # Advance the time by 2 days. We're using the default retention policy, therefore # after this the first event will still be valid. @@ -144,6 +145,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): # Send another event, which shouldn't get filtered out. resp = self.helper.send(room_id=room_id, body="2", tok=self.token) valid_event_id = resp.get("event_id") + assert isinstance(valid_event_id, str) # Advance the time by another 2 days. After this, the first event should be # outdated but not the second one. @@ -229,7 +231,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): # Check that we can still access state events that were sent before the event that # has been purged. - self.get_event(room_id, create_event.event_id) + self.get_event(room_id, bool(create_event)) def get_event(self, event_id: str, expect_none: bool = False) -> JsonDict: event = self.get_success(self.store.get_event(event_id, allow_none=True)) @@ -238,7 +240,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): self.assertIsNone(event) return {} - self.assertIsNotNone(event) + assert event is not None time_now = self.clock.time_msec() serialized = self.serializer.serialize_event(event, time_now) diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index 9222cab198..cfad182b2f 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py
@@ -3382,8 +3382,8 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase): # a remote IS. We keep the mock for make_and_store_3pid_invite around so we # can check its call_count later on during the test. make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0))) - self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock - self.hs.get_identity_handler().lookup_3pid = Mock( + self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock # type: ignore[assignment] + self.hs.get_identity_handler().lookup_3pid = Mock( # type: ignore[assignment] return_value=make_awaitable(None), ) @@ -3443,8 +3443,8 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase): # a remote IS. We keep the mock for make_and_store_3pid_invite around so we # can check its call_count later on during the test. make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0))) - self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock - self.hs.get_identity_handler().lookup_3pid = Mock( + self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock # type: ignore[assignment] + self.hs.get_identity_handler().lookup_3pid = Mock( # type: ignore[assignment] return_value=make_awaitable(None), ) @@ -3563,8 +3563,10 @@ class TimestampLookupTestCase(unittest.HomeserverTestCase): ) event.internal_metadata.outlier = True + persistence = self._storage_controllers.persistence + assert persistence is not None self.get_success( - self._storage_controllers.persistence.persist_event( + persistence.persist_event( event, EventContext.for_outlier(self._storage_controllers) ) ) diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py
index c807a37bc2..8d2cdf8751 100644 --- a/tests/rest/client/test_shadow_banned.py +++ b/tests/rest/client/test_shadow_banned.py
@@ -84,7 +84,7 @@ class RoomTestCase(_ShadowBannedBase): def test_invite_3pid(self) -> None: """Ensure that a 3PID invite does not attempt to contact the identity server.""" identity_handler = self.hs.get_identity_handler() - identity_handler.lookup_3pid = Mock( + identity_handler.lookup_3pid = Mock( # type: ignore[assignment] side_effect=AssertionError("This should not get called") ) @@ -222,7 +222,7 @@ class RoomTestCase(_ShadowBannedBase): event_source.get_new_events( user=UserID.from_string(self.other_user_id), from_key=0, - limit=None, + limit=10, room_ids=[room_id], is_guest=False, ) @@ -286,6 +286,7 @@ class ProfileTestCase(_ShadowBannedBase): self.banned_user_id, ) ) + assert event is not None self.assertEqual( event.content, {"membership": "join", "displayname": original_display_name} ) @@ -321,6 +322,7 @@ class ProfileTestCase(_ShadowBannedBase): self.banned_user_id, ) ) + assert event is not None self.assertEqual( event.content, {"membership": "join", "displayname": original_display_name} ) diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index 3325d43a2f..5fa3440691 100644 --- a/tests/rest/client/test_third_party_rules.py +++ b/tests/rest/client/test_third_party_rules.py
@@ -425,7 +425,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): async def test_fn( event: EventBase, state_events: StateMap[EventBase] ) -> Tuple[bool, Optional[JsonDict]]: - if event.is_state and event.type == EventTypes.PowerLevels: + if event.is_state() and event.type == EventTypes.PowerLevels: await api.create_and_send_event_into_room( { "room_id": event.room_id, diff --git a/tests/rest/client/test_upgrade_room.py b/tests/rest/client/test_upgrade_room.py
index 5ec343dd7f..0b4c691318 100644 --- a/tests/rest/client/test_upgrade_room.py +++ b/tests/rest/client/test_upgrade_room.py
@@ -84,7 +84,7 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): self.room_id, EventTypes.Tombstone, "" ) ) - self.assertIsNotNone(tombstone_event) + assert tombstone_event is not None self.assertEqual(new_room_id, tombstone_event.content["replacement_room"]) # Check that the new room exists. diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index 8d6f2b6ff9..9532e5ddc1 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py
@@ -36,6 +36,7 @@ from urllib.parse import urlencode import attr from typing_extensions import Literal +from twisted.test.proto_helpers import MemoryReactorClock from twisted.web.resource import Resource from twisted.web.server import Site @@ -67,6 +68,7 @@ class RestHelper: """ hs: HomeServer + reactor: MemoryReactorClock site: Site auth_user_id: Optional[str] @@ -142,7 +144,7 @@ class RestHelper: path = path + "?access_token=%s" % tok channel = make_request( - self.hs.get_reactor(), + self.reactor, self.site, "POST", path, @@ -216,7 +218,7 @@ class RestHelper: data["reason"] = reason channel = make_request( - self.hs.get_reactor(), + self.reactor, self.site, "POST", path, @@ -313,7 +315,7 @@ class RestHelper: data.update(extra_data or {}) channel = make_request( - self.hs.get_reactor(), + self.reactor, self.site, "PUT", path, @@ -394,7 +396,7 @@ class RestHelper: path = path + "?access_token=%s" % tok channel = make_request( - self.hs.get_reactor(), + self.reactor, self.site, "PUT", path, @@ -433,7 +435,7 @@ class RestHelper: path = path + f"?access_token={tok}" channel = make_request( - self.hs.get_reactor(), + self.reactor, self.site, "GET", path, @@ -488,7 +490,7 @@ class RestHelper: if body is not None: content = json.dumps(body).encode("utf8") - channel = make_request(self.hs.get_reactor(), self.site, method, path, content) + channel = make_request(self.reactor, self.site, method, path, content) assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % ( expect_code, @@ -573,8 +575,8 @@ class RestHelper: image_length = len(image_data) path = "/_matrix/media/r0/upload?filename=%s" % (filename,) channel = make_request( - self.hs.get_reactor(), - FakeSite(resource, self.hs.get_reactor()), + self.reactor, + FakeSite(resource, self.reactor), "POST", path, content=image_data, @@ -603,7 +605,7 @@ class RestHelper: expect_code: The return code to expect from attempting the whoami request """ channel = make_request( - self.hs.get_reactor(), + self.reactor, self.site, "GET", "account/whoami", @@ -642,7 +644,7 @@ class RestHelper: ) -> Tuple[JsonDict, FakeAuthorizationGrant]: """Log in (as a new user) via OIDC - Returns the result of the final token login. + Returns the result of the final token login and the fake authorization grant. Requires that "oidc_config" in the homeserver config be set appropriately (TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a @@ -672,10 +674,28 @@ class RestHelper: assert m, channel.text_body login_token = m.group(1) - # finally, submit the matrix login token to the login API, which gives us our - # matrix access token and device id. + return self.login_via_token(login_token, expected_status), grant + + def login_via_token( + self, + login_token: str, + expected_status: int = 200, + ) -> JsonDict: + """Submit the matrix login token to the login API, which gives us our + matrix access token and device id.Log in (as a new user) via OIDC + + Returns the result of the token login. + + Requires that "oidc_config" in the homeserver config be set appropriately + (TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a + "public_base_url". + + Also requires the login servlet and the OIDC callback resource to be mounted at + the normal places. + """ + channel = make_request( - self.hs.get_reactor(), + self.reactor, self.site, "POST", "/login", @@ -684,7 +704,7 @@ class RestHelper: assert ( channel.code == expected_status ), f"unexpected status in response: {channel.code}" - return channel.json_body, grant + return channel.json_body def auth_via_oidc( self, @@ -805,7 +825,7 @@ class RestHelper: with fake_serer.patch_homeserver(hs=self.hs): # now hit the callback URI with the right params and a made-up code channel = make_request( - self.hs.get_reactor(), + self.reactor, self.site, "GET", callback_uri, @@ -849,7 +869,7 @@ class RestHelper: # is the easiest way of figuring out what the Host header ought to be set to # to keep Synapse happy. channel = make_request( - self.hs.get_reactor(), + self.reactor, self.site, "GET", uri, @@ -867,7 +887,7 @@ class RestHelper: location = get_location(channel) parts = urllib.parse.urlsplit(location) channel = make_request( - self.hs.get_reactor(), + self.reactor, self.site, "GET", urllib.parse.urlunsplit(("", "") + parts[2:]), @@ -900,9 +920,7 @@ class RestHelper: + urllib.parse.urlencode({"session": ui_auth_session_id}) ) # hit the redirect url (which will issue a cookie and state) - channel = make_request( - self.hs.get_reactor(), self.site, "GET", sso_redirect_endpoint - ) + channel = make_request(self.reactor, self.site, "GET", sso_redirect_endpoint) # that should serve a confirmation page assert channel.code == HTTPStatus.OK, channel.text_body channel.extract_cookies(cookies) diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index d18fc13c21..17a3b06a8e 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py
@@ -16,7 +16,7 @@ import shutil import tempfile from binascii import unhexlify from io import BytesIO -from typing import Any, BinaryIO, Dict, List, Optional, Union +from typing import Any, BinaryIO, ClassVar, Dict, List, Optional, Tuple, Union from unittest.mock import Mock from urllib import parse @@ -32,6 +32,7 @@ from twisted.test.proto_helpers import MemoryReactor from synapse.api.errors import Codes from synapse.events import EventBase from synapse.events.spamcheck import load_legacy_spam_checkers +from synapse.http.types import QueryParams from synapse.logging.context import make_deferred_yieldable from synapse.module_api import ModuleApi from synapse.rest import admin @@ -41,7 +42,7 @@ from synapse.rest.media.v1.filepath import MediaFilePaths from synapse.rest.media.v1.media_storage import MediaStorage, ReadableFileWrapper from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend from synapse.server import HomeServer -from synapse.types import RoomAlias +from synapse.types import JsonDict, RoomAlias from synapse.util import Clock from tests import unittest @@ -201,36 +202,46 @@ class _TestImage: ], ) class MediaRepoTests(unittest.HomeserverTestCase): - + test_image: ClassVar[_TestImage] hijack_auth = True user_id = "@test:user" def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - self.fetches = [] + self.fetches: List[ + Tuple[ + "Deferred[Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]]", + str, + str, + Optional[QueryParams], + ] + ] = [] def get_file( destination: str, path: str, output_stream: BinaryIO, - args: Optional[Dict[str, Union[str, List[str]]]] = None, + args: Optional[QueryParams] = None, + retry_on_dns_fail: bool = True, max_size: Optional[int] = None, - ) -> Deferred: - """ - Returns tuple[int,dict,str,int] of file length, response headers, - absolute URI, and response code. - """ + ignore_backoff: bool = False, + ) -> "Deferred[Tuple[int, Dict[bytes, List[bytes]]]]": + """A mock for MatrixFederationHttpClient.get_file.""" - def write_to(r): + def write_to( + r: Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]] + ) -> Tuple[int, Dict[bytes, List[bytes]]]: data, response = r output_stream.write(data) return response - d = Deferred() - d.addCallback(write_to) + d: Deferred[Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]] = Deferred() self.fetches.append((d, destination, path, args)) - return make_deferred_yieldable(d) + # Note that this callback changes the value held by d. + d_after_callback = d.addCallback(write_to) + return make_deferred_yieldable(d_after_callback) + # Mock out the homeserver's MatrixFederationHttpClient client = Mock() client.get_file = get_file @@ -461,6 +472,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): # Synapse should regenerate missing thumbnails. origin, media_id = self.media_id.split("/") info = self.get_success(self.store.get_cached_remote_media(origin, media_id)) + assert info is not None file_id = info["filesystem_id"] thumbnail_dir = self.media_repo.filepaths.remote_media_thumbnail_dir( @@ -581,7 +593,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): "thumbnail_method": method, "thumbnail_type": self.test_image.content_type, "thumbnail_length": 256, - "filesystem_id": f"thumbnail1{self.test_image.extension}", + "filesystem_id": f"thumbnail1{self.test_image.extension.decode()}", }, { "thumbnail_width": 32, @@ -589,10 +601,10 @@ class MediaRepoTests(unittest.HomeserverTestCase): "thumbnail_method": method, "thumbnail_type": self.test_image.content_type, "thumbnail_length": 256, - "filesystem_id": f"thumbnail2{self.test_image.extension}", + "filesystem_id": f"thumbnail2{self.test_image.extension.decode()}", }, ], - file_id=f"image{self.test_image.extension}", + file_id=f"image{self.test_image.extension.decode()}", url_cache=None, server_name=None, ) @@ -637,6 +649,7 @@ class TestSpamCheckerLegacy: self.config = config self.api = api + @staticmethod def parse_config(config: Dict[str, Any]) -> Dict[str, Any]: return config @@ -748,7 +761,7 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase): async def check_media_file_for_spam( self, file_wrapper: ReadableFileWrapper, file_info: FileInfo - ) -> Union[Codes, Literal["NOT_SPAM"]]: + ) -> Union[Codes, Literal["NOT_SPAM"], Tuple[Codes, JsonDict]]: buf = BytesIO() await file_wrapper.write_chunks_to(buf.write)