summary refs log tree commit diff
path: root/tests/rest
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest')
-rw-r--r--tests/rest/admin/test_device.py3
-rw-r--r--tests/rest/admin/test_event_reports.py143
-rw-r--r--tests/rest/admin/test_media.py16
-rw-r--r--tests/rest/admin/test_room.py1
-rw-r--r--tests/rest/admin/test_server_notice.py5
-rw-r--r--tests/rest/admin/test_user.py11
-rw-r--r--tests/rest/admin/test_username_available.py15
-rw-r--r--tests/rest/client/test_account.py6
-rw-r--r--tests/rest/client/test_auth.py19
-rw-r--r--tests/rest/client/test_capabilities.py1
-rw-r--r--tests/rest/client/test_consent.py1
-rw-r--r--tests/rest/client/test_directory.py1
-rw-r--r--tests/rest/client/test_ephemeral_message.py1
-rw-r--r--tests/rest/client/test_events.py3
-rw-r--r--tests/rest/client/test_filter.py5
-rw-r--r--tests/rest/client/test_keys.py141
-rw-r--r--tests/rest/client/test_login.py2
-rw-r--r--tests/rest/client/test_login_token_request.py1
-rw-r--r--tests/rest/client/test_presence.py11
-rw-r--r--tests/rest/client/test_profile.py3
-rw-r--r--tests/rest/client/test_register.py11
-rw-r--r--tests/rest/client/test_relations.py237
-rw-r--r--tests/rest/client/test_rendezvous.py1
-rw-r--r--tests/rest/client/test_report_event.py12
-rw-r--r--tests/rest/client/test_retention.py6
-rw-r--r--tests/rest/client/test_rooms.py30
-rw-r--r--tests/rest/client/test_shadow_banned.py6
-rw-r--r--tests/rest/client/test_sync.py3
-rw-r--r--tests/rest/client/test_third_party_rules.py126
-rw-r--r--tests/rest/client/test_transactions.py55
-rw-r--r--tests/rest/client/test_upgrade_room.py2
-rw-r--r--tests/rest/client/utils.py58
-rw-r--r--tests/rest/media/test_media_retention.py1
-rw-r--r--tests/rest/media/test_url_preview.py (renamed from tests/rest/media/v1/test_url_preview.py)53
-rw-r--r--tests/rest/media/v1/__init__.py13
-rw-r--r--tests/rest/media/v1/test_base.py38
-rw-r--r--tests/rest/media/v1/test_filepath.py595
-rw-r--r--tests/rest/media/v1/test_html_preview.py542
-rw-r--r--tests/rest/media/v1/test_media_storage.py782
-rw-r--r--tests/rest/media/v1/test_oembed.py162
40 files changed, 647 insertions, 2475 deletions
diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py

index 03f2112b07..aaa488bced 100644 --- a/tests/rest/admin/test_device.py +++ b/tests/rest/admin/test_device.py
@@ -28,7 +28,6 @@ from tests import unittest class DeviceRestTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, @@ -291,7 +290,6 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): class DevicesRestTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, @@ -415,7 +413,6 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py
index 233eba3516..f189b07769 100644 --- a/tests/rest/admin/test_event_reports.py +++ b/tests/rest/admin/test_event_reports.py
@@ -78,7 +78,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): """ Try to get an event report without authentication. """ - channel = self.make_request("GET", self.url, b"{}") + channel = self.make_request("GET", self.url, {}) self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @@ -473,7 +473,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): """ Try to get event report without authentication. """ - channel = self.make_request("GET", self.url, b"{}") + channel = self.make_request("GET", self.url, {}) self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @@ -599,3 +599,142 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): self.assertIn("room_id", content["event_json"]) self.assertIn("sender", content["event_json"]) self.assertIn("content", content["event_json"]) + + +class DeleteEventReportTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self._store = hs.get_datastores().main + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.other_user_tok = self.login("user", "pass") + + # create report + event_id = self.get_success( + self._store.add_event_report( + "room_id", + "event_id", + self.other_user, + "this makes me sad", + {}, + self.clock.time_msec(), + ) + ) + + self.url = f"/_synapse/admin/v1/event_reports/{event_id}" + + def test_no_auth(self) -> None: + """ + Try to delete event report without authentication. + """ + channel = self.make_request("DELETE", self.url) + + self.assertEqual(401, channel.code, msg=channel.json_body) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_requester_is_no_admin(self) -> None: + """ + If the user is not a server admin, an error 403 is returned. + """ + + channel = self.make_request( + "DELETE", + self.url, + access_token=self.other_user_tok, + ) + + self.assertEqual(403, channel.code, msg=channel.json_body) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_delete_success(self) -> None: + """ + Testing delete a report. + """ + + channel = self.make_request( + "DELETE", + self.url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual({}, channel.json_body) + + channel = self.make_request( + "GET", + self.url, + access_token=self.admin_user_tok, + ) + + # check that report was deleted + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + def test_invalid_report_id(self) -> None: + """ + Testing that an invalid `report_id` returns a 400. + """ + + # `report_id` is negative + channel = self.make_request( + "DELETE", + "/_synapse/admin/v1/event_reports/-123", + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + self.assertEqual( + "The report_id parameter must be a string representing a positive integer.", + channel.json_body["error"], + ) + + # `report_id` is a non-numerical string + channel = self.make_request( + "DELETE", + "/_synapse/admin/v1/event_reports/abcdef", + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + self.assertEqual( + "The report_id parameter must be a string representing a positive integer.", + channel.json_body["error"], + ) + + # `report_id` is undefined + channel = self.make_request( + "DELETE", + "/_synapse/admin/v1/event_reports/", + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + self.assertEqual( + "The report_id parameter must be a string representing a positive integer.", + channel.json_body["error"], + ) + + def test_report_id_not_found(self) -> None: + """ + Testing that a not existing `report_id` returns a 404. + """ + + channel = self.make_request( + "DELETE", + "/_synapse/admin/v1/event_reports/123", + access_token=self.admin_user_tok, + ) + + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + self.assertEqual("Event report not found", channel.json_body["error"]) diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py
index aadb31ca83..6d04911d67 100644 --- a/tests/rest/admin/test_media.py +++ b/tests/rest/admin/test_media.py
@@ -20,8 +20,8 @@ from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin from synapse.api.errors import Codes +from synapse.media.filepath import MediaFilePaths from synapse.rest.client import login, profile, room -from synapse.rest.media.v1.filepath import MediaFilePaths from synapse.server import HomeServer from synapse.util import Clock @@ -34,7 +34,6 @@ INVALID_TIMESTAMP_IN_S = 1893456000 # 2030-01-01 in seconds class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, synapse.rest.admin.register_servlets_for_media_repo, @@ -196,7 +195,6 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, synapse.rest.admin.register_servlets_for_media_repo, @@ -213,7 +211,8 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.admin_user_tok = self.login("admin", "pass") self.filepaths = MediaFilePaths(hs.config.media.media_store_path) - self.url = "/_synapse/admin/v1/media/%s/delete" % self.server_name + self.url = "/_synapse/admin/v1/media/delete" + self.legacy_url = "/_synapse/admin/v1/media/%s/delete" % self.server_name # Move clock up to somewhat realistic time self.reactor.advance(1000000000) @@ -332,11 +331,13 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): channel.json_body["error"], ) - def test_delete_media_never_accessed(self) -> None: + @parameterized.expand([(True,), (False,)]) + def test_delete_media_never_accessed(self, use_legacy_url: bool) -> None: """ Tests that media deleted if it is older than `before_ts` and never accessed `last_access_ts` is `NULL` and `created_ts` < `before_ts` """ + url = self.legacy_url if use_legacy_url else self.url # upload and do not access server_and_media_id = self._create_media() @@ -351,7 +352,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): now_ms = self.clock.time_msec() channel = self.make_request( "POST", - self.url + "?before_ts=" + str(now_ms), + url + "?before_ts=" + str(now_ms), access_token=self.admin_user_tok, ) self.assertEqual(200, channel.code, msg=channel.json_body) @@ -591,7 +592,6 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, synapse.rest.admin.register_servlets_for_media_repo, @@ -721,7 +721,6 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, synapse.rest.admin.register_servlets_for_media_repo, @@ -818,7 +817,6 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): class PurgeMediaCacheTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, synapse.rest.admin.register_servlets_for_media_repo, diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index 453a6e979c..9dbb778679 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py
@@ -1990,7 +1990,6 @@ class RoomMessagesTestCase(unittest.HomeserverTestCase): class JoinAliasRoomTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, room.register_servlets, diff --git a/tests/rest/admin/test_server_notice.py b/tests/rest/admin/test_server_notice.py
index a2f347f666..28b999573e 100644 --- a/tests/rest/admin/test_server_notice.py +++ b/tests/rest/admin/test_server_notice.py
@@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List +from typing import List, Sequence from twisted.test.proto_helpers import MemoryReactor @@ -28,7 +28,6 @@ from tests.unittest import override_config class ServerNoticeTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, @@ -558,7 +557,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): def _check_invite_and_join_status( self, user_id: str, expected_invites: int, expected_memberships: int - ) -> List[RoomsForUser]: + ) -> Sequence[RoomsForUser]: """Check invite and room membership status of a user. Args diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 5c1ced355f..4b8f889a71 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py
@@ -28,8 +28,8 @@ import synapse.rest.admin from synapse.api.constants import ApprovalNoticeMedium, LoginType, UserTypes from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError from synapse.api.room_versions import RoomVersions +from synapse.media.filepath import MediaFilePaths from synapse.rest.client import devices, login, logout, profile, register, room, sync -from synapse.rest.media.v1.filepath import MediaFilePaths from synapse.server import HomeServer from synapse.types import JsonDict, UserID, create_requester from synapse.util import Clock @@ -2913,7 +2913,8 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): other_user_tok = self.login("user", "pass") event_builder_factory = self.hs.get_event_builder_factory() event_creation_handler = self.hs.get_event_creation_handler() - storage_controllers = self.hs.get_storage_controllers() + persistence = self.hs.get_storage_controllers().persistence + assert persistence is not None # Create two rooms, one with a local user only and one with both a local # and remote user. @@ -2934,11 +2935,13 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): }, ) - event, context = self.get_success( + event, unpersisted_context = self.get_success( event_creation_handler.create_new_client_event(builder) ) - self.get_success(storage_controllers.persistence.persist_event(event, context)) + context = self.get_success(unpersisted_context.persist(event)) + + self.get_success(persistence.persist_event(event, context)) # Now get rooms url = "/_synapse/admin/v1/users/@joiner:remote_hs/joined_rooms" diff --git a/tests/rest/admin/test_username_available.py b/tests/rest/admin/test_username_available.py
index 30f12f1bff..6c04e6c56c 100644 --- a/tests/rest/admin/test_username_available.py +++ b/tests/rest/admin/test_username_available.py
@@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional + from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin @@ -33,9 +35,14 @@ class UsernameAvailableTestCase(unittest.HomeserverTestCase): self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") - async def check_username(username: str) -> bool: - if username == "allowed": - return True + async def check_username( + localpart: str, + guest_access_token: Optional[str] = None, + assigned_user_id: Optional[str] = None, + inhibit_user_in_use_error: bool = False, + ) -> None: + if localpart == "allowed": + return raise SynapseError( 400, "User ID already taken.", @@ -43,7 +50,7 @@ class UsernameAvailableTestCase(unittest.HomeserverTestCase): ) handler = self.hs.get_registration_handler() - handler.check_username = check_username + handler.check_username = check_username # type: ignore[assignment] def test_username_available(self) -> None: """ diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py
index 88f255c9ee..2b05dffc7d 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py
@@ -40,7 +40,6 @@ from tests.unittest import override_config class PasswordResetTestCase(unittest.HomeserverTestCase): - servlets = [ account.register_servlets, synapse.rest.admin.register_servlets_for_client_rest_resource, @@ -408,7 +407,6 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): class DeactivateTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, login.register_servlets, @@ -492,7 +490,6 @@ class DeactivateTestCase(unittest.HomeserverTestCase): class WhoamiTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, login.register_servlets, @@ -567,7 +564,6 @@ class WhoamiTestCase(unittest.HomeserverTestCase): class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): - servlets = [ account.register_servlets, login.register_servlets, @@ -1193,7 +1189,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase): return {} # Register a mock that will return the expected result depending on the remote. - self.hs.get_federation_http_client().post_json = Mock(side_effect=post_json) + self.hs.get_federation_http_client().post_json = Mock(side_effect=post_json) # type: ignore[assignment] # Check that we've got the correct response from the client-side endpoint. self._test_status( diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py
index 208ec44829..0d8fe77b88 100644 --- a/tests/rest/client/test_auth.py +++ b/tests/rest/client/test_auth.py
@@ -34,7 +34,7 @@ from synapse.util import Clock from tests import unittest from tests.handlers.test_oidc import HAS_OIDC from tests.rest.client.utils import TEST_OIDC_CONFIG, TEST_OIDC_ISSUER -from tests.server import FakeChannel, make_request +from tests.server import FakeChannel from tests.unittest import override_config, skip_unless @@ -43,13 +43,15 @@ class DummyRecaptchaChecker(UserInteractiveAuthChecker): super().__init__(hs) self.recaptcha_attempts: List[Tuple[dict, str]] = [] + def is_enabled(self) -> bool: + return True + def check_auth(self, authdict: dict, clientip: str) -> Any: self.recaptcha_attempts.append((authdict, clientip)) return succeed(True) class FallbackAuthTests(unittest.HomeserverTestCase): - servlets = [ auth.register_servlets, register.register_servlets, @@ -57,7 +59,6 @@ class FallbackAuthTests(unittest.HomeserverTestCase): hijack_auth = False def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - config = self.default_config() config["enable_registration_captcha"] = True @@ -1319,16 +1320,8 @@ class OidcBackchannelLogoutTests(unittest.HomeserverTestCase): channel = self.submit_logout_token(logout_token) self.assertEqual(channel.code, 200) - # Now try to exchange the login token - channel = make_request( - self.hs.get_reactor(), - self.site, - "POST", - "/login", - content={"type": "m.login.token", "token": login_token}, - ) - # It should have failed - self.assertEqual(channel.code, 403) + # Now try to exchange the login token, it should fail. + self.helper.login_via_token(login_token, 403) @override_config( { diff --git a/tests/rest/client/test_capabilities.py b/tests/rest/client/test_capabilities.py
index d1751e1557..c16e8d43f4 100644 --- a/tests/rest/client/test_capabilities.py +++ b/tests/rest/client/test_capabilities.py
@@ -26,7 +26,6 @@ from tests.unittest import override_config class CapabilitiesTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, capabilities.register_servlets, diff --git a/tests/rest/client/test_consent.py b/tests/rest/client/test_consent.py
index b1ca81a911..bb845179d3 100644 --- a/tests/rest/client/test_consent.py +++ b/tests/rest/client/test_consent.py
@@ -38,7 +38,6 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): hijack_auth = False def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - config = self.default_config() config["form_secret"] = "123abc" diff --git a/tests/rest/client/test_directory.py b/tests/rest/client/test_directory.py
index 7a88aa2cda..6490e883bf 100644 --- a/tests/rest/client/test_directory.py +++ b/tests/rest/client/test_directory.py
@@ -28,7 +28,6 @@ from tests.unittest import override_config class DirectoryTestCase(unittest.HomeserverTestCase): - servlets = [ admin.register_servlets_for_client_rest_resource, directory.register_servlets, diff --git a/tests/rest/client/test_ephemeral_message.py b/tests/rest/client/test_ephemeral_message.py
index 9fa1f82dfe..f31ebc8021 100644 --- a/tests/rest/client/test_ephemeral_message.py +++ b/tests/rest/client/test_ephemeral_message.py
@@ -26,7 +26,6 @@ from tests import unittest class EphemeralMessageTestCase(unittest.HomeserverTestCase): - user_id = "@user:test" servlets = [ diff --git a/tests/rest/client/test_events.py b/tests/rest/client/test_events.py
index a9b7db9db2..54df2a252c 100644 --- a/tests/rest/client/test_events.py +++ b/tests/rest/client/test_events.py
@@ -38,7 +38,6 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): ] def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - config = self.default_config() config["enable_registration_captcha"] = False config["enable_registration"] = True @@ -51,7 +50,6 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): return hs def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - # register an account self.user_id = self.register_user("sid1", "pass") self.token = self.login(self.user_id, "pass") @@ -142,7 +140,6 @@ class GetEventsTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - # register an account self.user_id = self.register_user("sid1", "pass") self.token = self.login(self.user_id, "pass") diff --git a/tests/rest/client/test_filter.py b/tests/rest/client/test_filter.py
index afc8d641be..91678abf13 100644 --- a/tests/rest/client/test_filter.py +++ b/tests/rest/client/test_filter.py
@@ -25,7 +25,6 @@ PATH_PREFIX = "/_matrix/client/v2_alpha" class FilterTestCase(unittest.HomeserverTestCase): - user_id = "@apple:test" hijack_auth = True EXAMPLE_FILTER = {"room": {"timeline": {"types": ["m.room.message"]}}} @@ -63,14 +62,14 @@ class FilterTestCase(unittest.HomeserverTestCase): def test_add_filter_non_local_user(self) -> None: _is_mine = self.hs.is_mine - self.hs.is_mine = lambda target_user: False + self.hs.is_mine = lambda target_user: False # type: ignore[assignment] channel = self.make_request( "POST", "/_matrix/client/r0/user/%s/filter" % (self.user_id), self.EXAMPLE_FILTER_JSON, ) - self.hs.is_mine = _is_mine + self.hs.is_mine = _is_mine # type: ignore[assignment] self.assertEqual(channel.code, 403) self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN) diff --git a/tests/rest/client/test_keys.py b/tests/rest/client/test_keys.py
index 741fecea77..8ee5489057 100644 --- a/tests/rest/client/test_keys.py +++ b/tests/rest/client/test_keys.py
@@ -14,12 +14,21 @@ from http import HTTPStatus +from signedjson.key import ( + encode_verify_key_base64, + generate_signing_key, + get_verify_key, +) +from signedjson.sign import sign_json + from synapse.api.errors import Codes from synapse.rest import admin from synapse.rest.client import keys, login +from synapse.types import JsonDict from tests import unittest from tests.http.server._base import make_request_with_cancellation_test +from tests.unittest import override_config class KeyQueryTestCase(unittest.HomeserverTestCase): @@ -118,3 +127,135 @@ class KeyQueryTestCase(unittest.HomeserverTestCase): self.assertEqual(200, channel.code, msg=channel.result["body"]) self.assertIn(bob, channel.json_body["device_keys"]) + + def make_device_keys(self, user_id: str, device_id: str) -> JsonDict: + # We only generate a master key to simplify the test. + master_signing_key = generate_signing_key(device_id) + master_verify_key = encode_verify_key_base64(get_verify_key(master_signing_key)) + + return { + "master_key": sign_json( + { + "user_id": user_id, + "usage": ["master"], + "keys": {"ed25519:" + master_verify_key: master_verify_key}, + }, + user_id, + master_signing_key, + ), + } + + def test_device_signing_with_uia(self) -> None: + """Device signing key upload requires UIA.""" + password = "wonderland" + device_id = "ABCDEFGHI" + alice_id = self.register_user("alice", password) + alice_token = self.login("alice", password, device_id=device_id) + + content = self.make_device_keys(alice_id, device_id) + + channel = self.make_request( + "POST", + "/_matrix/client/v3/keys/device_signing/upload", + content, + alice_token, + ) + + self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result) + # Grab the session + session = channel.json_body["session"] + # Ensure that flows are what is expected. + self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) + + # add UI auth + content["auth"] = { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": alice_id}, + "password": password, + "session": session, + } + + channel = self.make_request( + "POST", + "/_matrix/client/v3/keys/device_signing/upload", + content, + alice_token, + ) + + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) + + @override_config({"ui_auth": {"session_timeout": "15m"}}) + def test_device_signing_with_uia_session_timeout(self) -> None: + """Device signing key upload requires UIA buy passes with grace period.""" + password = "wonderland" + device_id = "ABCDEFGHI" + alice_id = self.register_user("alice", password) + alice_token = self.login("alice", password, device_id=device_id) + + content = self.make_device_keys(alice_id, device_id) + + channel = self.make_request( + "POST", + "/_matrix/client/v3/keys/device_signing/upload", + content, + alice_token, + ) + + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) + + @override_config( + { + "experimental_features": {"msc3967_enabled": True}, + "ui_auth": {"session_timeout": "15s"}, + } + ) + def test_device_signing_with_msc3967(self) -> None: + """Device signing key follows MSC3967 behaviour when enabled.""" + password = "wonderland" + device_id = "ABCDEFGHI" + alice_id = self.register_user("alice", password) + alice_token = self.login("alice", password, device_id=device_id) + + keys1 = self.make_device_keys(alice_id, device_id) + + # Initial request should succeed as no existing keys are present. + channel = self.make_request( + "POST", + "/_matrix/client/v3/keys/device_signing/upload", + keys1, + alice_token, + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) + + keys2 = self.make_device_keys(alice_id, device_id) + + # Subsequent request should require UIA as keys already exist even though session_timeout is set. + channel = self.make_request( + "POST", + "/_matrix/client/v3/keys/device_signing/upload", + keys2, + alice_token, + ) + self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result) + + # Grab the session + session = channel.json_body["session"] + # Ensure that flows are what is expected. + self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) + + # add UI auth + keys2["auth"] = { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": alice_id}, + "password": password, + "session": session, + } + + # Request should complete + channel = self.make_request( + "POST", + "/_matrix/client/v3/keys/device_signing/upload", + keys2, + alice_token, + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py
index ff5baa9f0a..62acf4f44e 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py
@@ -89,7 +89,6 @@ ADDITIONAL_LOGIN_FLOWS = [ class LoginRestServletTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, login.register_servlets, @@ -737,7 +736,6 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): class CASTestCase(unittest.HomeserverTestCase): - servlets = [ login.register_servlets, ] diff --git a/tests/rest/client/test_login_token_request.py b/tests/rest/client/test_login_token_request.py
index 6aedc1a11c..b8187db982 100644 --- a/tests/rest/client/test_login_token_request.py +++ b/tests/rest/client/test_login_token_request.py
@@ -26,7 +26,6 @@ endpoint = "/_matrix/client/unstable/org.matrix.msc3882/login/token" class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase): - servlets = [ login.register_servlets, admin.register_servlets, diff --git a/tests/rest/client/test_presence.py b/tests/rest/client/test_presence.py
index b3738a0304..dcbb125a3b 100644 --- a/tests/rest/client/test_presence.py +++ b/tests/rest/client/test_presence.py
@@ -35,15 +35,14 @@ class PresenceTestCase(unittest.HomeserverTestCase): servlets = [presence.register_servlets] def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - - presence_handler = Mock(spec=PresenceHandler) - presence_handler.set_state.return_value = make_awaitable(None) + self.presence_handler = Mock(spec=PresenceHandler) + self.presence_handler.set_state.return_value = make_awaitable(None) hs = self.setup_test_homeserver( "red", federation_http_client=None, federation_client=Mock(), - presence_handler=presence_handler, + presence_handler=self.presence_handler, ) return hs @@ -61,7 +60,7 @@ class PresenceTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, HTTPStatus.OK) - self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 1) + self.assertEqual(self.presence_handler.set_state.call_count, 1) @unittest.override_config({"use_presence": False}) def test_put_presence_disabled(self) -> None: @@ -76,4 +75,4 @@ class PresenceTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, HTTPStatus.OK) - self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 0) + self.assertEqual(self.presence_handler.set_state.call_count, 0) diff --git a/tests/rest/client/test_profile.py b/tests/rest/client/test_profile.py
index 8de5a342ae..27c93ad761 100644 --- a/tests/rest/client/test_profile.py +++ b/tests/rest/client/test_profile.py
@@ -30,7 +30,6 @@ from tests import unittest class ProfileTestCase(unittest.HomeserverTestCase): - servlets = [ admin.register_servlets_for_client_rest_resource, login.register_servlets, @@ -324,7 +323,6 @@ class ProfileTestCase(unittest.HomeserverTestCase): class ProfilesRestrictedTestCase(unittest.HomeserverTestCase): - servlets = [ admin.register_servlets_for_client_rest_resource, login.register_servlets, @@ -404,7 +402,6 @@ class ProfilesRestrictedTestCase(unittest.HomeserverTestCase): class OwnProfileUnrestrictedTestCase(unittest.HomeserverTestCase): - servlets = [ admin.register_servlets_for_client_rest_resource, login.register_servlets, diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py
index 11cf3939d8..b228dba861 100644 --- a/tests/rest/client/test_register.py +++ b/tests/rest/client/test_register.py
@@ -40,7 +40,6 @@ from tests.unittest import override_config class RegisterRestServletTestCase(unittest.HomeserverTestCase): - servlets = [ login.register_servlets, register.register_servlets, @@ -151,7 +150,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") def test_POST_guest_registration(self) -> None: - self.hs.config.key.macaroon_secret_key = "test" + self.hs.config.key.macaroon_secret_key = b"test" self.hs.config.registration.allow_guest_access = True channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") @@ -797,7 +796,6 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): class AccountValidityTestCase(unittest.HomeserverTestCase): - servlets = [ register.register_servlets, synapse.rest.admin.register_servlets_for_client_rest_resource, @@ -913,7 +911,6 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): - servlets = [ register.register_servlets, synapse.rest.admin.register_servlets_for_client_rest_resource, @@ -1132,7 +1129,6 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase): - servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource] def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: @@ -1166,12 +1162,15 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase): """ user_id = self.register_user("kermit_delta", "user") - self.hs.config.account_validity.startup_job_max_delta = self.max_delta + self.hs.config.account_validity.account_validity_startup_job_max_delta = ( + self.max_delta + ) now_ms = self.hs.get_clock().time_msec() self.get_success(self.store._set_expiration_date_when_missing()) res = self.get_success(self.store.get_expiration_ts_for_user(user_id)) + assert res is not None self.assertGreaterEqual(res, now_ms + self.validity_period - self.max_delta) self.assertLessEqual(res, now_ms + self.validity_period) diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index c8a6911d5e..fbbbcb23f1 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py
@@ -30,7 +30,6 @@ from tests import unittest from tests.server import FakeChannel from tests.test_utils import make_awaitable from tests.test_utils.event_injection import inject_event -from tests.unittest import override_config class BaseRelationsTestCase(unittest.HomeserverTestCase): @@ -403,7 +402,7 @@ class RelationsTestCase(BaseRelationsTestCase): def test_edit(self) -> None: """Test that a simple edit works.""" - + orig_body = {"body": "Hi!", "msgtype": "m.text"} new_body = {"msgtype": "m.text", "body": "I've been edited!"} edit_event_content = { "msgtype": "m.text", @@ -424,9 +423,7 @@ class RelationsTestCase(BaseRelationsTestCase): access_token=self.user_token, ) self.assertEqual(200, channel.code, channel.json_body) - self.assertEqual( - channel.json_body["content"], {"body": "Hi!", "msgtype": "m.text"} - ) + self.assertEqual(channel.json_body["content"], orig_body) self._assert_edit_bundle(channel.json_body, edit_event_id, edit_event_content) # Request the room messages. @@ -443,7 +440,7 @@ class RelationsTestCase(BaseRelationsTestCase): ) # Request the room context. - # /context should return the edited event. + # /context should return the event. channel = self.make_request( "GET", f"/rooms/{self.room}/context/{self.parent_id}", @@ -453,7 +450,7 @@ class RelationsTestCase(BaseRelationsTestCase): self._assert_edit_bundle( channel.json_body["event"], edit_event_id, edit_event_content ) - self.assertEqual(channel.json_body["event"]["content"], new_body) + self.assertEqual(channel.json_body["event"]["content"], orig_body) # Request sync, but limit the timeline so it becomes limited (and includes # bundled aggregations). @@ -491,45 +488,11 @@ class RelationsTestCase(BaseRelationsTestCase): edit_event_content, ) - @override_config({"experimental_features": {"msc3925_inhibit_edit": True}}) - def test_edit_inhibit_replace(self) -> None: - """ - If msc3925_inhibit_edit is enabled, then the original event should not be - replaced. - """ - - new_body = {"msgtype": "m.text", "body": "I've been edited!"} - edit_event_content = { - "msgtype": "m.text", - "body": "foo", - "m.new_content": new_body, - } - channel = self._send_relation( - RelationTypes.REPLACE, - "m.room.message", - content=edit_event_content, - ) - edit_event_id = channel.json_body["event_id"] - - # /context should return the *original* event. - channel = self.make_request( - "GET", - f"/rooms/{self.room}/context/{self.parent_id}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - self.assertEqual( - channel.json_body["event"]["content"], {"body": "Hi!", "msgtype": "m.text"} - ) - self._assert_edit_bundle( - channel.json_body["event"], edit_event_id, edit_event_content - ) - def test_multi_edit(self) -> None: """Test that multiple edits, including attempts by people who shouldn't be allowed, are correctly handled. """ - + orig_body = orig_body = {"body": "Hi!", "msgtype": "m.text"} self._send_relation( RelationTypes.REPLACE, "m.room.message", @@ -570,7 +533,7 @@ class RelationsTestCase(BaseRelationsTestCase): ) self.assertEqual(200, channel.code, channel.json_body) - self.assertEqual(channel.json_body["event"]["content"], new_body) + self.assertEqual(channel.json_body["event"]["content"], orig_body) self._assert_edit_bundle( channel.json_body["event"], edit_event_id, edit_event_content ) @@ -642,6 +605,7 @@ class RelationsTestCase(BaseRelationsTestCase): def test_edit_edit(self) -> None: """Test that an edit cannot be edited.""" + orig_body = {"body": "Hi!", "msgtype": "m.text"} new_body = {"msgtype": "m.text", "body": "Initial edit"} edit_event_content = { "msgtype": "m.text", @@ -675,14 +639,12 @@ class RelationsTestCase(BaseRelationsTestCase): access_token=self.user_token, ) self.assertEqual(200, channel.code, channel.json_body) - self.assertEqual( - channel.json_body["content"], {"body": "Hi!", "msgtype": "m.text"} - ) + self.assertEqual(channel.json_body["content"], orig_body) # The relations information should not include the edit to the edit. self._assert_edit_bundle(channel.json_body, edit_event_id, edit_event_content) - # /context should return the event updated for the *first* edit + # /context should return the bundled edit for the *first* edit # (The edit to the edit should be ignored.) channel = self.make_request( "GET", @@ -690,7 +652,7 @@ class RelationsTestCase(BaseRelationsTestCase): access_token=self.user_token, ) self.assertEqual(200, channel.code, channel.json_body) - self.assertEqual(channel.json_body["event"]["content"], new_body) + self.assertEqual(channel.json_body["event"]["content"], orig_body) self._assert_edit_bundle( channel.json_body["event"], edit_event_id, edit_event_content ) @@ -1080,48 +1042,6 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): ] assert_bundle(self._find_event_in_chunk(chunk)) - def test_annotation(self) -> None: - """ - Test that annotations get correctly bundled. - """ - # Setup by sending a variety of relations. - self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token - ) - self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") - - def assert_annotations(bundled_aggregations: JsonDict) -> None: - self.assertEqual( - { - "chunk": [ - {"type": "m.reaction", "key": "a", "count": 2}, - {"type": "m.reaction", "key": "b", "count": 1}, - ] - }, - bundled_aggregations, - ) - - self._test_bundled_aggregations(RelationTypes.ANNOTATION, assert_annotations, 7) - - def test_annotation_to_annotation(self) -> None: - """Any relation to an annotation should be ignored.""" - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - event_id = channel.json_body["event_id"] - self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", "b", parent_id=event_id - ) - - # Fetch the initial annotation event to see if it has bundled aggregations. - channel = self.make_request( - "GET", - f"/_matrix/client/v3/rooms/{self.room}/event/{event_id}", - access_token=self.user_token, - ) - self.assertEquals(200, channel.code, channel.json_body) - # The first annotationt should not have any bundled aggregations. - self.assertNotIn("m.relations", channel.json_body["unsigned"]) - def test_reference(self) -> None: """ Test that references get correctly bundled. @@ -1138,7 +1058,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): bundled_aggregations, ) - self._test_bundled_aggregations(RelationTypes.REFERENCE, assert_annotations, 7) + self._test_bundled_aggregations(RelationTypes.REFERENCE, assert_annotations, 6) def test_thread(self) -> None: """ @@ -1183,7 +1103,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): # The "user" sent the root event and is making queries for the bundled # aggregations: they have participated. - self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 7) + self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 6) # The "user2" sent replies in the thread and is making queries for the # bundled aggregations: they have participated. # @@ -1208,9 +1128,10 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): channel = self._send_relation(RelationTypes.THREAD, "m.room.test") thread_2 = channel.json_body["event_id"] - self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_2 + channel = self._send_relation( + RelationTypes.REFERENCE, "org.matrix.test", parent_id=thread_2 ) + reference_event_id = channel.json_body["event_id"] def assert_thread(bundled_aggregations: JsonDict) -> None: self.assertEqual(2, bundled_aggregations.get("count")) @@ -1235,17 +1156,15 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): self.assert_dict( { "m.relations": { - RelationTypes.ANNOTATION: { - "chunk": [ - {"type": "m.reaction", "key": "a", "count": 1}, - ] + RelationTypes.REFERENCE: { + "chunk": [{"event_id": reference_event_id}] }, } }, bundled_aggregations["latest_event"].get("unsigned"), ) - self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 7) + self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 6) def test_nested_thread(self) -> None: """ @@ -1330,7 +1249,6 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): thread_summary = relations_dict[RelationTypes.THREAD] self.assertIn("latest_event", thread_summary) latest_event_in_thread = thread_summary["latest_event"] - self.assertEqual(latest_event_in_thread["content"]["body"], "I've been edited!") # The latest event in the thread should have the edit appear under the # bundled aggregations. self.assertDictContainsSubset( @@ -1363,10 +1281,11 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): channel = self._send_relation(RelationTypes.THREAD, "m.room.test") thread_id = channel.json_body["event_id"] - # Annotate the thread. - self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_id + # Make a reference to the thread. + channel = self._send_relation( + RelationTypes.REFERENCE, "org.matrix.test", parent_id=thread_id ) + reference_event_id = channel.json_body["event_id"] channel = self.make_request( "GET", @@ -1377,9 +1296,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): self.assertEqual( channel.json_body["unsigned"].get("m.relations"), { - RelationTypes.ANNOTATION: { - "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}] - }, + RelationTypes.REFERENCE: {"chunk": [{"event_id": reference_event_id}]}, }, ) @@ -1396,9 +1313,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): self.assertEqual( thread_message["unsigned"].get("m.relations"), { - RelationTypes.ANNOTATION: { - "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}] - }, + RelationTypes.REFERENCE: {"chunk": [{"event_id": reference_event_id}]}, }, ) @@ -1410,7 +1325,8 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): Note that the spec allows for a server to return additional fields beyond what is specified. """ - self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + channel = self._send_relation(RelationTypes.REFERENCE, "org.matrix.test") + reference_event_id = channel.json_body["event_id"] # Note that the sync filter does not include "unsigned" as a field. filter = urllib.parse.quote_plus( @@ -1428,7 +1344,12 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): # Ensure there's bundled aggregations on it. self.assertIn("unsigned", parent_event) - self.assertIn("m.relations", parent_event["unsigned"]) + self.assertEqual( + parent_event["unsigned"].get("m.relations"), + { + RelationTypes.REFERENCE: {"chunk": [{"event_id": reference_event_id}]}, + }, + ) class RelationIgnoredUserTestCase(BaseRelationsTestCase): @@ -1475,53 +1396,8 @@ class RelationIgnoredUserTestCase(BaseRelationsTestCase): return before_aggregations[relation_type], after_aggregations[relation_type] - def test_annotation(self) -> None: - """Annotations should ignore""" - # Send 2 from us, 2 from the to be ignored user. - allowed_event_ids = [] - ignored_event_ids = [] - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") - allowed_event_ids.append(channel.json_body["event_id"]) - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="b") - allowed_event_ids.append(channel.json_body["event_id"]) - channel = self._send_relation( - RelationTypes.ANNOTATION, - "m.reaction", - key="a", - access_token=self.user2_token, - ) - ignored_event_ids.append(channel.json_body["event_id"]) - channel = self._send_relation( - RelationTypes.ANNOTATION, - "m.reaction", - key="c", - access_token=self.user2_token, - ) - ignored_event_ids.append(channel.json_body["event_id"]) - - before_aggregations, after_aggregations = self._test_ignored_user( - RelationTypes.ANNOTATION, allowed_event_ids, ignored_event_ids - ) - - self.assertCountEqual( - before_aggregations["chunk"], - [ - {"type": "m.reaction", "key": "a", "count": 2}, - {"type": "m.reaction", "key": "b", "count": 1}, - {"type": "m.reaction", "key": "c", "count": 1}, - ], - ) - - self.assertCountEqual( - after_aggregations["chunk"], - [ - {"type": "m.reaction", "key": "a", "count": 1}, - {"type": "m.reaction", "key": "b", "count": 1}, - ], - ) - def test_reference(self) -> None: - """Annotations should ignore""" + """Aggregations should exclude reference relations from ignored users""" channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") allowed_event_ids = [channel.json_body["event_id"]] @@ -1544,7 +1420,7 @@ class RelationIgnoredUserTestCase(BaseRelationsTestCase): ) def test_thread(self) -> None: - """Annotations should ignore""" + """Aggregations should exclude thread releations from ignored users""" channel = self._send_relation(RelationTypes.THREAD, "m.room.test") allowed_event_ids = [channel.json_body["event_id"]] @@ -1618,43 +1494,6 @@ class RelationRedactionTestCase(BaseRelationsTestCase): for t in threads ] - def test_redact_relation_annotation(self) -> None: - """ - Test that annotations of an event are properly handled after the - annotation is redacted. - - The redacted relation should not be included in bundled aggregations or - the response to relations. - """ - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - to_redact_event_id = channel.json_body["event_id"] - - channel = self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token - ) - unredacted_event_id = channel.json_body["event_id"] - - # Both relations should exist. - event_ids = self._get_related_events() - relations = self._get_bundled_aggregations() - self.assertCountEqual(event_ids, [to_redact_event_id, unredacted_event_id]) - self.assertEquals( - relations["m.annotation"], - {"chunk": [{"type": "m.reaction", "key": "a", "count": 2}]}, - ) - - # Redact one of the reactions. - self._redact(to_redact_event_id) - - # The unredacted relation should still exist. - event_ids = self._get_related_events() - relations = self._get_bundled_aggregations() - self.assertEquals(event_ids, [unredacted_event_id]) - self.assertEquals( - relations["m.annotation"], - {"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]}, - ) - def test_redact_relation_thread(self) -> None: """ Test that thread replies are properly handled after the thread reply redacted. @@ -1775,14 +1614,14 @@ class RelationRedactionTestCase(BaseRelationsTestCase): is redacted. """ # Add a relation - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍") + channel = self._send_relation(RelationTypes.REFERENCE, "org.matrix.test") related_event_id = channel.json_body["event_id"] # The relations should exist. event_ids = self._get_related_events() relations = self._get_bundled_aggregations() self.assertEqual(len(event_ids), 1) - self.assertIn(RelationTypes.ANNOTATION, relations) + self.assertIn(RelationTypes.REFERENCE, relations) # Redact the original event. self._redact(self.parent_id) @@ -1792,8 +1631,8 @@ class RelationRedactionTestCase(BaseRelationsTestCase): relations = self._get_bundled_aggregations() self.assertEquals(event_ids, [related_event_id]) self.assertEquals( - relations["m.annotation"], - {"chunk": [{"type": "m.reaction", "key": "👍", "count": 1}]}, + relations[RelationTypes.REFERENCE], + {"chunk": [{"event_id": related_event_id}]}, ) def test_redact_parent_thread(self) -> None: diff --git a/tests/rest/client/test_rendezvous.py b/tests/rest/client/test_rendezvous.py
index c0eb5d01a6..8dbd64be55 100644 --- a/tests/rest/client/test_rendezvous.py +++ b/tests/rest/client/test_rendezvous.py
@@ -25,7 +25,6 @@ endpoint = "/_matrix/client/unstable/org.matrix.msc3886/rendezvous" class RendezvousServletTestCase(unittest.HomeserverTestCase): - servlets = [ rendezvous.register_servlets, ] diff --git a/tests/rest/client/test_report_event.py b/tests/rest/client/test_report_event.py
index 7cb1017a4a..1250685d39 100644 --- a/tests/rest/client/test_report_event.py +++ b/tests/rest/client/test_report_event.py
@@ -73,6 +73,18 @@ class ReportEventTestCase(unittest.HomeserverTestCase): data = {"reason": None, "score": None} self._assert_status(400, data) + def test_cannot_report_nonexistent_event(self) -> None: + """ + Tests that we don't accept event reports for events which do not exist. + """ + channel = self.make_request( + "POST", + f"rooms/{self.room_id}/report/$nonsenseeventid:test", + {"reason": "i am very sad"}, + access_token=self.other_user_tok, + ) + self.assertEqual(404, channel.code, msg=channel.result["body"]) + def _assert_status(self, response_status: int, data: JsonDict) -> None: channel = self.make_request( "POST", self.report_path, data, access_token=self.other_user_tok diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index 9c8c1889d3..d3e06bf6b3 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py
@@ -136,6 +136,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): # Send a first event, which should be filtered out at the end of the test. resp = self.helper.send(room_id=room_id, body="1", tok=self.token) first_event_id = resp.get("event_id") + assert isinstance(first_event_id, str) # Advance the time by 2 days. We're using the default retention policy, therefore # after this the first event will still be valid. @@ -144,6 +145,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): # Send another event, which shouldn't get filtered out. resp = self.helper.send(room_id=room_id, body="2", tok=self.token) valid_event_id = resp.get("event_id") + assert isinstance(valid_event_id, str) # Advance the time by another 2 days. After this, the first event should be # outdated but not the second one. @@ -229,7 +231,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): # Check that we can still access state events that were sent before the event that # has been purged. - self.get_event(room_id, create_event.event_id) + self.get_event(room_id, bool(create_event)) def get_event(self, event_id: str, expect_none: bool = False) -> JsonDict: event = self.get_success(self.store.get_event(event_id, allow_none=True)) @@ -238,7 +240,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): self.assertIsNone(event) return {} - self.assertIsNotNone(event) + assert event is not None time_now = self.clock.time_msec() serialized = self.serializer.serialize_event(event, time_now) diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index 9222cab198..a4900703c4 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py
@@ -65,7 +65,6 @@ class RoomBase(unittest.HomeserverTestCase): servlets = [room.register_servlets, room.register_deprecated_servlets] def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - self.hs = self.setup_test_homeserver( "red", federation_http_client=None, @@ -92,7 +91,6 @@ class RoomPermissionsTestCase(RoomBase): rmcreator_id = "@notme:red" def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.helper.auth_user_id = self.rmcreator_id # create some rooms under the name rmcreator_id self.uncreated_rmid = "!aa:test" @@ -715,7 +713,7 @@ class RoomsCreateTestCase(RoomBase): self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) assert channel.resource_usage is not None - self.assertEqual(33, channel.resource_usage.db_txn_count) + self.assertEqual(30, channel.resource_usage.db_txn_count) def test_post_room_initial_state(self) -> None: # POST with initial_state config key, expect new room id @@ -728,7 +726,7 @@ class RoomsCreateTestCase(RoomBase): self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) assert channel.resource_usage is not None - self.assertEqual(36, channel.resource_usage.db_txn_count) + self.assertEqual(32, channel.resource_usage.db_txn_count) def test_post_room_visibility_key(self) -> None: # POST with visibility config key, expect new room id @@ -1127,7 +1125,6 @@ class RoomInviteRatelimitTestCase(RoomBase): class RoomJoinTestCase(RoomBase): - servlets = [ admin.register_servlets, login.register_servlets, @@ -2102,7 +2099,6 @@ class RoomSearchTestCase(unittest.HomeserverTestCase): hijack_auth = False def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - # Register the user who does the searching self.user_id2 = self.register_user("user", "pass") self.access_token = self.login("user", "pass") @@ -2195,7 +2191,6 @@ class RoomSearchTestCase(unittest.HomeserverTestCase): class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, @@ -2203,7 +2198,6 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase): ] def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - self.url = b"/_matrix/client/r0/publicRooms" config = self.default_config() @@ -2225,7 +2219,6 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase): class PublicRoomsRoomTypeFilterTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, @@ -2233,7 +2226,6 @@ class PublicRoomsRoomTypeFilterTestCase(unittest.HomeserverTestCase): ] def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - config = self.default_config() config["allow_public_rooms_without_auth"] = True self.hs = self.setup_test_homeserver(config=config) @@ -2414,7 +2406,6 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, @@ -2983,7 +2974,6 @@ class RelationsTestCase(PaginationTestCase): class ContextTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, @@ -3359,7 +3349,6 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): class ThreepidInviteTestCase(unittest.HomeserverTestCase): - servlets = [ admin.register_servlets, login.register_servlets, @@ -3382,8 +3371,8 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase): # a remote IS. We keep the mock for make_and_store_3pid_invite around so we # can check its call_count later on during the test. make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0))) - self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock - self.hs.get_identity_handler().lookup_3pid = Mock( + self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock # type: ignore[assignment] + self.hs.get_identity_handler().lookup_3pid = Mock( # type: ignore[assignment] return_value=make_awaitable(None), ) @@ -3438,13 +3427,14 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase): """ Test allowing/blocking threepid invites with a spam-check module. - In this test, we use the more recent API in which callbacks return a `Union[Codes, Literal["NOT_SPAM"]]`.""" + In this test, we use the more recent API in which callbacks return a `Union[Codes, Literal["NOT_SPAM"]]`. + """ # Mock a few functions to prevent the test from failing due to failing to talk to # a remote IS. We keep the mock for make_and_store_3pid_invite around so we # can check its call_count later on during the test. make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0))) - self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock - self.hs.get_identity_handler().lookup_3pid = Mock( + self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock # type: ignore[assignment] + self.hs.get_identity_handler().lookup_3pid = Mock( # type: ignore[assignment] return_value=make_awaitable(None), ) @@ -3563,8 +3553,10 @@ class TimestampLookupTestCase(unittest.HomeserverTestCase): ) event.internal_metadata.outlier = True + persistence = self._storage_controllers.persistence + assert persistence is not None self.get_success( - self._storage_controllers.persistence.persist_event( + persistence.persist_event( event, EventContext.for_outlier(self._storage_controllers) ) ) diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py
index c807a37bc2..8d2cdf8751 100644 --- a/tests/rest/client/test_shadow_banned.py +++ b/tests/rest/client/test_shadow_banned.py
@@ -84,7 +84,7 @@ class RoomTestCase(_ShadowBannedBase): def test_invite_3pid(self) -> None: """Ensure that a 3PID invite does not attempt to contact the identity server.""" identity_handler = self.hs.get_identity_handler() - identity_handler.lookup_3pid = Mock( + identity_handler.lookup_3pid = Mock( # type: ignore[assignment] side_effect=AssertionError("This should not get called") ) @@ -222,7 +222,7 @@ class RoomTestCase(_ShadowBannedBase): event_source.get_new_events( user=UserID.from_string(self.other_user_id), from_key=0, - limit=None, + limit=10, room_ids=[room_id], is_guest=False, ) @@ -286,6 +286,7 @@ class ProfileTestCase(_ShadowBannedBase): self.banned_user_id, ) ) + assert event is not None self.assertEqual( event.content, {"membership": "join", "displayname": original_display_name} ) @@ -321,6 +322,7 @@ class ProfileTestCase(_ShadowBannedBase): self.banned_user_id, ) ) + assert event is not None self.assertEqual( event.content, {"membership": "join", "displayname": original_display_name} ) diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py
index b9047194dd..9c876c7a32 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py
@@ -41,7 +41,6 @@ from tests.server import TimedOutException class FilterTestCase(unittest.HomeserverTestCase): - user_id = "@apple:test" servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, @@ -191,7 +190,6 @@ class SyncFilterTestCase(unittest.HomeserverTestCase): class SyncTypingTests(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, @@ -892,7 +890,6 @@ class DeviceListSyncTestCase(unittest.HomeserverTestCase): class ExcludeRoomTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index 3325d43a2f..3b99513707 100644 --- a/tests/rest/client/test_third_party_rules.py +++ b/tests/rest/client/test_third_party_rules.py
@@ -137,6 +137,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): """Tests that a forbidden event is forbidden from being sent, but an allowed one can be sent. """ + # patch the rules module with a Mock which will return False for some event # types async def check( @@ -243,6 +244,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): def test_modify_event(self) -> None: """The module can return a modified version of the event""" + # first patch the event checker so that it will modify the event async def check( ev: EventBase, state: StateMap[EventBase] @@ -275,6 +277,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): def test_message_edit(self) -> None: """Ensure that the module doesn't cause issues with edited messages.""" + # first patch the event checker so that it will modify the event async def check( ev: EventBase, state: StateMap[EventBase] @@ -425,7 +428,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): async def test_fn( event: EventBase, state_events: StateMap[EventBase] ) -> Tuple[bool, Optional[JsonDict]]: - if event.is_state and event.type == EventTypes.PowerLevels: + if event.is_state() and event.type == EventTypes.PowerLevels: await api.create_and_send_event_into_room( { "room_id": event.room_id, @@ -931,3 +934,124 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): # Check that the mock was called with the right parameters self.assertEqual(args, (user_id, "email", "foo@example.com")) + + def test_on_add_and_remove_user_third_party_identifier(self) -> None: + """Tests that the on_add_user_third_party_identifier and + on_remove_user_third_party_identifier module callbacks are called + just before associating and removing a 3PID to/from an account. + """ + # Pretend to be a Synapse module and register both callbacks as mocks. + third_party_rules = self.hs.get_third_party_event_rules() + on_add_user_third_party_identifier_callback_mock = Mock( + return_value=make_awaitable(None) + ) + on_remove_user_third_party_identifier_callback_mock = Mock( + return_value=make_awaitable(None) + ) + third_party_rules._on_threepid_bind_callbacks.append( + on_add_user_third_party_identifier_callback_mock + ) + third_party_rules._on_threepid_bind_callbacks.append( + on_remove_user_third_party_identifier_callback_mock + ) + + # Register an admin user. + self.register_user("admin", "password", admin=True) + admin_tok = self.login("admin", "password") + + # Also register a normal user we can modify. + user_id = self.register_user("user", "password") + + # Add a 3PID to the user. + channel = self.make_request( + "PUT", + "/_synapse/admin/v2/users/%s" % user_id, + { + "threepids": [ + { + "medium": "email", + "address": "foo@example.com", + }, + ], + }, + access_token=admin_tok, + ) + + # Check that the mocked add callback was called with the appropriate + # 3PID details. + self.assertEqual(channel.code, 200, channel.json_body) + on_add_user_third_party_identifier_callback_mock.assert_called_once() + args = on_add_user_third_party_identifier_callback_mock.call_args[0] + self.assertEqual(args, (user_id, "email", "foo@example.com")) + + # Now remove the 3PID from the user + channel = self.make_request( + "PUT", + "/_synapse/admin/v2/users/%s" % user_id, + { + "threepids": [], + }, + access_token=admin_tok, + ) + + # Check that the mocked remove callback was called with the appropriate + # 3PID details. + self.assertEqual(channel.code, 200, channel.json_body) + on_remove_user_third_party_identifier_callback_mock.assert_called_once() + args = on_remove_user_third_party_identifier_callback_mock.call_args[0] + self.assertEqual(args, (user_id, "email", "foo@example.com")) + + def test_on_remove_user_third_party_identifier_is_called_on_deactivate( + self, + ) -> None: + """Tests that the on_remove_user_third_party_identifier module callback is called + when a user is deactivated and their third-party ID associations are deleted. + """ + # Pretend to be a Synapse module and register both callbacks as mocks. + third_party_rules = self.hs.get_third_party_event_rules() + on_remove_user_third_party_identifier_callback_mock = Mock( + return_value=make_awaitable(None) + ) + third_party_rules._on_threepid_bind_callbacks.append( + on_remove_user_third_party_identifier_callback_mock + ) + + # Register an admin user. + self.register_user("admin", "password", admin=True) + admin_tok = self.login("admin", "password") + + # Also register a normal user we can modify. + user_id = self.register_user("user", "password") + + # Add a 3PID to the user. + channel = self.make_request( + "PUT", + "/_synapse/admin/v2/users/%s" % user_id, + { + "threepids": [ + { + "medium": "email", + "address": "foo@example.com", + }, + ], + }, + access_token=admin_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Now deactivate the user. + channel = self.make_request( + "PUT", + "/_synapse/admin/v2/users/%s" % user_id, + { + "deactivated": True, + }, + access_token=admin_tok, + ) + + # Check that the mocked remove callback was called with the appropriate + # 3PID details. + self.assertEqual(channel.code, 200, channel.json_body) + on_remove_user_third_party_identifier_callback_mock.assert_called_once() + args = on_remove_user_third_party_identifier_callback_mock.call_args[0] + self.assertEqual(args, (user_id, "email", "foo@example.com")) diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py
index 3086e1b565..d8dc56261a 100644 --- a/tests/rest/client/test_transactions.py +++ b/tests/rest/client/test_transactions.py
@@ -39,15 +39,23 @@ class HttpTransactionCacheTestCase(unittest.TestCase): self.cache = HttpTransactionCache(self.hs) self.mock_http_response = (HTTPStatus.OK, {"result": "GOOD JOB!"}) - self.mock_key = "foo" + + # Here we make sure that we're setting all the fields that HttpTransactionCache + # uses to build the transaction key. + self.mock_request = Mock() + self.mock_request.path = b"/foo/bar" + self.mock_requester = Mock() + self.mock_requester.app_service = None + self.mock_requester.is_guest = False + self.mock_requester.access_token_id = 1234 @defer.inlineCallbacks def test_executes_given_function( self, ) -> Generator["defer.Deferred[Any]", object, None]: cb = Mock(return_value=make_awaitable(self.mock_http_response)) - res = yield self.cache.fetch_or_execute( - self.mock_key, cb, "some_arg", keyword="arg" + res = yield self.cache.fetch_or_execute_request( + self.mock_request, self.mock_requester, cb, "some_arg", keyword="arg" ) cb.assert_called_once_with("some_arg", keyword="arg") self.assertEqual(res, self.mock_http_response) @@ -58,8 +66,13 @@ class HttpTransactionCacheTestCase(unittest.TestCase): ) -> Generator["defer.Deferred[Any]", object, None]: cb = Mock(return_value=make_awaitable(self.mock_http_response)) for i in range(3): # invoke multiple times - res = yield self.cache.fetch_or_execute( - self.mock_key, cb, "some_arg", keyword="arg", changing_args=i + res = yield self.cache.fetch_or_execute_request( + self.mock_request, + self.mock_requester, + cb, + "some_arg", + keyword="arg", + changing_args=i, ) self.assertEqual(res, self.mock_http_response) # expect only a single call to do the work @@ -77,7 +90,9 @@ class HttpTransactionCacheTestCase(unittest.TestCase): @defer.inlineCallbacks def test() -> Generator["defer.Deferred[Any]", object, None]: with LoggingContext("c") as c1: - res = yield self.cache.fetch_or_execute(self.mock_key, cb) + res = yield self.cache.fetch_or_execute_request( + self.mock_request, self.mock_requester, cb + ) self.assertIs(current_context(), c1) self.assertEqual(res, (1, {})) @@ -106,12 +121,16 @@ class HttpTransactionCacheTestCase(unittest.TestCase): with LoggingContext("test") as test_context: try: - yield self.cache.fetch_or_execute(self.mock_key, cb) + yield self.cache.fetch_or_execute_request( + self.mock_request, self.mock_requester, cb + ) except Exception as e: self.assertEqual(e.args[0], "boo") self.assertIs(current_context(), test_context) - res = yield self.cache.fetch_or_execute(self.mock_key, cb) + res = yield self.cache.fetch_or_execute_request( + self.mock_request, self.mock_requester, cb + ) self.assertEqual(res, self.mock_http_response) self.assertIs(current_context(), test_context) @@ -134,29 +153,39 @@ class HttpTransactionCacheTestCase(unittest.TestCase): with LoggingContext("test") as test_context: try: - yield self.cache.fetch_or_execute(self.mock_key, cb) + yield self.cache.fetch_or_execute_request( + self.mock_request, self.mock_requester, cb + ) except Exception as e: self.assertEqual(e.args[0], "boo") self.assertIs(current_context(), test_context) - res = yield self.cache.fetch_or_execute(self.mock_key, cb) + res = yield self.cache.fetch_or_execute_request( + self.mock_request, self.mock_requester, cb + ) self.assertEqual(res, self.mock_http_response) self.assertIs(current_context(), test_context) @defer.inlineCallbacks def test_cleans_up(self) -> Generator["defer.Deferred[Any]", object, None]: cb = Mock(return_value=make_awaitable(self.mock_http_response)) - yield self.cache.fetch_or_execute(self.mock_key, cb, "an arg") + yield self.cache.fetch_or_execute_request( + self.mock_request, self.mock_requester, cb, "an arg" + ) # should NOT have cleaned up yet self.clock.advance_time_msec(CLEANUP_PERIOD_MS / 2) - yield self.cache.fetch_or_execute(self.mock_key, cb, "an arg") + yield self.cache.fetch_or_execute_request( + self.mock_request, self.mock_requester, cb, "an arg" + ) # still using cache cb.assert_called_once_with("an arg") self.clock.advance_time_msec(CLEANUP_PERIOD_MS) - yield self.cache.fetch_or_execute(self.mock_key, cb, "an arg") + yield self.cache.fetch_or_execute_request( + self.mock_request, self.mock_requester, cb, "an arg" + ) # no longer using cache self.assertEqual(cb.call_count, 2) self.assertEqual(cb.call_args_list, [call("an arg"), call("an arg")]) diff --git a/tests/rest/client/test_upgrade_room.py b/tests/rest/client/test_upgrade_room.py
index 5ec343dd7f..0b4c691318 100644 --- a/tests/rest/client/test_upgrade_room.py +++ b/tests/rest/client/test_upgrade_room.py
@@ -84,7 +84,7 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): self.room_id, EventTypes.Tombstone, "" ) ) - self.assertIsNotNone(tombstone_event) + assert tombstone_event is not None self.assertEqual(new_room_id, tombstone_event.content["replacement_room"]) # Check that the new room exists. diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index 8d6f2b6ff9..9532e5ddc1 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py
@@ -36,6 +36,7 @@ from urllib.parse import urlencode import attr from typing_extensions import Literal +from twisted.test.proto_helpers import MemoryReactorClock from twisted.web.resource import Resource from twisted.web.server import Site @@ -67,6 +68,7 @@ class RestHelper: """ hs: HomeServer + reactor: MemoryReactorClock site: Site auth_user_id: Optional[str] @@ -142,7 +144,7 @@ class RestHelper: path = path + "?access_token=%s" % tok channel = make_request( - self.hs.get_reactor(), + self.reactor, self.site, "POST", path, @@ -216,7 +218,7 @@ class RestHelper: data["reason"] = reason channel = make_request( - self.hs.get_reactor(), + self.reactor, self.site, "POST", path, @@ -313,7 +315,7 @@ class RestHelper: data.update(extra_data or {}) channel = make_request( - self.hs.get_reactor(), + self.reactor, self.site, "PUT", path, @@ -394,7 +396,7 @@ class RestHelper: path = path + "?access_token=%s" % tok channel = make_request( - self.hs.get_reactor(), + self.reactor, self.site, "PUT", path, @@ -433,7 +435,7 @@ class RestHelper: path = path + f"?access_token={tok}" channel = make_request( - self.hs.get_reactor(), + self.reactor, self.site, "GET", path, @@ -488,7 +490,7 @@ class RestHelper: if body is not None: content = json.dumps(body).encode("utf8") - channel = make_request(self.hs.get_reactor(), self.site, method, path, content) + channel = make_request(self.reactor, self.site, method, path, content) assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % ( expect_code, @@ -573,8 +575,8 @@ class RestHelper: image_length = len(image_data) path = "/_matrix/media/r0/upload?filename=%s" % (filename,) channel = make_request( - self.hs.get_reactor(), - FakeSite(resource, self.hs.get_reactor()), + self.reactor, + FakeSite(resource, self.reactor), "POST", path, content=image_data, @@ -603,7 +605,7 @@ class RestHelper: expect_code: The return code to expect from attempting the whoami request """ channel = make_request( - self.hs.get_reactor(), + self.reactor, self.site, "GET", "account/whoami", @@ -642,7 +644,7 @@ class RestHelper: ) -> Tuple[JsonDict, FakeAuthorizationGrant]: """Log in (as a new user) via OIDC - Returns the result of the final token login. + Returns the result of the final token login and the fake authorization grant. Requires that "oidc_config" in the homeserver config be set appropriately (TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a @@ -672,10 +674,28 @@ class RestHelper: assert m, channel.text_body login_token = m.group(1) - # finally, submit the matrix login token to the login API, which gives us our - # matrix access token and device id. + return self.login_via_token(login_token, expected_status), grant + + def login_via_token( + self, + login_token: str, + expected_status: int = 200, + ) -> JsonDict: + """Submit the matrix login token to the login API, which gives us our + matrix access token and device id.Log in (as a new user) via OIDC + + Returns the result of the token login. + + Requires that "oidc_config" in the homeserver config be set appropriately + (TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a + "public_base_url". + + Also requires the login servlet and the OIDC callback resource to be mounted at + the normal places. + """ + channel = make_request( - self.hs.get_reactor(), + self.reactor, self.site, "POST", "/login", @@ -684,7 +704,7 @@ class RestHelper: assert ( channel.code == expected_status ), f"unexpected status in response: {channel.code}" - return channel.json_body, grant + return channel.json_body def auth_via_oidc( self, @@ -805,7 +825,7 @@ class RestHelper: with fake_serer.patch_homeserver(hs=self.hs): # now hit the callback URI with the right params and a made-up code channel = make_request( - self.hs.get_reactor(), + self.reactor, self.site, "GET", callback_uri, @@ -849,7 +869,7 @@ class RestHelper: # is the easiest way of figuring out what the Host header ought to be set to # to keep Synapse happy. channel = make_request( - self.hs.get_reactor(), + self.reactor, self.site, "GET", uri, @@ -867,7 +887,7 @@ class RestHelper: location = get_location(channel) parts = urllib.parse.urlsplit(location) channel = make_request( - self.hs.get_reactor(), + self.reactor, self.site, "GET", urllib.parse.urlunsplit(("", "") + parts[2:]), @@ -900,9 +920,7 @@ class RestHelper: + urllib.parse.urlencode({"session": ui_auth_session_id}) ) # hit the redirect url (which will issue a cookie and state) - channel = make_request( - self.hs.get_reactor(), self.site, "GET", sso_redirect_endpoint - ) + channel = make_request(self.reactor, self.site, "GET", sso_redirect_endpoint) # that should serve a confirmation page assert channel.code == HTTPStatus.OK, channel.text_body channel.extract_cookies(cookies) diff --git a/tests/rest/media/test_media_retention.py b/tests/rest/media/test_media_retention.py
index 23f227aed6..b59d9dfd4d 100644 --- a/tests/rest/media/test_media_retention.py +++ b/tests/rest/media/test_media_retention.py
@@ -31,7 +31,6 @@ from tests.utils import MockClock class MediaRetentionTestCase(unittest.HomeserverTestCase): - ONE_DAY_IN_MS = 24 * 60 * 60 * 1000 THIRTY_DAYS_IN_MS = 30 * ONE_DAY_IN_MS diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/test_url_preview.py
index 2c321f8d04..e91dc581c2 100644 --- a/tests/rest/media/v1/test_url_preview.py +++ b/tests/rest/media/test_url_preview.py
@@ -26,8 +26,8 @@ from twisted.internet.interfaces import IAddress, IResolutionReceiver from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactor from synapse.config.oembed import OEmbedEndpointConfig -from synapse.rest.media.v1.media_repository import MediaRepositoryResource -from synapse.rest.media.v1.preview_url_resource import IMAGE_CACHE_EXPIRY_MS +from synapse.rest.media.media_repository_resource import MediaRepositoryResource +from synapse.rest.media.preview_url_resource import IMAGE_CACHE_EXPIRY_MS from synapse.server import HomeServer from synapse.types import JsonDict from synapse.util import Clock @@ -58,7 +58,6 @@ class URLPreviewTests(unittest.HomeserverTestCase): ) def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - config = self.default_config() config["url_preview_enabled"] = True config["max_spider_size"] = 9999999 @@ -83,7 +82,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): config["media_store_path"] = self.media_store_path provider_config = { - "module": "synapse.rest.media.v1.storage_provider.FileStorageProviderBackend", + "module": "synapse.media.storage_provider.FileStorageProviderBackend", "store_local": True, "store_synchronous": False, "store_remote": True, @@ -118,7 +117,6 @@ class URLPreviewTests(unittest.HomeserverTestCase): return hs def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.media_repo = hs.get_media_repository_resource() self.preview_url = self.media_repo.children[b"preview_url"] @@ -133,7 +131,6 @@ class URLPreviewTests(unittest.HomeserverTestCase): addressTypes: Optional[Sequence[Type[IAddress]]] = None, transportSemantics: str = "TCP", ) -> IResolutionReceiver: - resolution = HostResolution(hostName) resolutionReceiver.resolutionBegan(resolution) if hostName not in self.lookups: @@ -660,7 +657,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): """If the preview image doesn't exist, ensure some data is returned.""" self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] - end_content = ( + result = ( b"""<html><body><img src="http://cdn.matrix.org/foo.jpg"></body></html>""" ) @@ -681,8 +678,8 @@ class URLPreviewTests(unittest.HomeserverTestCase): b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" b'Content-Type: text/html; charset="utf8"\r\n\r\n' ) - % (len(end_content),) - + end_content + % (len(result),) + + result ) self.pump() @@ -691,6 +688,44 @@ class URLPreviewTests(unittest.HomeserverTestCase): # The image should not be in the result. self.assertNotIn("og:image", channel.json_body) + def test_oembed_failure(self) -> None: + """If the autodiscovered oEmbed URL fails, ensure some data is returned.""" + self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] + + result = b""" + <title>oEmbed Autodiscovery Fail</title> + <link rel="alternate" type="application/json+oembed" + href="http://example.com/oembed?url=http%3A%2F%2Fmatrix.org&format=json" + title="matrixdotorg" /> + """ + + channel = self.make_request( + "GET", + "preview_url?url=http://matrix.org", + shorthand=False, + await_result=False, + ) + self.pump() + + client = self.reactor.tcpClients[0][2].buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + client.dataReceived( + ( + b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" + b'Content-Type: text/html; charset="utf8"\r\n\r\n' + ) + % (len(result),) + + result + ) + + self.pump() + self.assertEqual(channel.code, 200) + + # The image should not be in the result. + self.assertEqual(channel.json_body["og:title"], "oEmbed Autodiscovery Fail") + def test_data_url(self) -> None: """ Requesting to preview a data URL is not supported. diff --git a/tests/rest/media/v1/__init__.py b/tests/rest/media/v1/__init__.py deleted file mode 100644
index b1ee10cfcc..0000000000 --- a/tests/rest/media/v1/__init__.py +++ /dev/null
@@ -1,13 +0,0 @@ -# Copyright 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/tests/rest/media/v1/test_base.py b/tests/rest/media/v1/test_base.py deleted file mode 100644
index c73179151a..0000000000 --- a/tests/rest/media/v1/test_base.py +++ /dev/null
@@ -1,38 +0,0 @@ -# Copyright 2019 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from synapse.rest.media.v1._base import get_filename_from_headers - -from tests import unittest - - -class GetFileNameFromHeadersTests(unittest.TestCase): - # input -> expected result - TEST_CASES = { - b"inline; filename=abc.txt": "abc.txt", - b'inline; filename="azerty"': "azerty", - b'inline; filename="aze%20rty"': "aze%20rty", - b'inline; filename="aze"rty"': 'aze"rty', - b'inline; filename="azer;ty"': "azer;ty", - b"inline; filename*=utf-8''foo%C2%A3bar": "foo£bar", - } - - def tests(self) -> None: - for hdr, expected in self.TEST_CASES.items(): - res = get_filename_from_headers({b"Content-Disposition": [hdr]}) - self.assertEqual( - res, - expected, - f"expected output for {hdr!r} to be {expected} but was {res}", - ) diff --git a/tests/rest/media/v1/test_filepath.py b/tests/rest/media/v1/test_filepath.py deleted file mode 100644
index 43e6f0f70a..0000000000 --- a/tests/rest/media/v1/test_filepath.py +++ /dev/null
@@ -1,595 +0,0 @@ -# Copyright 2021 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import inspect -import os -from typing import Iterable - -from synapse.rest.media.v1.filepath import MediaFilePaths, _wrap_with_jail_check - -from tests import unittest - - -class MediaFilePathsTestCase(unittest.TestCase): - def setUp(self) -> None: - super().setUp() - - self.filepaths = MediaFilePaths("/media_store") - - def test_local_media_filepath(self) -> None: - """Test local media paths""" - self.assertEqual( - self.filepaths.local_media_filepath_rel("GerZNDnDZVjsOtardLuwfIBg"), - "local_content/Ge/rZ/NDnDZVjsOtardLuwfIBg", - ) - self.assertEqual( - self.filepaths.local_media_filepath("GerZNDnDZVjsOtardLuwfIBg"), - "/media_store/local_content/Ge/rZ/NDnDZVjsOtardLuwfIBg", - ) - - def test_local_media_thumbnail(self) -> None: - """Test local media thumbnail paths""" - self.assertEqual( - self.filepaths.local_media_thumbnail_rel( - "GerZNDnDZVjsOtardLuwfIBg", 800, 600, "image/jpeg", "scale" - ), - "local_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale", - ) - self.assertEqual( - self.filepaths.local_media_thumbnail( - "GerZNDnDZVjsOtardLuwfIBg", 800, 600, "image/jpeg", "scale" - ), - "/media_store/local_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale", - ) - - def test_local_media_thumbnail_dir(self) -> None: - """Test local media thumbnail directory paths""" - self.assertEqual( - self.filepaths.local_media_thumbnail_dir("GerZNDnDZVjsOtardLuwfIBg"), - "/media_store/local_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg", - ) - - def test_remote_media_filepath(self) -> None: - """Test remote media paths""" - self.assertEqual( - self.filepaths.remote_media_filepath_rel( - "example.com", "GerZNDnDZVjsOtardLuwfIBg" - ), - "remote_content/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg", - ) - self.assertEqual( - self.filepaths.remote_media_filepath( - "example.com", "GerZNDnDZVjsOtardLuwfIBg" - ), - "/media_store/remote_content/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg", - ) - - def test_remote_media_thumbnail(self) -> None: - """Test remote media thumbnail paths""" - self.assertEqual( - self.filepaths.remote_media_thumbnail_rel( - "example.com", - "GerZNDnDZVjsOtardLuwfIBg", - 800, - 600, - "image/jpeg", - "scale", - ), - "remote_thumbnail/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale", - ) - self.assertEqual( - self.filepaths.remote_media_thumbnail( - "example.com", - "GerZNDnDZVjsOtardLuwfIBg", - 800, - 600, - "image/jpeg", - "scale", - ), - "/media_store/remote_thumbnail/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale", - ) - - def test_remote_media_thumbnail_legacy(self) -> None: - """Test old-style remote media thumbnail paths""" - self.assertEqual( - self.filepaths.remote_media_thumbnail_rel_legacy( - "example.com", "GerZNDnDZVjsOtardLuwfIBg", 800, 600, "image/jpeg" - ), - "remote_thumbnail/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg", - ) - - def test_remote_media_thumbnail_dir(self) -> None: - """Test remote media thumbnail directory paths""" - self.assertEqual( - self.filepaths.remote_media_thumbnail_dir( - "example.com", "GerZNDnDZVjsOtardLuwfIBg" - ), - "/media_store/remote_thumbnail/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg", - ) - - def test_url_cache_filepath(self) -> None: - """Test URL cache paths""" - self.assertEqual( - self.filepaths.url_cache_filepath_rel("2020-01-02_GerZNDnDZVjsOtar"), - "url_cache/2020-01-02/GerZNDnDZVjsOtar", - ) - self.assertEqual( - self.filepaths.url_cache_filepath("2020-01-02_GerZNDnDZVjsOtar"), - "/media_store/url_cache/2020-01-02/GerZNDnDZVjsOtar", - ) - - def test_url_cache_filepath_legacy(self) -> None: - """Test old-style URL cache paths""" - self.assertEqual( - self.filepaths.url_cache_filepath_rel("GerZNDnDZVjsOtardLuwfIBg"), - "url_cache/Ge/rZ/NDnDZVjsOtardLuwfIBg", - ) - self.assertEqual( - self.filepaths.url_cache_filepath("GerZNDnDZVjsOtardLuwfIBg"), - "/media_store/url_cache/Ge/rZ/NDnDZVjsOtardLuwfIBg", - ) - - def test_url_cache_filepath_dirs_to_delete(self) -> None: - """Test URL cache cleanup paths""" - self.assertEqual( - self.filepaths.url_cache_filepath_dirs_to_delete( - "2020-01-02_GerZNDnDZVjsOtar" - ), - ["/media_store/url_cache/2020-01-02"], - ) - - def test_url_cache_filepath_dirs_to_delete_legacy(self) -> None: - """Test old-style URL cache cleanup paths""" - self.assertEqual( - self.filepaths.url_cache_filepath_dirs_to_delete( - "GerZNDnDZVjsOtardLuwfIBg" - ), - [ - "/media_store/url_cache/Ge/rZ", - "/media_store/url_cache/Ge", - ], - ) - - def test_url_cache_thumbnail(self) -> None: - """Test URL cache thumbnail paths""" - self.assertEqual( - self.filepaths.url_cache_thumbnail_rel( - "2020-01-02_GerZNDnDZVjsOtar", 800, 600, "image/jpeg", "scale" - ), - "url_cache_thumbnails/2020-01-02/GerZNDnDZVjsOtar/800-600-image-jpeg-scale", - ) - self.assertEqual( - self.filepaths.url_cache_thumbnail( - "2020-01-02_GerZNDnDZVjsOtar", 800, 600, "image/jpeg", "scale" - ), - "/media_store/url_cache_thumbnails/2020-01-02/GerZNDnDZVjsOtar/800-600-image-jpeg-scale", - ) - - def test_url_cache_thumbnail_legacy(self) -> None: - """Test old-style URL cache thumbnail paths""" - self.assertEqual( - self.filepaths.url_cache_thumbnail_rel( - "GerZNDnDZVjsOtardLuwfIBg", 800, 600, "image/jpeg", "scale" - ), - "url_cache_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale", - ) - self.assertEqual( - self.filepaths.url_cache_thumbnail( - "GerZNDnDZVjsOtardLuwfIBg", 800, 600, "image/jpeg", "scale" - ), - "/media_store/url_cache_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale", - ) - - def test_url_cache_thumbnail_directory(self) -> None: - """Test URL cache thumbnail directory paths""" - self.assertEqual( - self.filepaths.url_cache_thumbnail_directory_rel( - "2020-01-02_GerZNDnDZVjsOtar" - ), - "url_cache_thumbnails/2020-01-02/GerZNDnDZVjsOtar", - ) - self.assertEqual( - self.filepaths.url_cache_thumbnail_directory("2020-01-02_GerZNDnDZVjsOtar"), - "/media_store/url_cache_thumbnails/2020-01-02/GerZNDnDZVjsOtar", - ) - - def test_url_cache_thumbnail_directory_legacy(self) -> None: - """Test old-style URL cache thumbnail directory paths""" - self.assertEqual( - self.filepaths.url_cache_thumbnail_directory_rel( - "GerZNDnDZVjsOtardLuwfIBg" - ), - "url_cache_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg", - ) - self.assertEqual( - self.filepaths.url_cache_thumbnail_directory("GerZNDnDZVjsOtardLuwfIBg"), - "/media_store/url_cache_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg", - ) - - def test_url_cache_thumbnail_dirs_to_delete(self) -> None: - """Test URL cache thumbnail cleanup paths""" - self.assertEqual( - self.filepaths.url_cache_thumbnail_dirs_to_delete( - "2020-01-02_GerZNDnDZVjsOtar" - ), - [ - "/media_store/url_cache_thumbnails/2020-01-02/GerZNDnDZVjsOtar", - "/media_store/url_cache_thumbnails/2020-01-02", - ], - ) - - def test_url_cache_thumbnail_dirs_to_delete_legacy(self) -> None: - """Test old-style URL cache thumbnail cleanup paths""" - self.assertEqual( - self.filepaths.url_cache_thumbnail_dirs_to_delete( - "GerZNDnDZVjsOtardLuwfIBg" - ), - [ - "/media_store/url_cache_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg", - "/media_store/url_cache_thumbnails/Ge/rZ", - "/media_store/url_cache_thumbnails/Ge", - ], - ) - - def test_server_name_validation(self) -> None: - """Test validation of server names""" - self._test_path_validation( - [ - "remote_media_filepath_rel", - "remote_media_filepath", - "remote_media_thumbnail_rel", - "remote_media_thumbnail", - "remote_media_thumbnail_rel_legacy", - "remote_media_thumbnail_dir", - ], - parameter="server_name", - valid_values=[ - "matrix.org", - "matrix.org:8448", - "matrix-federation.matrix.org", - "matrix-federation.matrix.org:8448", - "10.1.12.123", - "10.1.12.123:8448", - "[fd00:abcd::ffff]", - "[fd00:abcd::ffff]:8448", - ], - invalid_values=[ - "/matrix.org", - "matrix.org/..", - "matrix.org\x00", - "", - ".", - "..", - "/", - ], - ) - - def test_file_id_validation(self) -> None: - """Test validation of local, remote and legacy URL cache file / media IDs""" - # File / media IDs get split into three parts to form paths, consisting of the - # first two characters, next two characters and rest of the ID. - valid_file_ids = [ - "GerZNDnDZVjsOtardLuwfIBg", - # Unexpected, but produces an acceptable path: - "GerZN", # "N" becomes the last directory - ] - invalid_file_ids = [ - "/erZNDnDZVjsOtardLuwfIBg", - "Ge/ZNDnDZVjsOtardLuwfIBg", - "GerZ/DnDZVjsOtardLuwfIBg", - "GerZ/..", - "G\x00rZNDnDZVjsOtardLuwfIBg", - "Ger\x00NDnDZVjsOtardLuwfIBg", - "GerZNDnDZVjsOtardLuwfIBg\x00", - "", - "Ge", - "GerZ", - "GerZ.", - "..rZNDnDZVjsOtardLuwfIBg", - "Ge..NDnDZVjsOtardLuwfIBg", - "GerZ..", - "GerZ/", - ] - - self._test_path_validation( - [ - "local_media_filepath_rel", - "local_media_filepath", - "local_media_thumbnail_rel", - "local_media_thumbnail", - "local_media_thumbnail_dir", - # Legacy URL cache media IDs - "url_cache_filepath_rel", - "url_cache_filepath", - # `url_cache_filepath_dirs_to_delete` is tested below. - "url_cache_thumbnail_rel", - "url_cache_thumbnail", - "url_cache_thumbnail_directory_rel", - "url_cache_thumbnail_directory", - "url_cache_thumbnail_dirs_to_delete", - ], - parameter="media_id", - valid_values=valid_file_ids, - invalid_values=invalid_file_ids, - ) - - # `url_cache_filepath_dirs_to_delete` ignores what would be the last path - # component, so only the first 4 characters matter. - self._test_path_validation( - [ - "url_cache_filepath_dirs_to_delete", - ], - parameter="media_id", - valid_values=valid_file_ids, - invalid_values=[ - "/erZNDnDZVjsOtardLuwfIBg", - "Ge/ZNDnDZVjsOtardLuwfIBg", - "G\x00rZNDnDZVjsOtardLuwfIBg", - "Ger\x00NDnDZVjsOtardLuwfIBg", - "", - "Ge", - "..rZNDnDZVjsOtardLuwfIBg", - "Ge..NDnDZVjsOtardLuwfIBg", - ], - ) - - self._test_path_validation( - [ - "remote_media_filepath_rel", - "remote_media_filepath", - "remote_media_thumbnail_rel", - "remote_media_thumbnail", - "remote_media_thumbnail_rel_legacy", - "remote_media_thumbnail_dir", - ], - parameter="file_id", - valid_values=valid_file_ids, - invalid_values=invalid_file_ids, - ) - - def test_url_cache_media_id_validation(self) -> None: - """Test validation of URL cache media IDs""" - self._test_path_validation( - [ - "url_cache_filepath_rel", - "url_cache_filepath", - # `url_cache_filepath_dirs_to_delete` only cares about the date prefix - "url_cache_thumbnail_rel", - "url_cache_thumbnail", - "url_cache_thumbnail_directory_rel", - "url_cache_thumbnail_directory", - "url_cache_thumbnail_dirs_to_delete", - ], - parameter="media_id", - valid_values=[ - "2020-01-02_GerZNDnDZVjsOtar", - "2020-01-02_G", # Unexpected, but produces an acceptable path - ], - invalid_values=[ - "2020-01-02", - "2020-01-02-", - "2020-01-02-.", - "2020-01-02-..", - "2020-01-02-/", - "2020-01-02-/GerZNDnDZVjsOtar", - "2020-01-02-GerZNDnDZVjsOtar/..", - "2020-01-02-GerZNDnDZVjsOtar\x00", - ], - ) - - def test_content_type_validation(self) -> None: - """Test validation of thumbnail content types""" - self._test_path_validation( - [ - "local_media_thumbnail_rel", - "local_media_thumbnail", - "remote_media_thumbnail_rel", - "remote_media_thumbnail", - "remote_media_thumbnail_rel_legacy", - "url_cache_thumbnail_rel", - "url_cache_thumbnail", - ], - parameter="content_type", - valid_values=[ - "image/jpeg", - ], - invalid_values=[ - "", # ValueError: not enough values to unpack - "image/jpeg/abc", # ValueError: too many values to unpack - "image/jpeg\x00", - ], - ) - - def test_thumbnail_method_validation(self) -> None: - """Test validation of thumbnail methods""" - self._test_path_validation( - [ - "local_media_thumbnail_rel", - "local_media_thumbnail", - "remote_media_thumbnail_rel", - "remote_media_thumbnail", - "url_cache_thumbnail_rel", - "url_cache_thumbnail", - ], - parameter="method", - valid_values=[ - "crop", - "scale", - ], - invalid_values=[ - "/scale", - "scale/..", - "scale\x00", - "/", - ], - ) - - def _test_path_validation( - self, - methods: Iterable[str], - parameter: str, - valid_values: Iterable[str], - invalid_values: Iterable[str], - ) -> None: - """Test that the specified methods validate the named parameter as expected - - Args: - methods: The names of `MediaFilePaths` methods to test - parameter: The name of the parameter to test - valid_values: A list of parameter values that are expected to be accepted - invalid_values: A list of parameter values that are expected to be rejected - - Raises: - AssertionError: If a value was accepted when it should have failed - validation. - ValueError: If a value failed validation when it should have been accepted. - """ - for method in methods: - get_path = getattr(self.filepaths, method) - - parameters = inspect.signature(get_path).parameters - kwargs = { - "server_name": "matrix.org", - "media_id": "GerZNDnDZVjsOtardLuwfIBg", - "file_id": "GerZNDnDZVjsOtardLuwfIBg", - "width": 800, - "height": 600, - "content_type": "image/jpeg", - "method": "scale", - } - - if get_path.__name__.startswith("url_"): - kwargs["media_id"] = "2020-01-02_GerZNDnDZVjsOtar" - - kwargs = {k: v for k, v in kwargs.items() if k in parameters} - kwargs.pop(parameter) - - for value in valid_values: - kwargs[parameter] = value - get_path(**kwargs) - # No exception should be raised - - for value in invalid_values: - with self.assertRaises(ValueError): - kwargs[parameter] = value - path_or_list = get_path(**kwargs) - self.fail( - f"{value!r} unexpectedly passed validation: " - f"{method} returned {path_or_list!r}" - ) - - -class MediaFilePathsJailTestCase(unittest.TestCase): - def _check_relative_path(self, filepaths: MediaFilePaths, path: str) -> None: - """Passes a relative path through the jail check. - - Args: - filepaths: The `MediaFilePaths` instance. - path: A path relative to the media store directory. - - Raises: - ValueError: If the jail check fails. - """ - - @_wrap_with_jail_check(relative=True) - def _make_relative_path(self: MediaFilePaths, path: str) -> str: - return path - - _make_relative_path(filepaths, path) - - def _check_absolute_path(self, filepaths: MediaFilePaths, path: str) -> None: - """Passes an absolute path through the jail check. - - Args: - filepaths: The `MediaFilePaths` instance. - path: A path relative to the media store directory. - - Raises: - ValueError: If the jail check fails. - """ - - @_wrap_with_jail_check(relative=False) - def _make_absolute_path(self: MediaFilePaths, path: str) -> str: - return os.path.join(self.base_path, path) - - _make_absolute_path(filepaths, path) - - def test_traversal_inside(self) -> None: - """Test the jail check for paths that stay within the media directory.""" - # Despite the `../`s, these paths still lie within the media directory and it's - # expected for the jail check to allow them through. - # These paths ought to trip the other checks in place and should never be - # returned. - filepaths = MediaFilePaths("/media_store") - path = "url_cache/2020-01-02/../../GerZNDnDZVjsOtar" - self._check_relative_path(filepaths, path) - self._check_absolute_path(filepaths, path) - - def test_traversal_outside(self) -> None: - """Test that the jail check fails for paths that escape the media directory.""" - filepaths = MediaFilePaths("/media_store") - path = "url_cache/2020-01-02/../../../GerZNDnDZVjsOtar" - with self.assertRaises(ValueError): - self._check_relative_path(filepaths, path) - with self.assertRaises(ValueError): - self._check_absolute_path(filepaths, path) - - def test_traversal_reentry(self) -> None: - """Test the jail check for paths that exit and re-enter the media directory.""" - # These paths lie outside the media directory if it is a symlink, and inside - # otherwise. Ideally the check should fail, but this proves difficult. - # This test documents the behaviour for this edge case. - # These paths ought to trip the other checks in place and should never be - # returned. - filepaths = MediaFilePaths("/media_store") - path = "url_cache/2020-01-02/../../../media_store/GerZNDnDZVjsOtar" - self._check_relative_path(filepaths, path) - self._check_absolute_path(filepaths, path) - - def test_symlink(self) -> None: - """Test that a symlink does not cause the jail check to fail.""" - media_store_path = self.mktemp() - - # symlink the media store directory - os.symlink("/mnt/synapse/media_store", media_store_path) - - # Test that relative and absolute paths don't trip the check - # NB: `media_store_path` is a relative path - filepaths = MediaFilePaths(media_store_path) - self._check_relative_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar") - self._check_absolute_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar") - - filepaths = MediaFilePaths(os.path.abspath(media_store_path)) - self._check_relative_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar") - self._check_absolute_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar") - - def test_symlink_subdirectory(self) -> None: - """Test that a symlinked subdirectory does not cause the jail check to fail.""" - media_store_path = self.mktemp() - os.mkdir(media_store_path) - - # symlink `url_cache/` - os.symlink( - "/mnt/synapse/media_store_url_cache", - os.path.join(media_store_path, "url_cache"), - ) - - # Test that relative and absolute paths don't trip the check - # NB: `media_store_path` is a relative path - filepaths = MediaFilePaths(media_store_path) - self._check_relative_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar") - self._check_absolute_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar") - - filepaths = MediaFilePaths(os.path.abspath(media_store_path)) - self._check_relative_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar") - self._check_absolute_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar") diff --git a/tests/rest/media/v1/test_html_preview.py b/tests/rest/media/v1/test_html_preview.py deleted file mode 100644
index 1062081a06..0000000000 --- a/tests/rest/media/v1/test_html_preview.py +++ /dev/null
@@ -1,542 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from synapse.rest.media.v1.preview_html import ( - _get_html_media_encodings, - decode_body, - parse_html_to_open_graph, - summarize_paragraphs, -) - -from tests import unittest - -try: - import lxml -except ImportError: - lxml = None - - -class SummarizeTestCase(unittest.TestCase): - if not lxml: - skip = "url preview feature requires lxml" - - def test_long_summarize(self) -> None: - example_paras = [ - """Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami: - Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in - Troms county, Norway. The administrative centre of the municipality is - the city of Tromsø. Outside of Norway, Tromso and Tromsö are - alternative spellings of the city.Tromsø is considered the northernmost - city in the world with a population above 50,000. The most populous town - north of it is Alta, Norway, with a population of 14,272 (2013).""", - """Tromsø lies in Northern Norway. The municipality has a population of - (2015) 72,066, but with an annual influx of students it has over 75,000 - most of the year. It is the largest urban area in Northern Norway and the - third largest north of the Arctic Circle (following Murmansk and Norilsk). - Most of Tromsø, including the city centre, is located on the island of - Tromsøya, 350 kilometres (217 mi) north of the Arctic Circle. In 2012, - Tromsøya had a population of 36,088. Substantial parts of the urban area - are also situated on the mainland to the east, and on parts of Kvaløya—a - large island to the west. Tromsøya is connected to the mainland by the Tromsø - Bridge and the Tromsøysund Tunnel, and to the island of Kvaløya by the - Sandnessund Bridge. Tromsø Airport connects the city to many destinations - in Europe. The city is warmer than most other places located on the same - latitude, due to the warming effect of the Gulf Stream.""", - """The city centre of Tromsø contains the highest number of old wooden - houses in Northern Norway, the oldest house dating from 1789. The Arctic - Cathedral, a modern church from 1965, is probably the most famous landmark - in Tromsø. The city is a cultural centre for its region, with several - festivals taking place in the summer. Some of Norway's best-known - musicians, Torbjørn Brundtland and Svein Berge of the electronica duo - Röyksopp and Lene Marlin grew up and started their careers in Tromsø. - Noted electronic musician Geir Jenssen also hails from Tromsø.""", - ] - - desc = summarize_paragraphs(example_paras, min_size=200, max_size=500) - - self.assertEqual( - desc, - "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" - " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" - " Troms county, Norway. The administrative centre of the municipality is" - " the city of Tromsø. Outside of Norway, Tromso and Tromsö are" - " alternative spellings of the city.Tromsø is considered the northernmost" - " city in the world with a population above 50,000. The most populous town" - " north of it is Alta, Norway, with a population of 14,272 (2013).", - ) - - desc = summarize_paragraphs(example_paras[1:], min_size=200, max_size=500) - - self.assertEqual( - desc, - "Tromsø lies in Northern Norway. The municipality has a population of" - " (2015) 72,066, but with an annual influx of students it has over 75,000" - " most of the year. It is the largest urban area in Northern Norway and the" - " third largest north of the Arctic Circle (following Murmansk and Norilsk)." - " Most of Tromsø, including the city centre, is located on the island of" - " Tromsøya, 350 kilometres (217 mi) north of the Arctic Circle. In 2012," - " Tromsøya had a population of 36,088. Substantial parts of the urban…", - ) - - def test_short_summarize(self) -> None: - example_paras = [ - "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" - " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" - " Troms county, Norway.", - "Tromsø lies in Northern Norway. The municipality has a population of" - " (2015) 72,066, but with an annual influx of students it has over 75,000" - " most of the year.", - "The city centre of Tromsø contains the highest number of old wooden" - " houses in Northern Norway, the oldest house dating from 1789. The Arctic" - " Cathedral, a modern church from 1965, is probably the most famous landmark" - " in Tromsø.", - ] - - desc = summarize_paragraphs(example_paras, min_size=200, max_size=500) - - self.assertEqual( - desc, - "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" - " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" - " Troms county, Norway.\n" - "\n" - "Tromsø lies in Northern Norway. The municipality has a population of" - " (2015) 72,066, but with an annual influx of students it has over 75,000" - " most of the year.", - ) - - def test_small_then_large_summarize(self) -> None: - example_paras = [ - "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" - " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" - " Troms county, Norway.", - "Tromsø lies in Northern Norway. The municipality has a population of" - " (2015) 72,066, but with an annual influx of students it has over 75,000" - " most of the year." - " The city centre of Tromsø contains the highest number of old wooden" - " houses in Northern Norway, the oldest house dating from 1789. The Arctic" - " Cathedral, a modern church from 1965, is probably the most famous landmark" - " in Tromsø.", - ] - - desc = summarize_paragraphs(example_paras, min_size=200, max_size=500) - self.assertEqual( - desc, - "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" - " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" - " Troms county, Norway.\n" - "\n" - "Tromsø lies in Northern Norway. The municipality has a population of" - " (2015) 72,066, but with an annual influx of students it has over 75,000" - " most of the year. The city centre of Tromsø contains the highest number" - " of old wooden houses in Northern Norway, the oldest house dating from" - " 1789. The Arctic Cathedral, a modern church from…", - ) - - -class OpenGraphFromHtmlTestCase(unittest.TestCase): - if not lxml: - skip = "url preview feature requires lxml" - - def test_simple(self) -> None: - html = b""" - <html> - <head><title>Foo</title></head> - <body> - Some text. - </body> - </html> - """ - - tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree) - - self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) - - def test_comment(self) -> None: - html = b""" - <html> - <head><title>Foo</title></head> - <body> - <!-- HTML comment --> - Some text. - </body> - </html> - """ - - tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree) - - self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) - - def test_comment2(self) -> None: - html = b""" - <html> - <head><title>Foo</title></head> - <body> - Some text. - <!-- HTML comment --> - Some more text. - <p>Text</p> - More text - </body> - </html> - """ - - tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree) - - self.assertEqual( - og, - { - "og:title": "Foo", - "og:description": "Some text.\n\nSome more text.\n\nText\n\nMore text", - }, - ) - - def test_script(self) -> None: - html = b""" - <html> - <head><title>Foo</title></head> - <body> - <script> (function() {})() </script> - Some text. - </body> - </html> - """ - - tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree) - - self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) - - def test_missing_title(self) -> None: - html = b""" - <html> - <body> - Some text. - </body> - </html> - """ - - tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree) - - self.assertEqual(og, {"og:title": None, "og:description": "Some text."}) - - # Another variant is a title with no content. - html = b""" - <html> - <head><title></title></head> - <body> - <h1>Title</h1> - </body> - </html> - """ - - tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree) - - self.assertEqual(og, {"og:title": "Title", "og:description": "Title"}) - - def test_h1_as_title(self) -> None: - html = b""" - <html> - <meta property="og:description" content="Some text."/> - <body> - <h1>Title</h1> - </body> - </html> - """ - - tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree) - - self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."}) - - def test_empty_description(self) -> None: - """Description tags with empty content should be ignored.""" - html = b""" - <html> - <meta property="og:description" content=""/> - <meta property="og:description"/> - <meta name="description" content=""/> - <meta name="description"/> - <meta name="description" content="Finally!"/> - <body> - <h1>Title</h1> - </body> - </html> - """ - - tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree) - - self.assertEqual(og, {"og:title": "Title", "og:description": "Finally!"}) - - def test_missing_title_and_broken_h1(self) -> None: - html = b""" - <html> - <body> - <h1><a href="foo"/></h1> - Some text. - </body> - </html> - """ - - tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree) - - self.assertEqual(og, {"og:title": None, "og:description": "Some text."}) - - def test_empty(self) -> None: - """Test a body with no data in it.""" - html = b"" - tree = decode_body(html, "http://example.com/test.html") - self.assertIsNone(tree) - - def test_no_tree(self) -> None: - """A valid body with no tree in it.""" - html = b"\x00" - tree = decode_body(html, "http://example.com/test.html") - self.assertIsNone(tree) - - def test_xml(self) -> None: - """Test decoding XML and ensure it works properly.""" - # Note that the strip() call is important to ensure the xml tag starts - # at the initial byte. - html = b""" - <?xml version="1.0" encoding="UTF-8"?> - - <!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd"> - <html xmlns="http://www.w3.org/1999/xhtml" xml:lang="en" lang="en"> - <head><title>Foo</title></head><body>Some text.</body></html> - """.strip() - tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree) - self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) - - def test_invalid_encoding(self) -> None: - """An invalid character encoding should be ignored and treated as UTF-8, if possible.""" - html = b""" - <html> - <head><title>Foo</title></head> - <body> - Some text. - </body> - </html> - """ - tree = decode_body(html, "http://example.com/test.html", "invalid-encoding") - og = parse_html_to_open_graph(tree) - self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) - - def test_invalid_encoding2(self) -> None: - """A body which doesn't match the sent character encoding.""" - # Note that this contains an invalid UTF-8 sequence in the title. - html = b""" - <html> - <head><title>\xff\xff Foo</title></head> - <body> - Some text. - </body> - </html> - """ - tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree) - self.assertEqual(og, {"og:title": "ÿÿ Foo", "og:description": "Some text."}) - - def test_windows_1252(self) -> None: - """A body which uses cp1252, but doesn't declare that.""" - html = b""" - <html> - <head><title>\xf3</title></head> - <body> - Some text. - </body> - </html> - """ - tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree) - self.assertEqual(og, {"og:title": "ó", "og:description": "Some text."}) - - def test_twitter_tag(self) -> None: - """Twitter card tags should be used if nothing else is available.""" - html = b""" - <html> - <meta name="twitter:card" content="summary"> - <meta name="twitter:description" content="Description"> - <meta name="twitter:site" content="@matrixdotorg"> - </html> - """ - tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree) - self.assertEqual( - og, - { - "og:title": None, - "og:description": "Description", - "og:site_name": "@matrixdotorg", - }, - ) - - # But they shouldn't override Open Graph values. - html = b""" - <html> - <meta name="twitter:card" content="summary"> - <meta name="twitter:description" content="Description"> - <meta property="og:description" content="Real Description"> - <meta name="twitter:site" content="@matrixdotorg"> - <meta property="og:site_name" content="matrix.org"> - </html> - """ - tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree) - self.assertEqual( - og, - { - "og:title": None, - "og:description": "Real Description", - "og:site_name": "matrix.org", - }, - ) - - def test_nested_nodes(self) -> None: - """A body with some nested nodes. Tests that we iterate over children - in the right order (and don't reverse the order of the text).""" - html = b""" - <a href="somewhere">Welcome <b>the bold <u>and underlined text <svg> - with a cheeky SVG</svg></u> and <strong>some</strong> tail text</b></a> - """ - tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree) - self.assertEqual( - og, - { - "og:title": None, - "og:description": "Welcome\n\nthe bold\n\nand underlined text\n\nand\n\nsome\n\ntail text", - }, - ) - - -class MediaEncodingTestCase(unittest.TestCase): - def test_meta_charset(self) -> None: - """A character encoding is found via the meta tag.""" - encodings = _get_html_media_encodings( - b""" - <html> - <head><meta charset="ascii"> - </head> - </html> - """, - "text/html", - ) - self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"]) - - # A less well-formed version. - encodings = _get_html_media_encodings( - b""" - <html> - <head>< meta charset = ascii> - </head> - </html> - """, - "text/html", - ) - self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"]) - - def test_meta_charset_underscores(self) -> None: - """A character encoding contains underscore.""" - encodings = _get_html_media_encodings( - b""" - <html> - <head><meta charset="Shift_JIS"> - </head> - </html> - """, - "text/html", - ) - self.assertEqual(list(encodings), ["shift_jis", "utf-8", "cp1252"]) - - def test_xml_encoding(self) -> None: - """A character encoding is found via the meta tag.""" - encodings = _get_html_media_encodings( - b""" - <?xml version="1.0" encoding="ascii"?> - <html> - </html> - """, - "text/html", - ) - self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"]) - - def test_meta_xml_encoding(self) -> None: - """Meta tags take precedence over XML encoding.""" - encodings = _get_html_media_encodings( - b""" - <?xml version="1.0" encoding="ascii"?> - <html> - <head><meta charset="UTF-16"> - </head> - </html> - """, - "text/html", - ) - self.assertEqual(list(encodings), ["utf-16", "ascii", "utf-8", "cp1252"]) - - def test_content_type(self) -> None: - """A character encoding is found via the Content-Type header.""" - # Test a few variations of the header. - headers = ( - 'text/html; charset="ascii";', - "text/html;charset=ascii;", - 'text/html; charset="ascii"', - "text/html; charset=ascii", - 'text/html; charset="ascii;', - 'text/html; charset=ascii";', - ) - for header in headers: - encodings = _get_html_media_encodings(b"", header) - self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"]) - - def test_fallback(self) -> None: - """A character encoding cannot be found in the body or header.""" - encodings = _get_html_media_encodings(b"", "text/html") - self.assertEqual(list(encodings), ["utf-8", "cp1252"]) - - def test_duplicates(self) -> None: - """Ensure each encoding is only attempted once.""" - encodings = _get_html_media_encodings( - b""" - <?xml version="1.0" encoding="utf8"?> - <html> - <head><meta charset="UTF-8"> - </head> - </html> - """, - 'text/html; charset="UTF_8"', - ) - self.assertEqual(list(encodings), ["utf-8", "cp1252"]) - - def test_unknown_invalid(self) -> None: - """A character encoding should be ignored if it is unknown or invalid.""" - encodings = _get_html_media_encodings( - b""" - <html> - <head><meta charset="invalid"> - </head> - </html> - """, - 'text/html; charset="invalid"', - ) - self.assertEqual(list(encodings), ["utf-8", "cp1252"]) diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py deleted file mode 100644
index d18fc13c21..0000000000 --- a/tests/rest/media/v1/test_media_storage.py +++ /dev/null
@@ -1,782 +0,0 @@ -# Copyright 2018-2021 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os -import shutil -import tempfile -from binascii import unhexlify -from io import BytesIO -from typing import Any, BinaryIO, Dict, List, Optional, Union -from unittest.mock import Mock -from urllib import parse - -import attr -from parameterized import parameterized, parameterized_class -from PIL import Image as Image -from typing_extensions import Literal - -from twisted.internet import defer -from twisted.internet.defer import Deferred -from twisted.test.proto_helpers import MemoryReactor - -from synapse.api.errors import Codes -from synapse.events import EventBase -from synapse.events.spamcheck import load_legacy_spam_checkers -from synapse.logging.context import make_deferred_yieldable -from synapse.module_api import ModuleApi -from synapse.rest import admin -from synapse.rest.client import login -from synapse.rest.media.v1._base import FileInfo -from synapse.rest.media.v1.filepath import MediaFilePaths -from synapse.rest.media.v1.media_storage import MediaStorage, ReadableFileWrapper -from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend -from synapse.server import HomeServer -from synapse.types import RoomAlias -from synapse.util import Clock - -from tests import unittest -from tests.server import FakeChannel, FakeSite, make_request -from tests.test_utils import SMALL_PNG -from tests.utils import default_config - - -class MediaStorageTests(unittest.HomeserverTestCase): - - needs_threadpool = True - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.test_dir = tempfile.mkdtemp(prefix="synapse-tests-") - self.addCleanup(shutil.rmtree, self.test_dir) - - self.primary_base_path = os.path.join(self.test_dir, "primary") - self.secondary_base_path = os.path.join(self.test_dir, "secondary") - - hs.config.media.media_store_path = self.primary_base_path - - storage_providers = [FileStorageProviderBackend(hs, self.secondary_base_path)] - - self.filepaths = MediaFilePaths(self.primary_base_path) - self.media_storage = MediaStorage( - hs, self.primary_base_path, self.filepaths, storage_providers - ) - - def test_ensure_media_is_in_local_cache(self) -> None: - media_id = "some_media_id" - test_body = "Test\n" - - # First we create a file that is in a storage provider but not in the - # local primary media store - rel_path = self.filepaths.local_media_filepath_rel(media_id) - secondary_path = os.path.join(self.secondary_base_path, rel_path) - - os.makedirs(os.path.dirname(secondary_path)) - - with open(secondary_path, "w") as f: - f.write(test_body) - - # Now we run ensure_media_is_in_local_cache, which should copy the file - # to the local cache. - file_info = FileInfo(None, media_id) - - # This uses a real blocking threadpool so we have to wait for it to be - # actually done :/ - x = defer.ensureDeferred( - self.media_storage.ensure_media_is_in_local_cache(file_info) - ) - - # Hotloop until the threadpool does its job... - self.wait_on_thread(x) - - local_path = self.get_success(x) - - self.assertTrue(os.path.exists(local_path)) - - # Asserts the file is under the expected local cache directory - self.assertEqual( - os.path.commonprefix([self.primary_base_path, local_path]), - self.primary_base_path, - ) - - with open(local_path) as f: - body = f.read() - - self.assertEqual(test_body, body) - - -@attr.s(auto_attribs=True, slots=True, frozen=True) -class _TestImage: - """An image for testing thumbnailing with the expected results - - Attributes: - data: The raw image to thumbnail - content_type: The type of the image as a content type, e.g. "image/png" - extension: The extension associated with the format, e.g. ".png" - expected_cropped: The expected bytes from cropped thumbnailing, or None if - test should just check for success. - expected_scaled: The expected bytes from scaled thumbnailing, or None if - test should just check for a valid image returned. - expected_found: True if the file should exist on the server, or False if - a 404/400 is expected. - unable_to_thumbnail: True if we expect the thumbnailing to fail (400), or - False if the thumbnailing should succeed or a normal 404 is expected. - """ - - data: bytes - content_type: bytes - extension: bytes - expected_cropped: Optional[bytes] = None - expected_scaled: Optional[bytes] = None - expected_found: bool = True - unable_to_thumbnail: bool = False - - -@parameterized_class( - ("test_image",), - [ - # small png - ( - _TestImage( - SMALL_PNG, - b"image/png", - b".png", - unhexlify( - b"89504e470d0a1a0a0000000d4948445200000020000000200806" - b"000000737a7af40000001a49444154789cedc101010000008220" - b"ffaf6e484001000000ef0610200001194334ee0000000049454e" - b"44ae426082" - ), - unhexlify( - b"89504e470d0a1a0a0000000d4948445200000001000000010806" - b"0000001f15c4890000000d49444154789c636060606000000005" - b"0001a5f645400000000049454e44ae426082" - ), - ), - ), - # small png with transparency. - ( - _TestImage( - unhexlify( - b"89504e470d0a1a0a0000000d49484452000000010000000101000" - b"00000376ef9240000000274524e5300010194fdae0000000a4944" - b"4154789c636800000082008177cd72b60000000049454e44ae426" - b"082" - ), - b"image/png", - b".png", - # Note that we don't check the output since it varies across - # different versions of Pillow. - ), - ), - # small lossless webp - ( - _TestImage( - unhexlify( - b"524946461a000000574542505650384c0d0000002f0000001007" - b"1011118888fe0700" - ), - b"image/webp", - b".webp", - ), - ), - # an empty file - ( - _TestImage( - b"", - b"image/gif", - b".gif", - expected_found=False, - unable_to_thumbnail=True, - ), - ), - ], -) -class MediaRepoTests(unittest.HomeserverTestCase): - - hijack_auth = True - user_id = "@test:user" - - def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - - self.fetches = [] - - def get_file( - destination: str, - path: str, - output_stream: BinaryIO, - args: Optional[Dict[str, Union[str, List[str]]]] = None, - max_size: Optional[int] = None, - ) -> Deferred: - """ - Returns tuple[int,dict,str,int] of file length, response headers, - absolute URI, and response code. - """ - - def write_to(r): - data, response = r - output_stream.write(data) - return response - - d = Deferred() - d.addCallback(write_to) - self.fetches.append((d, destination, path, args)) - return make_deferred_yieldable(d) - - client = Mock() - client.get_file = get_file - - self.storage_path = self.mktemp() - self.media_store_path = self.mktemp() - os.mkdir(self.storage_path) - os.mkdir(self.media_store_path) - - config = self.default_config() - config["media_store_path"] = self.media_store_path - config["max_image_pixels"] = 2000000 - - provider_config = { - "module": "synapse.rest.media.v1.storage_provider.FileStorageProviderBackend", - "store_local": True, - "store_synchronous": False, - "store_remote": True, - "config": {"directory": self.storage_path}, - } - config["media_storage_providers"] = [provider_config] - - hs = self.setup_test_homeserver(config=config, federation_http_client=client) - - return hs - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - - media_resource = hs.get_media_repository_resource() - self.download_resource = media_resource.children[b"download"] - self.thumbnail_resource = media_resource.children[b"thumbnail"] - self.store = hs.get_datastores().main - self.media_repo = hs.get_media_repository() - - self.media_id = "example.com/12345" - - def _req( - self, content_disposition: Optional[bytes], include_content_type: bool = True - ) -> FakeChannel: - channel = make_request( - self.reactor, - FakeSite(self.download_resource, self.reactor), - "GET", - self.media_id, - shorthand=False, - await_result=False, - ) - self.pump() - - # We've made one fetch, to example.com, using the media URL, and asking - # the other server not to do a remote fetch - self.assertEqual(len(self.fetches), 1) - self.assertEqual(self.fetches[0][1], "example.com") - self.assertEqual( - self.fetches[0][2], "/_matrix/media/r0/download/" + self.media_id - ) - self.assertEqual(self.fetches[0][3], {"allow_remote": "false"}) - - headers = { - b"Content-Length": [b"%d" % (len(self.test_image.data))], - } - - if include_content_type: - headers[b"Content-Type"] = [self.test_image.content_type] - - if content_disposition: - headers[b"Content-Disposition"] = [content_disposition] - - self.fetches[0][0].callback( - (self.test_image.data, (len(self.test_image.data), headers)) - ) - - self.pump() - self.assertEqual(channel.code, 200) - - return channel - - def test_handle_missing_content_type(self) -> None: - channel = self._req( - b"inline; filename=out" + self.test_image.extension, - include_content_type=False, - ) - headers = channel.headers - self.assertEqual(channel.code, 200) - self.assertEqual( - headers.getRawHeaders(b"Content-Type"), [b"application/octet-stream"] - ) - - def test_disposition_filename_ascii(self) -> None: - """ - If the filename is filename=<ascii> then Synapse will decode it as an - ASCII string, and use filename= in the response. - """ - channel = self._req(b"inline; filename=out" + self.test_image.extension) - - headers = channel.headers - self.assertEqual( - headers.getRawHeaders(b"Content-Type"), [self.test_image.content_type] - ) - self.assertEqual( - headers.getRawHeaders(b"Content-Disposition"), - [b"inline; filename=out" + self.test_image.extension], - ) - - def test_disposition_filenamestar_utf8escaped(self) -> None: - """ - If the filename is filename=*utf8''<utf8 escaped> then Synapse will - correctly decode it as the UTF-8 string, and use filename* in the - response. - """ - filename = parse.quote("\u2603".encode()).encode("ascii") - channel = self._req( - b"inline; filename*=utf-8''" + filename + self.test_image.extension - ) - - headers = channel.headers - self.assertEqual( - headers.getRawHeaders(b"Content-Type"), [self.test_image.content_type] - ) - self.assertEqual( - headers.getRawHeaders(b"Content-Disposition"), - [b"inline; filename*=utf-8''" + filename + self.test_image.extension], - ) - - def test_disposition_none(self) -> None: - """ - If there is no filename, one isn't passed on in the Content-Disposition - of the request. - """ - channel = self._req(None) - - headers = channel.headers - self.assertEqual( - headers.getRawHeaders(b"Content-Type"), [self.test_image.content_type] - ) - self.assertEqual(headers.getRawHeaders(b"Content-Disposition"), None) - - def test_thumbnail_crop(self) -> None: - """Test that a cropped remote thumbnail is available.""" - self._test_thumbnail( - "crop", - self.test_image.expected_cropped, - expected_found=self.test_image.expected_found, - unable_to_thumbnail=self.test_image.unable_to_thumbnail, - ) - - def test_thumbnail_scale(self) -> None: - """Test that a scaled remote thumbnail is available.""" - self._test_thumbnail( - "scale", - self.test_image.expected_scaled, - expected_found=self.test_image.expected_found, - unable_to_thumbnail=self.test_image.unable_to_thumbnail, - ) - - def test_invalid_type(self) -> None: - """An invalid thumbnail type is never available.""" - self._test_thumbnail( - "invalid", - None, - expected_found=False, - unable_to_thumbnail=self.test_image.unable_to_thumbnail, - ) - - @unittest.override_config( - {"thumbnail_sizes": [{"width": 32, "height": 32, "method": "scale"}]} - ) - def test_no_thumbnail_crop(self) -> None: - """ - Override the config to generate only scaled thumbnails, but request a cropped one. - """ - self._test_thumbnail( - "crop", - None, - expected_found=False, - unable_to_thumbnail=self.test_image.unable_to_thumbnail, - ) - - @unittest.override_config( - {"thumbnail_sizes": [{"width": 32, "height": 32, "method": "crop"}]} - ) - def test_no_thumbnail_scale(self) -> None: - """ - Override the config to generate only cropped thumbnails, but request a scaled one. - """ - self._test_thumbnail( - "scale", - None, - expected_found=False, - unable_to_thumbnail=self.test_image.unable_to_thumbnail, - ) - - def test_thumbnail_repeated_thumbnail(self) -> None: - """Test that fetching the same thumbnail works, and deleting the on disk - thumbnail regenerates it. - """ - self._test_thumbnail( - "scale", - self.test_image.expected_scaled, - expected_found=self.test_image.expected_found, - unable_to_thumbnail=self.test_image.unable_to_thumbnail, - ) - - if not self.test_image.expected_found: - return - - # Fetching again should work, without re-requesting the image from the - # remote. - params = "?width=32&height=32&method=scale" - channel = make_request( - self.reactor, - FakeSite(self.thumbnail_resource, self.reactor), - "GET", - self.media_id + params, - shorthand=False, - await_result=False, - ) - self.pump() - - self.assertEqual(channel.code, 200) - if self.test_image.expected_scaled: - self.assertEqual( - channel.result["body"], - self.test_image.expected_scaled, - channel.result["body"], - ) - - # Deleting the thumbnail on disk then re-requesting it should work as - # Synapse should regenerate missing thumbnails. - origin, media_id = self.media_id.split("/") - info = self.get_success(self.store.get_cached_remote_media(origin, media_id)) - file_id = info["filesystem_id"] - - thumbnail_dir = self.media_repo.filepaths.remote_media_thumbnail_dir( - origin, file_id - ) - shutil.rmtree(thumbnail_dir, ignore_errors=True) - - channel = make_request( - self.reactor, - FakeSite(self.thumbnail_resource, self.reactor), - "GET", - self.media_id + params, - shorthand=False, - await_result=False, - ) - self.pump() - - self.assertEqual(channel.code, 200) - if self.test_image.expected_scaled: - self.assertEqual( - channel.result["body"], - self.test_image.expected_scaled, - channel.result["body"], - ) - - def _test_thumbnail( - self, - method: str, - expected_body: Optional[bytes], - expected_found: bool, - unable_to_thumbnail: bool = False, - ) -> None: - """Test the given thumbnailing method works as expected. - - Args: - method: The thumbnailing method to use (crop, scale). - expected_body: The expected bytes from thumbnailing, or None if - test should just check for a valid image. - expected_found: True if the file should exist on the server, or False if - a 404/400 is expected. - unable_to_thumbnail: True if we expect the thumbnailing to fail (400), or - False if the thumbnailing should succeed or a normal 404 is expected. - """ - - params = "?width=32&height=32&method=" + method - channel = make_request( - self.reactor, - FakeSite(self.thumbnail_resource, self.reactor), - "GET", - self.media_id + params, - shorthand=False, - await_result=False, - ) - self.pump() - - headers = { - b"Content-Length": [b"%d" % (len(self.test_image.data))], - b"Content-Type": [self.test_image.content_type], - } - self.fetches[0][0].callback( - (self.test_image.data, (len(self.test_image.data), headers)) - ) - self.pump() - - if expected_found: - self.assertEqual(channel.code, 200) - - self.assertEqual( - channel.headers.getRawHeaders(b"Cross-Origin-Resource-Policy"), - [b"cross-origin"], - ) - - if expected_body is not None: - self.assertEqual( - channel.result["body"], expected_body, channel.result["body"] - ) - else: - # ensure that the result is at least some valid image - Image.open(BytesIO(channel.result["body"])) - elif unable_to_thumbnail: - # A 400 with a JSON body. - self.assertEqual(channel.code, 400) - self.assertEqual( - channel.json_body, - { - "errcode": "M_UNKNOWN", - "error": "Cannot find any thumbnails for the requested media ([b'example.com', b'12345']). This might mean the media is not a supported_media_format=(image/jpeg, image/jpg, image/webp, image/gif, image/png) or that thumbnailing failed for some other reason. (Dynamic thumbnails are disabled on this server.)", - }, - ) - else: - # A 404 with a JSON body. - self.assertEqual(channel.code, 404) - self.assertEqual( - channel.json_body, - { - "errcode": "M_NOT_FOUND", - "error": "Not found [b'example.com', b'12345']", - }, - ) - - @parameterized.expand([("crop", 16), ("crop", 64), ("scale", 16), ("scale", 64)]) - def test_same_quality(self, method: str, desired_size: int) -> None: - """Test that choosing between thumbnails with the same quality rating succeeds. - - We are not particular about which thumbnail is chosen.""" - self.assertIsNotNone( - self.thumbnail_resource._select_thumbnail( - desired_width=desired_size, - desired_height=desired_size, - desired_method=method, - desired_type=self.test_image.content_type, - # Provide two identical thumbnails which are guaranteed to have the same - # quality rating. - thumbnail_infos=[ - { - "thumbnail_width": 32, - "thumbnail_height": 32, - "thumbnail_method": method, - "thumbnail_type": self.test_image.content_type, - "thumbnail_length": 256, - "filesystem_id": f"thumbnail1{self.test_image.extension}", - }, - { - "thumbnail_width": 32, - "thumbnail_height": 32, - "thumbnail_method": method, - "thumbnail_type": self.test_image.content_type, - "thumbnail_length": 256, - "filesystem_id": f"thumbnail2{self.test_image.extension}", - }, - ], - file_id=f"image{self.test_image.extension}", - url_cache=None, - server_name=None, - ) - ) - - def test_x_robots_tag_header(self) -> None: - """ - Tests that the `X-Robots-Tag` header is present, which informs web crawlers - to not index, archive, or follow links in media. - """ - channel = self._req(b"inline; filename=out" + self.test_image.extension) - - headers = channel.headers - self.assertEqual( - headers.getRawHeaders(b"X-Robots-Tag"), - [b"noindex, nofollow, noarchive, noimageindex"], - ) - - def test_cross_origin_resource_policy_header(self) -> None: - """ - Test that the Cross-Origin-Resource-Policy header is set to "cross-origin" - allowing web clients to embed media from the downloads API. - """ - channel = self._req(b"inline; filename=out" + self.test_image.extension) - - headers = channel.headers - - self.assertEqual( - headers.getRawHeaders(b"Cross-Origin-Resource-Policy"), - [b"cross-origin"], - ) - - -class TestSpamCheckerLegacy: - """A spam checker module that rejects all media that includes the bytes - `evil`. - - Uses the legacy Spam-Checker API. - """ - - def __init__(self, config: Dict[str, Any], api: ModuleApi) -> None: - self.config = config - self.api = api - - def parse_config(config: Dict[str, Any]) -> Dict[str, Any]: - return config - - async def check_event_for_spam(self, event: EventBase) -> Union[bool, str]: - return False # allow all events - - async def user_may_invite( - self, - inviter_userid: str, - invitee_userid: str, - room_id: str, - ) -> bool: - return True # allow all invites - - async def user_may_create_room(self, userid: str) -> bool: - return True # allow all room creations - - async def user_may_create_room_alias( - self, userid: str, room_alias: RoomAlias - ) -> bool: - return True # allow all room aliases - - async def user_may_publish_room(self, userid: str, room_id: str) -> bool: - return True # allow publishing of all rooms - - async def check_media_file_for_spam( - self, file_wrapper: ReadableFileWrapper, file_info: FileInfo - ) -> bool: - buf = BytesIO() - await file_wrapper.write_chunks_to(buf.write) - - return b"evil" in buf.getvalue() - - -class SpamCheckerTestCaseLegacy(unittest.HomeserverTestCase): - servlets = [ - login.register_servlets, - admin.register_servlets, - ] - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.user = self.register_user("user", "pass") - self.tok = self.login("user", "pass") - - # Allow for uploading and downloading to/from the media repo - self.media_repo = hs.get_media_repository_resource() - self.download_resource = self.media_repo.children[b"download"] - self.upload_resource = self.media_repo.children[b"upload"] - - load_legacy_spam_checkers(hs) - - def default_config(self) -> Dict[str, Any]: - config = default_config("test") - - config.update( - { - "spam_checker": [ - { - "module": TestSpamCheckerLegacy.__module__ - + ".TestSpamCheckerLegacy", - "config": {}, - } - ] - } - ) - - return config - - def test_upload_innocent(self) -> None: - """Attempt to upload some innocent data that should be allowed.""" - self.helper.upload_media( - self.upload_resource, SMALL_PNG, tok=self.tok, expect_code=200 - ) - - def test_upload_ban(self) -> None: - """Attempt to upload some data that includes bytes "evil", which should - get rejected by the spam checker. - """ - - data = b"Some evil data" - - self.helper.upload_media( - self.upload_resource, data, tok=self.tok, expect_code=400 - ) - - -EVIL_DATA = b"Some evil data" -EVIL_DATA_EXPERIMENT = b"Some evil data to trigger the experimental tuple API" - - -class SpamCheckerTestCase(unittest.HomeserverTestCase): - servlets = [ - login.register_servlets, - admin.register_servlets, - ] - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.user = self.register_user("user", "pass") - self.tok = self.login("user", "pass") - - # Allow for uploading and downloading to/from the media repo - self.media_repo = hs.get_media_repository_resource() - self.download_resource = self.media_repo.children[b"download"] - self.upload_resource = self.media_repo.children[b"upload"] - - hs.get_module_api().register_spam_checker_callbacks( - check_media_file_for_spam=self.check_media_file_for_spam - ) - - async def check_media_file_for_spam( - self, file_wrapper: ReadableFileWrapper, file_info: FileInfo - ) -> Union[Codes, Literal["NOT_SPAM"]]: - buf = BytesIO() - await file_wrapper.write_chunks_to(buf.write) - - if buf.getvalue() == EVIL_DATA: - return Codes.FORBIDDEN - elif buf.getvalue() == EVIL_DATA_EXPERIMENT: - return (Codes.FORBIDDEN, {}) - else: - return "NOT_SPAM" - - def test_upload_innocent(self) -> None: - """Attempt to upload some innocent data that should be allowed.""" - self.helper.upload_media( - self.upload_resource, SMALL_PNG, tok=self.tok, expect_code=200 - ) - - def test_upload_ban(self) -> None: - """Attempt to upload some data that includes bytes "evil", which should - get rejected by the spam checker. - """ - - self.helper.upload_media( - self.upload_resource, EVIL_DATA, tok=self.tok, expect_code=400 - ) - - self.helper.upload_media( - self.upload_resource, - EVIL_DATA_EXPERIMENT, - tok=self.tok, - expect_code=400, - ) diff --git a/tests/rest/media/v1/test_oembed.py b/tests/rest/media/v1/test_oembed.py deleted file mode 100644
index 3f7f1dbab9..0000000000 --- a/tests/rest/media/v1/test_oembed.py +++ /dev/null
@@ -1,162 +0,0 @@ -# Copyright 2021 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json - -from parameterized import parameterized - -from twisted.test.proto_helpers import MemoryReactor - -from synapse.rest.media.v1.oembed import OEmbedProvider, OEmbedResult -from synapse.server import HomeServer -from synapse.types import JsonDict -from synapse.util import Clock - -from tests.unittest import HomeserverTestCase - -try: - import lxml -except ImportError: - lxml = None - - -class OEmbedTests(HomeserverTestCase): - if not lxml: - skip = "url preview feature requires lxml" - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.oembed = OEmbedProvider(hs) - - def parse_response(self, response: JsonDict) -> OEmbedResult: - return self.oembed.parse_oembed_response( - "https://test", json.dumps(response).encode("utf-8") - ) - - def test_version(self) -> None: - """Accept versions that are similar to 1.0 as a string or int (or missing).""" - for version in ("1.0", 1.0, 1): - result = self.parse_response({"version": version}) - # An empty Open Graph response is an error, ensure the URL is included. - self.assertIn("og:url", result.open_graph_result) - - # A missing version should be treated as 1.0. - result = self.parse_response({"type": "link"}) - self.assertIn("og:url", result.open_graph_result) - - # Invalid versions should be rejected. - for version in ("2.0", "1", 1.1, 0, None, {}, []): - result = self.parse_response({"version": version, "type": "link"}) - # An empty Open Graph response is an error, ensure the URL is included. - self.assertEqual({}, result.open_graph_result) - - def test_cache_age(self) -> None: - """Ensure a cache-age is parsed properly.""" - # Correct-ish cache ages are allowed. - for cache_age in ("1", 1.0, 1): - result = self.parse_response({"cache_age": cache_age}) - self.assertEqual(result.cache_age, 1000) - - # Invalid cache ages are ignored. - for cache_age in ("invalid", {}): - result = self.parse_response({"cache_age": cache_age}) - self.assertIsNone(result.cache_age) - - # Cache age is optional. - result = self.parse_response({}) - self.assertIsNone(result.cache_age) - - @parameterized.expand( - [ - ("title", "title"), - ("provider_name", "site_name"), - ("thumbnail_url", "image"), - ], - name_func=lambda func, num, p: f"{func.__name__}_{p.args[0]}", - ) - def test_property(self, oembed_property: str, open_graph_property: str) -> None: - """Test properties which must be strings.""" - result = self.parse_response({oembed_property: "test"}) - self.assertIn(f"og:{open_graph_property}", result.open_graph_result) - self.assertEqual(result.open_graph_result[f"og:{open_graph_property}"], "test") - - result = self.parse_response({oembed_property: 1}) - self.assertNotIn(f"og:{open_graph_property}", result.open_graph_result) - - def test_author_name(self) -> None: - """Test the author_name property.""" - result = self.parse_response({"author_name": "test"}) - self.assertEqual(result.author_name, "test") - - result = self.parse_response({"author_name": 1}) - self.assertIsNone(result.author_name) - - def test_rich(self) -> None: - """Test a type of rich.""" - result = self.parse_response({"html": "test<img src='foo'>", "type": "rich"}) - self.assertIn("og:description", result.open_graph_result) - self.assertIn("og:image", result.open_graph_result) - self.assertEqual(result.open_graph_result["og:description"], "test") - self.assertEqual(result.open_graph_result["og:image"], "foo") - - result = self.parse_response({"type": "rich"}) - self.assertNotIn("og:description", result.open_graph_result) - - result = self.parse_response({"html": 1, "type": "rich"}) - self.assertNotIn("og:description", result.open_graph_result) - - def test_photo(self) -> None: - """Test a type of photo.""" - result = self.parse_response({"url": "test", "type": "photo"}) - self.assertIn("og:image", result.open_graph_result) - self.assertEqual(result.open_graph_result["og:image"], "test") - - result = self.parse_response({"type": "photo"}) - self.assertNotIn("og:image", result.open_graph_result) - - result = self.parse_response({"url": 1, "type": "photo"}) - self.assertNotIn("og:image", result.open_graph_result) - - def test_video(self) -> None: - """Test a type of video.""" - result = self.parse_response({"html": "test", "type": "video"}) - self.assertIn("og:type", result.open_graph_result) - self.assertEqual(result.open_graph_result["og:type"], "video.other") - self.assertIn("og:description", result.open_graph_result) - self.assertEqual(result.open_graph_result["og:description"], "test") - - result = self.parse_response({"type": "video"}) - self.assertIn("og:type", result.open_graph_result) - self.assertEqual(result.open_graph_result["og:type"], "video.other") - self.assertNotIn("og:description", result.open_graph_result) - - result = self.parse_response({"url": 1, "type": "video"}) - self.assertIn("og:type", result.open_graph_result) - self.assertEqual(result.open_graph_result["og:type"], "video.other") - self.assertNotIn("og:description", result.open_graph_result) - - def test_link(self) -> None: - """Test type of link.""" - result = self.parse_response({"type": "link"}) - self.assertIn("og:type", result.open_graph_result) - self.assertEqual(result.open_graph_result["og:type"], "website") - - def test_title_html_entities(self) -> None: - """Test HTML entities in title""" - result = self.parse_response( - {"title": "Why JSON isn&#8217;t a Good Configuration Language"} - ) - self.assertEqual( - result.open_graph_result["og:title"], - "Why JSON isn’t a Good Configuration Language", - )