summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
authorPatrick Cloke <patrickc@matrix.org>2023-10-17 11:56:14 -0400
committerPatrick Cloke <patrickc@matrix.org>2023-10-17 11:56:14 -0400
commit07b3b9a95e979c0125d069b8841fd7ce4917e559 (patch)
treef4c9e521e2cd0ab54b03282032a44e6d007b1cb0 /tests
parentRevert "TEMPORARY Measure and log test cases" (diff)
parent1.95.0rc1 (diff)
downloadsynapse-07b3b9a95e979c0125d069b8841fd7ce4917e559.tar.xz
Merge branch 'release-v1.95' into matrix-org-hotfixes
Diffstat (limited to 'tests')
-rw-r--r--tests/federation/transport/test_knocking.py2
-rw-r--r--tests/handlers/test_appservice.py16
-rw-r--r--tests/handlers/test_typing.py28
-rw-r--r--tests/media/test_base.py19
-rw-r--r--tests/media/test_media_storage.py122
-rw-r--r--tests/media/test_url_previewer.py6
-rw-r--r--tests/module_api/test_api.py8
-rw-r--r--tests/replication/test_multi_media_repo.py19
-rw-r--r--tests/rest/admin/test_admin.py58
-rw-r--r--tests/rest/admin/test_media.py71
-rw-r--r--tests/rest/admin/test_statistics.py15
-rw-r--r--tests/rest/admin/test_user.py21
-rw-r--r--tests/rest/client/test_rooms.py24
-rw-r--r--tests/rest/client/utils.py6
-rw-r--r--tests/rest/media/test_url_preview.py124
-rw-r--r--tests/storage/test_client_ips.py137
-rw-r--r--tests/storage/test_event_chain.py64
-rw-r--r--tests/unittest.py4
-rw-r--r--tests/util/test_retryutils.py75
19 files changed, 468 insertions, 351 deletions
diff --git a/tests/federation/transport/test_knocking.py b/tests/federation/transport/test_knocking.py

index 3f42f79f26..b63ef3d4ed 100644 --- a/tests/federation/transport/test_knocking.py +++ b/tests/federation/transport/test_knocking.py
@@ -308,7 +308,7 @@ class FederationKnockingTestCase( self.assertEqual(200, channel.code, channel.result) # Check that we got the stripped room state in return - room_state_events = channel.json_body["knock_state_events"] + room_state_events = channel.json_body["knock_room_state"] # Validate the stripped room state events self.check_knock_room_state_against_room_state( diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index a7e6cdd66a..867dbd6001 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py
@@ -31,7 +31,7 @@ from synapse.appservice import ( from synapse.handlers.appservice import ApplicationServicesHandler from synapse.rest.client import login, receipts, register, room, sendtodevice from synapse.server import HomeServer -from synapse.types import JsonDict, RoomStreamToken +from synapse.types import JsonDict, RoomStreamToken, StreamKeyType from synapse.util import Clock from synapse.util.stringutils import random_string @@ -86,7 +86,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): [event], ] ) - self.handler.notify_interested_services(RoomStreamToken(None, 1)) + self.handler.notify_interested_services(RoomStreamToken(stream=1)) self.mock_scheduler.enqueue_for_appservice.assert_called_once_with( interested_service, events=[event] @@ -107,7 +107,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): ] ) self.mock_store.get_events_as_list = AsyncMock(side_effect=[[event]]) - self.handler.notify_interested_services(RoomStreamToken(None, 0)) + self.handler.notify_interested_services(RoomStreamToken(stream=0)) self.mock_as_api.query_user.assert_called_once_with(services[0], user_id) @@ -126,7 +126,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): ] ) - self.handler.notify_interested_services(RoomStreamToken(None, 0)) + self.handler.notify_interested_services(RoomStreamToken(stream=0)) self.assertFalse( self.mock_as_api.query_user.called, @@ -304,7 +304,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): ) self.handler.notify_interested_services_ephemeral( - "receipt_key", 580, ["@fakerecipient:example.com"] + StreamKeyType.RECEIPT, 580, ["@fakerecipient:example.com"] ) self.mock_scheduler.enqueue_for_appservice.assert_called_once_with( interested_service, ephemeral=[event] @@ -332,7 +332,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): ) self.handler.notify_interested_services_ephemeral( - "receipt_key", 580, ["@fakerecipient:example.com"] + StreamKeyType.RECEIPT, 580, ["@fakerecipient:example.com"] ) # This method will be called, but with an empty list of events self.mock_scheduler.enqueue_for_appservice.assert_called_once_with( @@ -441,7 +441,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): self.get_success( self.hs.get_application_service_handler()._notify_interested_services( RoomStreamToken( - None, self.hs.get_application_service_handler().current_max + stream=self.hs.get_application_service_handler().current_max ) ) ) @@ -634,7 +634,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): self.get_success( self.hs.get_application_service_handler()._notify_interested_services_ephemeral( services=[interested_appservice], - stream_key="receipt_key", + stream_key=StreamKeyType.RECEIPT, new_token=stream_token, users=[self.exclusive_as_user], ) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 95106ec8f3..d7025c6f2c 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py
@@ -28,7 +28,7 @@ from synapse.federation.transport.server import TransportLayerServer from synapse.handlers.typing import TypingWriterHandler from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent from synapse.server import HomeServer -from synapse.types import JsonDict, Requester, UserID, create_requester +from synapse.types import JsonDict, Requester, StreamKeyType, UserID, create_requester from synapse.util import Clock from tests import unittest @@ -174,7 +174,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): return_value=1 ) - self.datastore.get_partial_current_state_deltas = Mock(return_value=(0, None)) # type: ignore[method-assign] + self.datastore.get_partial_current_state_deltas = Mock(return_value=(0, [])) # type: ignore[method-assign] self.datastore.get_to_device_stream_token = Mock( # type: ignore[method-assign] return_value=0 @@ -203,7 +203,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ) ) - self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])]) + self.on_new_event.assert_has_calls( + [call(StreamKeyType.TYPING, 1, rooms=[ROOM_ID])] + ) self.assertEqual(self.event_source.get_current_key(), 1) events = self.get_success( @@ -273,7 +275,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, 200) - self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])]) + self.on_new_event.assert_has_calls( + [call(StreamKeyType.TYPING, 1, rooms=[ROOM_ID])] + ) self.assertEqual(self.event_source.get_current_key(), 1) events = self.get_success( @@ -349,7 +353,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ) ) - self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])]) + self.on_new_event.assert_has_calls( + [call(StreamKeyType.TYPING, 1, rooms=[ROOM_ID])] + ) self.mock_federation_client.put_json.assert_called_once_with( "farm", @@ -399,7 +405,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ) ) - self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])]) + self.on_new_event.assert_has_calls( + [call(StreamKeyType.TYPING, 1, rooms=[ROOM_ID])] + ) self.on_new_event.reset_mock() self.assertEqual(self.event_source.get_current_key(), 1) @@ -425,7 +433,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.reactor.pump([16]) - self.on_new_event.assert_has_calls([call("typing_key", 2, rooms=[ROOM_ID])]) + self.on_new_event.assert_has_calls( + [call(StreamKeyType.TYPING, 2, rooms=[ROOM_ID])] + ) self.assertEqual(self.event_source.get_current_key(), 2) events = self.get_success( @@ -459,7 +469,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ) ) - self.on_new_event.assert_has_calls([call("typing_key", 3, rooms=[ROOM_ID])]) + self.on_new_event.assert_has_calls( + [call(StreamKeyType.TYPING, 3, rooms=[ROOM_ID])] + ) self.on_new_event.reset_mock() self.assertEqual(self.event_source.get_current_key(), 3) diff --git a/tests/media/test_base.py b/tests/media/test_base.py
index 119d7ba66f..144948f23c 100644 --- a/tests/media/test_base.py +++ b/tests/media/test_base.py
@@ -42,18 +42,35 @@ class GetFileNameFromHeadersTests(unittest.TestCase): class AddFileHeadersTests(unittest.TestCase): TEST_CASES = { + # Safe values use inline. "text/plain": b"inline; filename=file.name", "text/csv": b"inline; filename=file.name", "image/png": b"inline; filename=file.name", + # Unlisted values are set to attachment. "text/html": b"attachment; filename=file.name", "any/thing": b"attachment; filename=file.name", + # Parameters get ignored. + "text/plain; charset=utf-8": b"inline; filename=file.name", + "text/markdown; charset=utf-8; variant=CommonMark": b"attachment; filename=file.name", + # Parsed as lowercase. + "Text/Plain": b"inline; filename=file.name", + # Bad values don't choke. + "": b"attachment; filename=file.name", + ";": b"attachment; filename=file.name", } def test_content_disposition(self) -> None: for media_type, expected in self.TEST_CASES.items(): request = Mock() add_file_headers(request, media_type, 0, "file.name") - request.setHeader.assert_any_call(b"Content-Disposition", expected) + # There should be a single call to set Content-Disposition. + for call in request.setHeader.call_args_list: + args, _ = call + if args[0] == b"Content-Disposition": + break + else: + self.fail(f"No Content-Disposition header found for {media_type}") + self.assertEqual(args[1], expected, media_type) def test_no_filename(self) -> None: request = Mock() diff --git a/tests/media/test_media_storage.py b/tests/media/test_media_storage.py
index 04fc7bdcef..15f5d644e4 100644 --- a/tests/media/test_media_storage.py +++ b/tests/media/test_media_storage.py
@@ -28,12 +28,13 @@ from typing_extensions import Literal from twisted.internet import defer from twisted.internet.defer import Deferred from twisted.test.proto_helpers import MemoryReactor +from twisted.web.resource import Resource from synapse.api.errors import Codes from synapse.events import EventBase from synapse.http.types import QueryParams from synapse.logging.context import make_deferred_yieldable -from synapse.media._base import FileInfo +from synapse.media._base import FileInfo, ThumbnailInfo from synapse.media.filepath import MediaFilePaths from synapse.media.media_storage import MediaStorage, ReadableFileWrapper from synapse.media.storage_provider import FileStorageProviderBackend @@ -41,12 +42,13 @@ from synapse.module_api import ModuleApi from synapse.module_api.callbacks.spamchecker_callbacks import load_legacy_spam_checkers from synapse.rest import admin from synapse.rest.client import login +from synapse.rest.media.thumbnail_resource import ThumbnailResource from synapse.server import HomeServer from synapse.types import JsonDict, RoomAlias from synapse.util import Clock from tests import unittest -from tests.server import FakeChannel, FakeSite, make_request +from tests.server import FakeChannel from tests.test_utils import SMALL_PNG from tests.utils import default_config @@ -288,22 +290,22 @@ class MediaRepoTests(unittest.HomeserverTestCase): return hs def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - media_resource = hs.get_media_repository_resource() - self.download_resource = media_resource.children[b"download"] - self.thumbnail_resource = media_resource.children[b"thumbnail"] self.store = hs.get_datastores().main self.media_repo = hs.get_media_repository() self.media_id = "example.com/12345" + def create_resource_dict(self) -> Dict[str, Resource]: + resources = super().create_resource_dict() + resources["/_matrix/media"] = self.hs.get_media_repository_resource() + return resources + def _req( self, content_disposition: Optional[bytes], include_content_type: bool = True ) -> FakeChannel: - channel = make_request( - self.reactor, - FakeSite(self.download_resource, self.reactor), + channel = self.make_request( "GET", - self.media_id, + f"/_matrix/media/v3/download/{self.media_id}", shorthand=False, await_result=False, ) @@ -481,11 +483,9 @@ class MediaRepoTests(unittest.HomeserverTestCase): # Fetching again should work, without re-requesting the image from the # remote. params = "?width=32&height=32&method=scale" - channel = make_request( - self.reactor, - FakeSite(self.thumbnail_resource, self.reactor), + channel = self.make_request( "GET", - self.media_id + params, + f"/_matrix/media/v3/thumbnail/{self.media_id}{params}", shorthand=False, await_result=False, ) @@ -511,11 +511,9 @@ class MediaRepoTests(unittest.HomeserverTestCase): ) shutil.rmtree(thumbnail_dir, ignore_errors=True) - channel = make_request( - self.reactor, - FakeSite(self.thumbnail_resource, self.reactor), + channel = self.make_request( "GET", - self.media_id + params, + f"/_matrix/media/v3/thumbnail/{self.media_id}{params}", shorthand=False, await_result=False, ) @@ -549,11 +547,9 @@ class MediaRepoTests(unittest.HomeserverTestCase): """ params = "?width=32&height=32&method=" + method - channel = make_request( - self.reactor, - FakeSite(self.thumbnail_resource, self.reactor), + channel = self.make_request( "GET", - self.media_id + params, + f"/_matrix/media/r0/thumbnail/{self.media_id}{params}", shorthand=False, await_result=False, ) @@ -590,7 +586,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): channel.json_body, { "errcode": "M_UNKNOWN", - "error": "Cannot find any thumbnails for the requested media ([b'example.com', b'12345']). This might mean the media is not a supported_media_format=(image/jpeg, image/jpg, image/webp, image/gif, image/png) or that thumbnailing failed for some other reason. (Dynamic thumbnails are disabled on this server.)", + "error": "Cannot find any thumbnails for the requested media ('/_matrix/media/r0/thumbnail/example.com/12345'). This might mean the media is not a supported_media_format=(image/jpeg, image/jpg, image/webp, image/gif, image/png) or that thumbnailing failed for some other reason. (Dynamic thumbnails are disabled on this server.)", }, ) else: @@ -600,7 +596,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): channel.json_body, { "errcode": "M_NOT_FOUND", - "error": "Not found [b'example.com', b'12345']", + "error": "Not found '/_matrix/media/r0/thumbnail/example.com/12345'", }, ) @@ -609,34 +605,39 @@ class MediaRepoTests(unittest.HomeserverTestCase): """Test that choosing between thumbnails with the same quality rating succeeds. We are not particular about which thumbnail is chosen.""" + + content_type = self.test_image.content_type.decode() + media_repo = self.hs.get_media_repository() + thumbnail_resouce = ThumbnailResource( + self.hs, media_repo, media_repo.media_storage + ) + self.assertIsNotNone( - self.thumbnail_resource._select_thumbnail( + thumbnail_resouce._select_thumbnail( desired_width=desired_size, desired_height=desired_size, desired_method=method, - desired_type=self.test_image.content_type, + desired_type=content_type, # Provide two identical thumbnails which are guaranteed to have the same # quality rating. thumbnail_infos=[ - { - "thumbnail_width": 32, - "thumbnail_height": 32, - "thumbnail_method": method, - "thumbnail_type": self.test_image.content_type, - "thumbnail_length": 256, - "filesystem_id": f"thumbnail1{self.test_image.extension.decode()}", - }, - { - "thumbnail_width": 32, - "thumbnail_height": 32, - "thumbnail_method": method, - "thumbnail_type": self.test_image.content_type, - "thumbnail_length": 256, - "filesystem_id": f"thumbnail2{self.test_image.extension.decode()}", - }, + ThumbnailInfo( + width=32, + height=32, + method=method, + type=content_type, + length=256, + ), + ThumbnailInfo( + width=32, + height=32, + method=method, + type=content_type, + length=256, + ), ], file_id=f"image{self.test_image.extension.decode()}", - url_cache=None, + url_cache=False, server_name=None, ) ) @@ -725,13 +726,13 @@ class SpamCheckerTestCaseLegacy(unittest.HomeserverTestCase): self.user = self.register_user("user", "pass") self.tok = self.login("user", "pass") - # Allow for uploading and downloading to/from the media repo - self.media_repo = hs.get_media_repository_resource() - self.download_resource = self.media_repo.children[b"download"] - self.upload_resource = self.media_repo.children[b"upload"] - load_legacy_spam_checkers(hs) + def create_resource_dict(self) -> Dict[str, Resource]: + resources = super().create_resource_dict() + resources["/_matrix/media"] = self.hs.get_media_repository_resource() + return resources + def default_config(self) -> Dict[str, Any]: config = default_config("test") @@ -751,9 +752,7 @@ class SpamCheckerTestCaseLegacy(unittest.HomeserverTestCase): def test_upload_innocent(self) -> None: """Attempt to upload some innocent data that should be allowed.""" - self.helper.upload_media( - self.upload_resource, SMALL_PNG, tok=self.tok, expect_code=200 - ) + self.helper.upload_media(SMALL_PNG, tok=self.tok, expect_code=200) def test_upload_ban(self) -> None: """Attempt to upload some data that includes bytes "evil", which should @@ -762,9 +761,7 @@ class SpamCheckerTestCaseLegacy(unittest.HomeserverTestCase): data = b"Some evil data" - self.helper.upload_media( - self.upload_resource, data, tok=self.tok, expect_code=400 - ) + self.helper.upload_media(data, tok=self.tok, expect_code=400) EVIL_DATA = b"Some evil data" @@ -781,15 +778,15 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase): self.user = self.register_user("user", "pass") self.tok = self.login("user", "pass") - # Allow for uploading and downloading to/from the media repo - self.media_repo = hs.get_media_repository_resource() - self.download_resource = self.media_repo.children[b"download"] - self.upload_resource = self.media_repo.children[b"upload"] - hs.get_module_api().register_spam_checker_callbacks( check_media_file_for_spam=self.check_media_file_for_spam ) + def create_resource_dict(self) -> Dict[str, Resource]: + resources = super().create_resource_dict() + resources["/_matrix/media"] = self.hs.get_media_repository_resource() + return resources + async def check_media_file_for_spam( self, file_wrapper: ReadableFileWrapper, file_info: FileInfo ) -> Union[Codes, Literal["NOT_SPAM"], Tuple[Codes, JsonDict]]: @@ -805,21 +802,16 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase): def test_upload_innocent(self) -> None: """Attempt to upload some innocent data that should be allowed.""" - self.helper.upload_media( - self.upload_resource, SMALL_PNG, tok=self.tok, expect_code=200 - ) + self.helper.upload_media(SMALL_PNG, tok=self.tok, expect_code=200) def test_upload_ban(self) -> None: """Attempt to upload some data that includes bytes "evil", which should get rejected by the spam checker. """ - self.helper.upload_media( - self.upload_resource, EVIL_DATA, tok=self.tok, expect_code=400 - ) + self.helper.upload_media(EVIL_DATA, tok=self.tok, expect_code=400) self.helper.upload_media( - self.upload_resource, EVIL_DATA_EXPERIMENT, tok=self.tok, expect_code=400, diff --git a/tests/media/test_url_previewer.py b/tests/media/test_url_previewer.py
index 46ecde5344..04b69f378a 100644 --- a/tests/media/test_url_previewer.py +++ b/tests/media/test_url_previewer.py
@@ -61,9 +61,9 @@ class URLPreviewTests(unittest.HomeserverTestCase): return self.setup_test_homeserver(config=config) def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - media_repo_resource = hs.get_media_repository_resource() - preview_url = media_repo_resource.children[b"preview_url"] - self.url_previewer = preview_url._url_previewer + media_repo = hs.get_media_repository() + assert media_repo.url_previewer is not None + self.url_previewer = media_repo.url_previewer def test_all_urls_allowed(self) -> None: self.assertFalse(self.url_previewer._is_url_blocked("http://matrix.org")) diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index 172fc3a736..1dabf52156 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py
@@ -94,12 +94,12 @@ class ModuleApiTestCase(BaseModuleApiTestCase): self.assertEqual(len(emails), 1) email = emails[0] - self.assertEqual(email["medium"], "email") - self.assertEqual(email["address"], "bob@bobinator.bob") + self.assertEqual(email.medium, "email") + self.assertEqual(email.address, "bob@bobinator.bob") # Should these be 0? - self.assertEqual(email["validated_at"], 0) - self.assertEqual(email["added_at"], 0) + self.assertEqual(email.validated_at, 0) + self.assertEqual(email.added_at, 0) # Check that the displayname was assigned displayname = self.get_success( diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py
index 6e78daa830..b230a6c361 100644 --- a/tests/replication/test_multi_media_repo.py +++ b/tests/replication/test_multi_media_repo.py
@@ -13,7 +13,7 @@ # limitations under the License. import logging import os -from typing import Optional, Tuple +from typing import Any, Optional, Tuple from twisted.internet.interfaces import IOpenSSLServerConnectionCreator from twisted.internet.protocol import Factory @@ -29,7 +29,7 @@ from synapse.util import Clock from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file from tests.replication._base import BaseMultiWorkerStreamTestCase -from tests.server import FakeChannel, FakeSite, FakeTransport, make_request +from tests.server import FakeChannel, FakeTransport, make_request from tests.test_utils import SMALL_PNG logger = logging.getLogger(__name__) @@ -56,6 +56,16 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase): conf["federation_custom_ca_list"] = [get_test_ca_cert_file()] return conf + def make_worker_hs( + self, worker_app: str, extra_config: Optional[dict] = None, **kwargs: Any + ) -> HomeServer: + worker_hs = super().make_worker_hs(worker_app, extra_config, **kwargs) + # Force the media paths onto the replication resource. + worker_hs.get_media_repository_resource().register_servlets( + self._hs_to_site[worker_hs].resource, worker_hs + ) + return worker_hs + def _get_media_req( self, hs: HomeServer, target: str, media_id: str ) -> Tuple[FakeChannel, Request]: @@ -68,12 +78,11 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase): The channel for the *client* request and the *outbound* request for the media which the caller should respond to. """ - resource = hs.get_media_repository_resource().children[b"download"] channel = make_request( self.reactor, - FakeSite(resource, self.reactor), + self._hs_to_site[hs], "GET", - f"/{target}/{media_id}", + f"/_matrix/media/r0/download/{target}/{media_id}", shorthand=False, access_token=self.access_token, await_result=False, diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index 359d131b37..8646b2f0fd 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py
@@ -13,10 +13,12 @@ # limitations under the License. import urllib.parse +from typing import Dict from parameterized import parameterized from twisted.test.proto_helpers import MemoryReactor +from twisted.web.resource import Resource import synapse.rest.admin from synapse.http.server import JsonResource @@ -26,7 +28,6 @@ from synapse.server import HomeServer from synapse.util import Clock from tests import unittest -from tests.server import FakeSite, make_request from tests.test_utils import SMALL_PNG @@ -55,21 +56,18 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - # Allow for uploading and downloading to/from the media repo - self.media_repo = hs.get_media_repository_resource() - self.download_resource = self.media_repo.children[b"download"] - self.upload_resource = self.media_repo.children[b"upload"] + def create_resource_dict(self) -> Dict[str, Resource]: + resources = super().create_resource_dict() + resources["/_matrix/media"] = self.hs.get_media_repository_resource() + return resources def _ensure_quarantined( self, admin_user_tok: str, server_and_media_id: str ) -> None: """Ensure a piece of media is quarantined when trying to access it.""" - channel = make_request( - self.reactor, - FakeSite(self.download_resource, self.reactor), + channel = self.make_request( "GET", - server_and_media_id, + f"/_matrix/media/v3/download/{server_and_media_id}", shorthand=False, access_token=admin_user_tok, ) @@ -117,20 +115,16 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): non_admin_user_tok = self.login("id_nonadmin", "pass") # Upload some media into the room - response = self.helper.upload_media( - self.upload_resource, SMALL_PNG, tok=admin_user_tok - ) + response = self.helper.upload_media(SMALL_PNG, tok=admin_user_tok) # Extract media ID from the response server_name_and_media_id = response["content_uri"][6:] # Cut off 'mxc://' server_name, media_id = server_name_and_media_id.split("/") # Attempt to access the media - channel = make_request( - self.reactor, - FakeSite(self.download_resource, self.reactor), + channel = self.make_request( "GET", - server_name_and_media_id, + f"/_matrix/media/v3/download/{server_name_and_media_id}", shorthand=False, access_token=non_admin_user_tok, ) @@ -173,12 +167,8 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): self.helper.join(room_id, non_admin_user, tok=non_admin_user_tok) # Upload some media - response_1 = self.helper.upload_media( - self.upload_resource, SMALL_PNG, tok=non_admin_user_tok - ) - response_2 = self.helper.upload_media( - self.upload_resource, SMALL_PNG, tok=non_admin_user_tok - ) + response_1 = self.helper.upload_media(SMALL_PNG, tok=non_admin_user_tok) + response_2 = self.helper.upload_media(SMALL_PNG, tok=non_admin_user_tok) # Extract mxcs mxc_1 = response_1["content_uri"] @@ -227,12 +217,8 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): non_admin_user_tok = self.login("user_nonadmin", "pass") # Upload some media - response_1 = self.helper.upload_media( - self.upload_resource, SMALL_PNG, tok=non_admin_user_tok - ) - response_2 = self.helper.upload_media( - self.upload_resource, SMALL_PNG, tok=non_admin_user_tok - ) + response_1 = self.helper.upload_media(SMALL_PNG, tok=non_admin_user_tok) + response_2 = self.helper.upload_media(SMALL_PNG, tok=non_admin_user_tok) # Extract media IDs server_and_media_id_1 = response_1["content_uri"][6:] @@ -265,12 +251,8 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): non_admin_user_tok = self.login("user_nonadmin", "pass") # Upload some media - response_1 = self.helper.upload_media( - self.upload_resource, SMALL_PNG, tok=non_admin_user_tok - ) - response_2 = self.helper.upload_media( - self.upload_resource, SMALL_PNG, tok=non_admin_user_tok - ) + response_1 = self.helper.upload_media(SMALL_PNG, tok=non_admin_user_tok) + response_2 = self.helper.upload_media(SMALL_PNG, tok=non_admin_user_tok) # Extract media IDs server_and_media_id_1 = response_1["content_uri"][6:] @@ -304,11 +286,9 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): self._ensure_quarantined(admin_user_tok, server_and_media_id_1) # Attempt to access each piece of media - channel = make_request( - self.reactor, - FakeSite(self.download_resource, self.reactor), + channel = self.make_request( "GET", - server_and_media_id_2, + f"/_matrix/media/v3/download/{server_and_media_id_2}", shorthand=False, access_token=non_admin_user_tok, ) diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py
index 6d04911d67..278808abb5 100644 --- a/tests/rest/admin/test_media.py +++ b/tests/rest/admin/test_media.py
@@ -13,10 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from typing import Dict from parameterized import parameterized from twisted.test.proto_helpers import MemoryReactor +from twisted.web.resource import Resource import synapse.rest.admin from synapse.api.errors import Codes @@ -26,22 +28,27 @@ from synapse.server import HomeServer from synapse.util import Clock from tests import unittest -from tests.server import FakeSite, make_request from tests.test_utils import SMALL_PNG VALID_TIMESTAMP = 1609459200000 # 2021-01-01 in milliseconds INVALID_TIMESTAMP_IN_S = 1893456000 # 2030-01-01 in seconds -class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): +class _AdminMediaTests(unittest.HomeserverTestCase): servlets = [ synapse.rest.admin.register_servlets, synapse.rest.admin.register_servlets_for_media_repo, login.register_servlets, ] + def create_resource_dict(self) -> Dict[str, Resource]: + resources = super().create_resource_dict() + resources["/_matrix/media"] = self.hs.get_media_repository_resource() + return resources + + +class DeleteMediaByIDTestCase(_AdminMediaTests): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.media_repo = hs.get_media_repository_resource() self.server_name = hs.hostname self.admin_user = self.register_user("admin", "pass", admin=True) @@ -117,12 +124,8 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): Tests that delete a media is successfully """ - download_resource = self.media_repo.children[b"download"] - upload_resource = self.media_repo.children[b"upload"] - # Upload some media into the room response = self.helper.upload_media( - upload_resource, SMALL_PNG, tok=self.admin_user_tok, expect_code=200, @@ -134,11 +137,9 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): self.assertEqual(server_name, self.server_name) # Attempt to access media - channel = make_request( - self.reactor, - FakeSite(download_resource, self.reactor), + channel = self.make_request( "GET", - server_and_media_id, + f"/_matrix/media/v3/download/{server_and_media_id}", shorthand=False, access_token=self.admin_user_tok, ) @@ -173,11 +174,9 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): ) # Attempt to access media - channel = make_request( - self.reactor, - FakeSite(download_resource, self.reactor), + channel = self.make_request( "GET", - server_and_media_id, + f"/_matrix/media/v3/download/{server_and_media_id}", shorthand=False, access_token=self.admin_user_tok, ) @@ -194,7 +193,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): self.assertFalse(os.path.exists(local_path)) -class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): +class DeleteMediaByDateSizeTestCase(_AdminMediaTests): servlets = [ synapse.rest.admin.register_servlets, synapse.rest.admin.register_servlets_for_media_repo, @@ -529,11 +528,8 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): """ Create a media and return media_id and server_and_media_id """ - upload_resource = self.media_repo.children[b"upload"] - # Upload some media into the room response = self.helper.upload_media( - upload_resource, SMALL_PNG, tok=self.admin_user_tok, expect_code=200, @@ -553,16 +549,12 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): """ Try to access a media and check the result """ - download_resource = self.media_repo.children[b"download"] - media_id = server_and_media_id.split("/")[1] local_path = self.filepaths.local_media_filepath(media_id) - channel = make_request( - self.reactor, - FakeSite(download_resource, self.reactor), + channel = self.make_request( "GET", - server_and_media_id, + f"/_matrix/media/v3/download/{server_and_media_id}", shorthand=False, access_token=self.admin_user_tok, ) @@ -591,27 +583,16 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.assertFalse(os.path.exists(local_path)) -class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets, - synapse.rest.admin.register_servlets_for_media_repo, - login.register_servlets, - ] - +class QuarantineMediaByIDTestCase(_AdminMediaTests): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - media_repo = hs.get_media_repository_resource() self.store = hs.get_datastores().main self.server_name = hs.hostname self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") - # Create media - upload_resource = media_repo.children[b"upload"] - # Upload some media into the room response = self.helper.upload_media( - upload_resource, SMALL_PNG, tok=self.admin_user_tok, expect_code=200, @@ -720,26 +701,16 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): self.assertFalse(media_info["quarantined_by"]) -class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets, - synapse.rest.admin.register_servlets_for_media_repo, - login.register_servlets, - ] - +class ProtectMediaByIDTestCase(_AdminMediaTests): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - media_repo = hs.get_media_repository_resource() + hs.get_media_repository_resource() self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") - # Create media - upload_resource = media_repo.children[b"upload"] - # Upload some media into the room response = self.helper.upload_media( - upload_resource, SMALL_PNG, tok=self.admin_user_tok, expect_code=200, @@ -816,7 +787,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): self.assertFalse(media_info["safe_from_quarantine"]) -class PurgeMediaCacheTestCase(unittest.HomeserverTestCase): +class PurgeMediaCacheTestCase(_AdminMediaTests): servlets = [ synapse.rest.admin.register_servlets, synapse.rest.admin.register_servlets_for_media_repo, diff --git a/tests/rest/admin/test_statistics.py b/tests/rest/admin/test_statistics.py
index b60f16b914..cd8ee274d8 100644 --- a/tests/rest/admin/test_statistics.py +++ b/tests/rest/admin/test_statistics.py
@@ -12,9 +12,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional +from typing import Dict, List, Optional from twisted.test.proto_helpers import MemoryReactor +from twisted.web.resource import Resource import synapse.rest.admin from synapse.api.errors import Codes @@ -34,8 +35,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.media_repo = hs.get_media_repository_resource() - self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -44,6 +43,11 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.url = "/_synapse/admin/v1/statistics/users/media" + def create_resource_dict(self) -> Dict[str, Resource]: + resources = super().create_resource_dict() + resources["/_matrix/media"] = self.hs.get_media_repository_resource() + return resources + def test_no_auth(self) -> None: """ Try to list users without authentication. @@ -470,12 +474,9 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): user_token: Access token of the user number_media: Number of media to be created for the user """ - upload_resource = self.media_repo.children[b"upload"] for _ in range(number_media): # Upload some media into the room - self.helper.upload_media( - upload_resource, SMALL_PNG, tok=user_token, expect_code=200 - ) + self.helper.upload_media(SMALL_PNG, tok=user_token, expect_code=200) def _check_fields(self, content: List[JsonDict]) -> None: """Checks that all attributes are present in content diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index b326ad2c90..37f37a09d8 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py
@@ -17,12 +17,13 @@ import hmac import os import urllib.parse from binascii import unhexlify -from typing import List, Optional +from typing import Dict, List, Optional from unittest.mock import AsyncMock, Mock, patch from parameterized import parameterized, parameterized_class from twisted.test.proto_helpers import MemoryReactor +from twisted.web.resource import Resource import synapse.rest.admin from synapse.api.constants import ApprovalNoticeMedium, LoginType, UserTypes @@ -45,7 +46,6 @@ from synapse.types import JsonDict, UserID, create_requester from synapse.util import Clock from tests import unittest -from tests.server import FakeSite, make_request from tests.test_utils import SMALL_PNG from tests.unittest import override_config @@ -3421,7 +3421,6 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main - self.media_repo = hs.get_media_repository_resource() self.filepaths = MediaFilePaths(hs.config.media.media_store_path) self.admin_user = self.register_user("admin", "pass", admin=True) @@ -3432,6 +3431,11 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): self.other_user ) + def create_resource_dict(self) -> Dict[str, Resource]: + resources = super().create_resource_dict() + resources["/_matrix/media"] = self.hs.get_media_repository_resource() + return resources + @parameterized.expand(["GET", "DELETE"]) def test_no_auth(self, method: str) -> None: """Try to list media of an user without authentication.""" @@ -3907,12 +3911,9 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): Returns: The ID of the newly created media. """ - upload_resource = self.media_repo.children[b"upload"] - download_resource = self.media_repo.children[b"download"] - # Upload some media into the room response = self.helper.upload_media( - upload_resource, image_data, user_token, filename, expect_code=200 + image_data, user_token, filename, expect_code=200 ) # Extract media ID from the response @@ -3920,11 +3921,9 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): media_id = server_and_media_id.split("/")[1] # Try to access a media and to create `last_access_ts` - channel = make_request( - self.reactor, - FakeSite(download_resource, self.reactor), + channel = self.make_request( "GET", - server_and_media_id, + f"/_matrix/media/v3/download/{server_and_media_id}", shorthand=False, access_token=user_token, ) diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index 7627823d3f..aaa4f3bba0 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py
@@ -1447,6 +1447,30 @@ class RoomJoinRatelimitTestCase(RoomBase): @unittest.override_config( {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}} ) + def test_join_attempts_local_ratelimit(self) -> None: + """Tests that unsuccessful joins that end up being denied are rate-limited.""" + # Create 4 rooms + room_ids = [ + self.helper.create_room_as(self.user_id, is_public=True) for _ in range(4) + ] + # Pre-emptively ban the user who will attempt to join. + joiner_user_id = self.register_user("joiner", "secret") + for room_id in room_ids: + self.helper.ban(room_id, self.user_id, joiner_user_id) + + # Now make a new user try to join some of them. + # The user can make 3 requests, each of which should be denied. + for room_id in room_ids[0:3]: + self.helper.join(room_id, joiner_user_id, expect_code=HTTPStatus.FORBIDDEN) + + # The fourth attempt should be rate limited. + self.helper.join( + room_ids[3], joiner_user_id, expect_code=HTTPStatus.TOO_MANY_REQUESTS + ) + + @unittest.override_config( + {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}} + ) def test_join_local_ratelimit_profile_change(self) -> None: """Tests that sending a profile update into all of the user's joined rooms isn't rate-limited by the rate-limiter on joins.""" diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index 9532e5ddc1..465b696c0b 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py
@@ -37,7 +37,6 @@ import attr from typing_extensions import Literal from twisted.test.proto_helpers import MemoryReactorClock -from twisted.web.resource import Resource from twisted.web.server import Site from synapse.api.constants import Membership @@ -45,7 +44,7 @@ from synapse.api.errors import Codes from synapse.server import HomeServer from synapse.types import JsonDict -from tests.server import FakeChannel, FakeSite, make_request +from tests.server import FakeChannel, make_request from tests.test_utils.html_parsers import TestHtmlParser from tests.test_utils.oidc import FakeAuthorizationGrant, FakeOidcServer @@ -558,7 +557,6 @@ class RestHelper: def upload_media( self, - resource: Resource, image_data: bytes, tok: str, filename: str = "test.png", @@ -576,7 +574,7 @@ class RestHelper: path = "/_matrix/media/r0/upload?filename=%s" % (filename,) channel = make_request( self.reactor, - FakeSite(resource, self.reactor), + self.site, "POST", path, content=image_data, diff --git a/tests/rest/media/test_url_preview.py b/tests/rest/media/test_url_preview.py
index 05d5e39cab..24459c6af4 100644 --- a/tests/rest/media/test_url_preview.py +++ b/tests/rest/media/test_url_preview.py
@@ -24,10 +24,10 @@ from twisted.internet.address import IPv4Address, IPv6Address from twisted.internet.error import DNSLookupError from twisted.internet.interfaces import IAddress, IResolutionReceiver from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactor +from twisted.web.resource import Resource from synapse.config.oembed import OEmbedEndpointConfig from synapse.media.url_previewer import IMAGE_CACHE_EXPIRY_MS -from synapse.rest.media.media_repository_resource import MediaRepositoryResource from synapse.server import HomeServer from synapse.types import JsonDict from synapse.util import Clock @@ -117,8 +117,8 @@ class URLPreviewTests(unittest.HomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.media_repo = hs.get_media_repository() - media_repo_resource = hs.get_media_repository_resource() - self.preview_url = media_repo_resource.children[b"preview_url"] + assert self.media_repo.url_previewer is not None + self.url_previewer = self.media_repo.url_previewer self.lookups: Dict[str, Any] = {} @@ -143,8 +143,15 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.reactor.nameResolver = Resolver() # type: ignore[assignment] - def create_test_resource(self) -> MediaRepositoryResource: - return self.hs.get_media_repository_resource() + def create_resource_dict(self) -> Dict[str, Resource]: + """Create a resource tree for the test server + + A resource tree is a mapping from path to twisted.web.resource. + + The default implementation creates a JsonResource and calls each function in + `servlets` to register servlets against it. + """ + return {"/_matrix/media": self.hs.get_media_repository_resource()} def _assert_small_png(self, json_body: JsonDict) -> None: """Assert properties from the SMALL_PNG test image.""" @@ -159,7 +166,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): channel = self.make_request( "GET", - "preview_url?url=http://matrix.org", + "/_matrix/media/v3/preview_url?url=http://matrix.org", shorthand=False, await_result=False, ) @@ -183,7 +190,9 @@ class URLPreviewTests(unittest.HomeserverTestCase): # Check the cache returns the correct response channel = self.make_request( - "GET", "preview_url?url=http://matrix.org", shorthand=False + "GET", + "/_matrix/media/v3/preview_url?url=http://matrix.org", + shorthand=False, ) # Check the cache response has the same content @@ -193,13 +202,15 @@ class URLPreviewTests(unittest.HomeserverTestCase): ) # Clear the in-memory cache - self.assertIn("http://matrix.org", self.preview_url._url_previewer._cache) - self.preview_url._url_previewer._cache.pop("http://matrix.org") - self.assertNotIn("http://matrix.org", self.preview_url._url_previewer._cache) + self.assertIn("http://matrix.org", self.url_previewer._cache) + self.url_previewer._cache.pop("http://matrix.org") + self.assertNotIn("http://matrix.org", self.url_previewer._cache) # Check the database cache returns the correct response channel = self.make_request( - "GET", "preview_url?url=http://matrix.org", shorthand=False + "GET", + "/_matrix/media/v3/preview_url?url=http://matrix.org", + shorthand=False, ) # Check the cache response has the same content @@ -221,7 +232,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): channel = self.make_request( "GET", - "preview_url?url=http://matrix.org", + "/_matrix/media/v3/preview_url?url=http://matrix.org", shorthand=False, await_result=False, ) @@ -251,7 +262,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): channel = self.make_request( "GET", - "preview_url?url=http://matrix.org", + "/_matrix/media/v3/preview_url?url=http://matrix.org", shorthand=False, await_result=False, ) @@ -287,7 +298,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): channel = self.make_request( "GET", - "preview_url?url=http://matrix.org", + "/_matrix/media/v3/preview_url?url=http://matrix.org", shorthand=False, await_result=False, ) @@ -328,7 +339,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): channel = self.make_request( "GET", - "preview_url?url=http://matrix.org", + "/_matrix/media/v3/preview_url?url=http://matrix.org", shorthand=False, await_result=False, ) @@ -363,7 +374,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): channel = self.make_request( "GET", - "preview_url?url=http://matrix.org", + "/_matrix/media/v3/preview_url?url=http://matrix.org", shorthand=False, await_result=False, ) @@ -396,7 +407,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): channel = self.make_request( "GET", - "preview_url?url=http://example.com", + "/_matrix/media/v3/preview_url?url=http://example.com", shorthand=False, await_result=False, ) @@ -425,7 +436,9 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.lookups["example.com"] = [(IPv4Address, "192.168.1.1")] channel = self.make_request( - "GET", "preview_url?url=http://example.com", shorthand=False + "GET", + "/_matrix/media/v3/preview_url?url=http://example.com", + shorthand=False, ) # No requests made. @@ -446,7 +459,9 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.lookups["example.com"] = [(IPv4Address, "1.1.1.2")] channel = self.make_request( - "GET", "preview_url?url=http://example.com", shorthand=False + "GET", + "/_matrix/media/v3/preview_url?url=http://example.com", + shorthand=False, ) self.assertEqual(channel.code, 502) @@ -463,7 +478,9 @@ class URLPreviewTests(unittest.HomeserverTestCase): Blocked IP addresses, accessed directly, are not spidered. """ channel = self.make_request( - "GET", "preview_url?url=http://192.168.1.1", shorthand=False + "GET", + "/_matrix/media/v3/preview_url?url=http://192.168.1.1", + shorthand=False, ) # No requests made. @@ -479,7 +496,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): Blocked IP ranges, accessed directly, are not spidered. """ channel = self.make_request( - "GET", "preview_url?url=http://1.1.1.2", shorthand=False + "GET", "/_matrix/media/v3/preview_url?url=http://1.1.1.2", shorthand=False ) self.assertEqual(channel.code, 403) @@ -497,7 +514,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): channel = self.make_request( "GET", - "preview_url?url=http://example.com", + "/_matrix/media/v3/preview_url?url=http://example.com", shorthand=False, await_result=False, ) @@ -533,7 +550,9 @@ class URLPreviewTests(unittest.HomeserverTestCase): ] channel = self.make_request( - "GET", "preview_url?url=http://example.com", shorthand=False + "GET", + "/_matrix/media/v3/preview_url?url=http://example.com", + shorthand=False, ) self.assertEqual(channel.code, 502) self.assertEqual( @@ -553,7 +572,9 @@ class URLPreviewTests(unittest.HomeserverTestCase): ] channel = self.make_request( - "GET", "preview_url?url=http://example.com", shorthand=False + "GET", + "/_matrix/media/v3/preview_url?url=http://example.com", + shorthand=False, ) # No requests made. @@ -574,7 +595,9 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.lookups["example.com"] = [(IPv6Address, "2001:800::1")] channel = self.make_request( - "GET", "preview_url?url=http://example.com", shorthand=False + "GET", + "/_matrix/media/v3/preview_url?url=http://example.com", + shorthand=False, ) self.assertEqual(channel.code, 502) @@ -591,10 +614,11 @@ class URLPreviewTests(unittest.HomeserverTestCase): OPTIONS returns the OPTIONS. """ channel = self.make_request( - "OPTIONS", "preview_url?url=http://example.com", shorthand=False + "OPTIONS", + "/_matrix/media/v3/preview_url?url=http://example.com", + shorthand=False, ) - self.assertEqual(channel.code, 200) - self.assertEqual(channel.json_body, {}) + self.assertEqual(channel.code, 204) def test_accept_language_config_option(self) -> None: """ @@ -605,7 +629,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): # Build and make a request to the server channel = self.make_request( "GET", - "preview_url?url=http://example.com", + "/_matrix/media/v3/preview_url?url=http://example.com", shorthand=False, await_result=False, ) @@ -658,7 +682,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): channel = self.make_request( "GET", - "preview_url?url=http://matrix.org", + "/_matrix/media/v3/preview_url?url=http://matrix.org", shorthand=False, await_result=False, ) @@ -708,7 +732,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): channel = self.make_request( "GET", - "preview_url?url=http://matrix.org", + "/_matrix/media/v3/preview_url?url=http://matrix.org", shorthand=False, await_result=False, ) @@ -750,7 +774,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): channel = self.make_request( "GET", - "preview_url?url=http://matrix.org", + "/_matrix/media/v3/preview_url?url=http://matrix.org", shorthand=False, await_result=False, ) @@ -790,7 +814,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): channel = self.make_request( "GET", - "preview_url?url=http://matrix.org", + "/_matrix/media/v3/preview_url?url=http://matrix.org", shorthand=False, await_result=False, ) @@ -831,7 +855,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): channel = self.make_request( "GET", - f"preview_url?{query_params}", + f"/_matrix/media/v3/preview_url?{query_params}", shorthand=False, ) self.pump() @@ -852,7 +876,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): channel = self.make_request( "GET", - "preview_url?url=http://matrix.org", + "/_matrix/media/v3/preview_url?url=http://matrix.org", shorthand=False, await_result=False, ) @@ -889,7 +913,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): channel = self.make_request( "GET", - "preview_url?url=http://twitter.com/matrixdotorg/status/12345", + "/_matrix/media/v3/preview_url?url=http://twitter.com/matrixdotorg/status/12345", shorthand=False, await_result=False, ) @@ -949,7 +973,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): channel = self.make_request( "GET", - "preview_url?url=http://twitter.com/matrixdotorg/status/12345", + "/_matrix/media/v3/preview_url?url=http://twitter.com/matrixdotorg/status/12345", shorthand=False, await_result=False, ) @@ -998,7 +1022,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): channel = self.make_request( "GET", - "preview_url?url=http://www.hulu.com/watch/12345", + "/_matrix/media/v3/preview_url?url=http://www.hulu.com/watch/12345", shorthand=False, await_result=False, ) @@ -1043,7 +1067,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): channel = self.make_request( "GET", - "preview_url?url=http://twitter.com/matrixdotorg/status/12345", + "/_matrix/media/v3/preview_url?url=http://twitter.com/matrixdotorg/status/12345", shorthand=False, await_result=False, ) @@ -1072,7 +1096,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): channel = self.make_request( "GET", - "preview_url?url=http://www.twitter.com/matrixdotorg/status/12345", + "/_matrix/media/v3/preview_url?url=http://www.twitter.com/matrixdotorg/status/12345", shorthand=False, await_result=False, ) @@ -1164,7 +1188,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): channel = self.make_request( "GET", - "preview_url?url=http://www.twitter.com/matrixdotorg/status/12345", + "/_matrix/media/v3/preview_url?url=http://www.twitter.com/matrixdotorg/status/12345", shorthand=False, await_result=False, ) @@ -1205,7 +1229,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): channel = self.make_request( "GET", - "preview_url?url=http://cdn.twitter.com/matrixdotorg", + "/_matrix/media/v3/preview_url?url=http://cdn.twitter.com/matrixdotorg", shorthand=False, await_result=False, ) @@ -1247,7 +1271,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): # Check fetching channel = self.make_request( "GET", - f"download/{host}/{media_id}", + f"/_matrix/media/v3/download/{host}/{media_id}", shorthand=False, await_result=False, ) @@ -1260,7 +1284,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): channel = self.make_request( "GET", - f"download/{host}/{media_id}", + f"/_matrix/media/v3/download/{host}/{media_id}", shorthand=False, await_result=False, ) @@ -1295,7 +1319,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): # Check fetching channel = self.make_request( "GET", - f"thumbnail/{host}/{media_id}?width=32&height=32&method=scale", + f"/_matrix/media/v3/thumbnail/{host}/{media_id}?width=32&height=32&method=scale", shorthand=False, await_result=False, ) @@ -1313,7 +1337,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): channel = self.make_request( "GET", - f"thumbnail/{host}/{media_id}?width=32&height=32&method=scale", + f"/_matrix/media/v3/thumbnail/{host}/{media_id}?width=32&height=32&method=scale", shorthand=False, await_result=False, ) @@ -1343,7 +1367,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.assertTrue(os.path.isdir(thumbnail_dir)) self.reactor.advance(IMAGE_CACHE_EXPIRY_MS * 1000 + 1) - self.get_success(self.preview_url._url_previewer._expire_url_cache_data()) + self.get_success(self.url_previewer._expire_url_cache_data()) for path in [file_path] + file_dirs + [thumbnail_dir] + thumbnail_dirs: self.assertFalse( @@ -1363,7 +1387,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): channel = self.make_request( "GET", - "preview_url?url=" + bad_url, + "/_matrix/media/v3/preview_url?url=" + bad_url, shorthand=False, await_result=False, ) @@ -1372,7 +1396,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): channel = self.make_request( "GET", - "preview_url?url=" + good_url, + "/_matrix/media/v3/preview_url?url=" + good_url, shorthand=False, await_result=False, ) @@ -1404,7 +1428,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): channel = self.make_request( "GET", - "preview_url?url=" + bad_url, + "/_matrix/media/v3/preview_url?url=" + bad_url, shorthand=False, await_result=False, ) diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 6b9692c486..0c054a598f 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py
@@ -24,7 +24,10 @@ import synapse.rest.admin from synapse.http.site import XForwardedForRequest from synapse.rest.client import login from synapse.server import HomeServer -from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY +from synapse.storage.databases.main.client_ips import ( + LAST_SEEN_GRANULARITY, + DeviceLastConnectionInfo, +) from synapse.types import UserID from synapse.util import Clock @@ -65,15 +68,15 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): ) r = result[(user_id, device_id)] - self.assertLessEqual( - { - "user_id": user_id, - "device_id": device_id, - "ip": "ip", - "user_agent": "user_agent", - "last_seen": 12345678000, - }.items(), - r.items(), + self.assertEqual( + DeviceLastConnectionInfo( + user_id=user_id, + device_id=device_id, + ip="ip", + user_agent="user_agent", + last_seen=12345678000, + ), + r, ) def test_insert_new_client_ip_none_device_id(self) -> None: @@ -201,13 +204,13 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.assertEqual( result, { - (user_id, device_id): { - "user_id": user_id, - "device_id": device_id, - "ip": "ip", - "user_agent": "user_agent", - "last_seen": 12345678000, - }, + (user_id, device_id): DeviceLastConnectionInfo( + user_id=user_id, + device_id=device_id, + ip="ip", + user_agent="user_agent", + last_seen=12345678000, + ), }, ) @@ -292,20 +295,20 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.assertEqual( result, { - (user_id, device_id_1): { - "user_id": user_id, - "device_id": device_id_1, - "ip": "ip_1", - "user_agent": "user_agent_1", - "last_seen": 12345678000, - }, - (user_id, device_id_2): { - "user_id": user_id, - "device_id": device_id_2, - "ip": "ip_2", - "user_agent": "user_agent_3", - "last_seen": 12345688000 + LAST_SEEN_GRANULARITY, - }, + (user_id, device_id_1): DeviceLastConnectionInfo( + user_id=user_id, + device_id=device_id_1, + ip="ip_1", + user_agent="user_agent_1", + last_seen=12345678000, + ), + (user_id, device_id_2): DeviceLastConnectionInfo( + user_id=user_id, + device_id=device_id_2, + ip="ip_2", + user_agent="user_agent_3", + last_seen=12345688000 + LAST_SEEN_GRANULARITY, + ), }, ) @@ -526,15 +529,15 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): ) r = result[(user_id, device_id)] - self.assertLessEqual( - { - "user_id": user_id, - "device_id": device_id, - "ip": None, - "user_agent": None, - "last_seen": None, - }.items(), - r.items(), + self.assertEqual( + DeviceLastConnectionInfo( + user_id=user_id, + device_id=device_id, + ip=None, + user_agent=None, + last_seen=None, + ), + r, ) # Register the background update to run again. @@ -561,15 +564,15 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): ) r = result[(user_id, device_id)] - self.assertLessEqual( - { - "user_id": user_id, - "device_id": device_id, - "ip": "ip", - "user_agent": "user_agent", - "last_seen": 0, - }.items(), - r.items(), + self.assertEqual( + DeviceLastConnectionInfo( + user_id=user_id, + device_id=device_id, + ip="ip", + user_agent="user_agent", + last_seen=0, + ), + r, ) def test_old_user_ips_pruned(self) -> None: @@ -640,15 +643,15 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): ) r = result2[(user_id, device_id)] - self.assertLessEqual( - { - "user_id": user_id, - "device_id": device_id, - "ip": "ip", - "user_agent": "user_agent", - "last_seen": 0, - }.items(), - r.items(), + self.assertEqual( + DeviceLastConnectionInfo( + user_id=user_id, + device_id=device_id, + ip="ip", + user_agent="user_agent", + last_seen=0, + ), + r, ) def test_invalid_user_agents_are_ignored(self) -> None: @@ -777,13 +780,13 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase): self.store.get_last_client_ip_by_device(self.user_id, device_id) ) r = result[(self.user_id, device_id)] - self.assertLessEqual( - { - "user_id": self.user_id, - "device_id": device_id, - "ip": expected_ip, - "user_agent": "Mozzila pizza", - "last_seen": 123456100, - }.items(), - r.items(), + self.assertEqual( + DeviceLastConnectionInfo( + user_id=self.user_id, + device_id=device_id, + ip=expected_ip, + user_agent="Mozzila pizza", + last_seen=123456100, + ), + r, ) diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index b55dd07f14..2f6499966c 100644 --- a/tests/storage/test_event_chain.py +++ b/tests/storage/test_event_chain.py
@@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Set, Tuple +from typing import Dict, List, Set, Tuple, cast from twisted.test.proto_helpers import MemoryReactor from twisted.trial import unittest @@ -421,41 +421,53 @@ class EventChainStoreTestCase(HomeserverTestCase): self, events: List[EventBase] ) -> Tuple[Dict[str, Tuple[int, int]], _LinkMap]: # Fetch the map from event ID -> (chain ID, sequence number) - rows = self.get_success( - self.store.db_pool.simple_select_many_batch( - table="event_auth_chains", - column="event_id", - iterable=[e.event_id for e in events], - retcols=("event_id", "chain_id", "sequence_number"), - keyvalues={}, - ) + rows = cast( + List[Tuple[str, int, int]], + self.get_success( + self.store.db_pool.simple_select_many_batch( + table="event_auth_chains", + column="event_id", + iterable=[e.event_id for e in events], + retcols=("event_id", "chain_id", "sequence_number"), + keyvalues={}, + ) + ), ) chain_map = { - row["event_id"]: (row["chain_id"], row["sequence_number"]) for row in rows + event_id: (chain_id, sequence_number) + for event_id, chain_id, sequence_number in rows } # Fetch all the links and pass them to the _LinkMap. - rows = self.get_success( - self.store.db_pool.simple_select_many_batch( - table="event_auth_chain_links", - column="origin_chain_id", - iterable=[chain_id for chain_id, _ in chain_map.values()], - retcols=( - "origin_chain_id", - "origin_sequence_number", - "target_chain_id", - "target_sequence_number", - ), - keyvalues={}, - ) + auth_chain_rows = cast( + List[Tuple[int, int, int, int]], + self.get_success( + self.store.db_pool.simple_select_many_batch( + table="event_auth_chain_links", + column="origin_chain_id", + iterable=[chain_id for chain_id, _ in chain_map.values()], + retcols=( + "origin_chain_id", + "origin_sequence_number", + "target_chain_id", + "target_sequence_number", + ), + keyvalues={}, + ) + ), ) link_map = _LinkMap() - for row in rows: + for ( + origin_chain_id, + origin_sequence_number, + target_chain_id, + target_sequence_number, + ) in auth_chain_rows: added = link_map.add_link( - (row["origin_chain_id"], row["origin_sequence_number"]), - (row["target_chain_id"], row["target_sequence_number"]), + (origin_chain_id, origin_sequence_number), + (target_chain_id, target_sequence_number), ) # We shouldn't have persisted any redundant links diff --git a/tests/unittest.py b/tests/unittest.py
index dbaff361b4..99ad02eb06 100644 --- a/tests/unittest.py +++ b/tests/unittest.py
@@ -60,7 +60,7 @@ from synapse.config.homeserver import HomeServerConfig from synapse.config.server import DEFAULT_ROOM_VERSION from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.federation.transport.server import TransportLayerServer -from synapse.http.server import JsonResource +from synapse.http.server import JsonResource, OptionsResource from synapse.http.site import SynapseRequest, SynapseSite from synapse.logging.context import ( SENTINEL_CONTEXT, @@ -459,7 +459,7 @@ class HomeserverTestCase(TestCase): The default calls `self.create_resource_dict` and builds the resultant dict into a tree. """ - root_resource = Resource() + root_resource = OptionsResource() create_resource_tree(self.create_resource_dict(), root_resource) return root_resource diff --git a/tests/util/test_retryutils.py b/tests/util/test_retryutils.py
index 4bcd17a6fc..ad88b24566 100644 --- a/tests/util/test_retryutils.py +++ b/tests/util/test_retryutils.py
@@ -11,6 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from unittest import mock + +from synapse.notifier import Notifier +from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter from tests.unittest import HomeserverTestCase @@ -109,6 +113,77 @@ class RetryLimiterTestCase(HomeserverTestCase): new_timings = self.get_success(store.get_destination_retry_timings("test_dest")) self.assertIsNone(new_timings) + def test_notifier_replication(self) -> None: + """Ensure the notifier/replication client is called only when expected.""" + store = self.hs.get_datastores().main + + notifier = mock.Mock(spec=Notifier) + replication_client = mock.Mock(spec=ReplicationCommandHandler) + + limiter = self.get_success( + get_retry_limiter( + "test_dest", + self.clock, + store, + notifier=notifier, + replication_client=replication_client, + ) + ) + + # The server is already up, nothing should occur. + self.pump(1) + with limiter: + pass + self.pump() + + new_timings = self.get_success(store.get_destination_retry_timings("test_dest")) + self.assertIsNone(new_timings) + notifier.notify_remote_server_up.assert_not_called() + replication_client.send_remote_server_up.assert_not_called() + + # Attempt again, but return an error. This will cause new retry timings, but + # should not trigger server up notifications. + self.pump(1) + try: + with limiter: + raise AssertionError("argh") + except AssertionError: + pass + self.pump() + + new_timings = self.get_success(store.get_destination_retry_timings("test_dest")) + # The exact retry timings are tested separately. + self.assertIsNotNone(new_timings) + notifier.notify_remote_server_up.assert_not_called() + replication_client.send_remote_server_up.assert_not_called() + + # A second failing request should be treated as the above. + self.pump(1) + try: + with limiter: + raise AssertionError("argh") + except AssertionError: + pass + self.pump() + + new_timings = self.get_success(store.get_destination_retry_timings("test_dest")) + # The exact retry timings are tested separately. + self.assertIsNotNone(new_timings) + notifier.notify_remote_server_up.assert_not_called() + replication_client.send_remote_server_up.assert_not_called() + + # A final successful attempt should generate a server up notification. + self.pump(1) + with limiter: + pass + self.pump() + + new_timings = self.get_success(store.get_destination_retry_timings("test_dest")) + # The exact retry timings are tested separately. + self.assertIsNone(new_timings) + notifier.notify_remote_server_up.assert_called_once_with("test_dest") + replication_client.send_remote_server_up.assert_called_once_with("test_dest") + def test_max_retry_interval(self) -> None: """Test that `destination_max_retry_interval` setting works as expected""" store = self.hs.get_datastores().main