diff --git a/tests/federation/transport/test_knocking.py b/tests/federation/transport/test_knocking.py
index 3f42f79f26..b63ef3d4ed 100644
--- a/tests/federation/transport/test_knocking.py
+++ b/tests/federation/transport/test_knocking.py
@@ -308,7 +308,7 @@ class FederationKnockingTestCase(
self.assertEqual(200, channel.code, channel.result)
# Check that we got the stripped room state in return
- room_state_events = channel.json_body["knock_state_events"]
+ room_state_events = channel.json_body["knock_room_state"]
# Validate the stripped room state events
self.check_knock_room_state_against_room_state(
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index a7e6cdd66a..867dbd6001 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -31,7 +31,7 @@ from synapse.appservice import (
from synapse.handlers.appservice import ApplicationServicesHandler
from synapse.rest.client import login, receipts, register, room, sendtodevice
from synapse.server import HomeServer
-from synapse.types import JsonDict, RoomStreamToken
+from synapse.types import JsonDict, RoomStreamToken, StreamKeyType
from synapse.util import Clock
from synapse.util.stringutils import random_string
@@ -86,7 +86,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
[event],
]
)
- self.handler.notify_interested_services(RoomStreamToken(None, 1))
+ self.handler.notify_interested_services(RoomStreamToken(stream=1))
self.mock_scheduler.enqueue_for_appservice.assert_called_once_with(
interested_service, events=[event]
@@ -107,7 +107,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
]
)
self.mock_store.get_events_as_list = AsyncMock(side_effect=[[event]])
- self.handler.notify_interested_services(RoomStreamToken(None, 0))
+ self.handler.notify_interested_services(RoomStreamToken(stream=0))
self.mock_as_api.query_user.assert_called_once_with(services[0], user_id)
@@ -126,7 +126,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
]
)
- self.handler.notify_interested_services(RoomStreamToken(None, 0))
+ self.handler.notify_interested_services(RoomStreamToken(stream=0))
self.assertFalse(
self.mock_as_api.query_user.called,
@@ -304,7 +304,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
)
self.handler.notify_interested_services_ephemeral(
- "receipt_key", 580, ["@fakerecipient:example.com"]
+ StreamKeyType.RECEIPT, 580, ["@fakerecipient:example.com"]
)
self.mock_scheduler.enqueue_for_appservice.assert_called_once_with(
interested_service, ephemeral=[event]
@@ -332,7 +332,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
)
self.handler.notify_interested_services_ephemeral(
- "receipt_key", 580, ["@fakerecipient:example.com"]
+ StreamKeyType.RECEIPT, 580, ["@fakerecipient:example.com"]
)
# This method will be called, but with an empty list of events
self.mock_scheduler.enqueue_for_appservice.assert_called_once_with(
@@ -441,7 +441,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
self.get_success(
self.hs.get_application_service_handler()._notify_interested_services(
RoomStreamToken(
- None, self.hs.get_application_service_handler().current_max
+ stream=self.hs.get_application_service_handler().current_max
)
)
)
@@ -634,7 +634,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
self.get_success(
self.hs.get_application_service_handler()._notify_interested_services_ephemeral(
services=[interested_appservice],
- stream_key="receipt_key",
+ stream_key=StreamKeyType.RECEIPT,
new_token=stream_token,
users=[self.exclusive_as_user],
)
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 95106ec8f3..d7025c6f2c 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -28,7 +28,7 @@ from synapse.federation.transport.server import TransportLayerServer
from synapse.handlers.typing import TypingWriterHandler
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
from synapse.server import HomeServer
-from synapse.types import JsonDict, Requester, UserID, create_requester
+from synapse.types import JsonDict, Requester, StreamKeyType, UserID, create_requester
from synapse.util import Clock
from tests import unittest
@@ -174,7 +174,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
return_value=1
)
- self.datastore.get_partial_current_state_deltas = Mock(return_value=(0, None)) # type: ignore[method-assign]
+ self.datastore.get_partial_current_state_deltas = Mock(return_value=(0, [])) # type: ignore[method-assign]
self.datastore.get_to_device_stream_token = Mock( # type: ignore[method-assign]
return_value=0
@@ -203,7 +203,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
)
)
- self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])])
+ self.on_new_event.assert_has_calls(
+ [call(StreamKeyType.TYPING, 1, rooms=[ROOM_ID])]
+ )
self.assertEqual(self.event_source.get_current_key(), 1)
events = self.get_success(
@@ -273,7 +275,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, 200)
- self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])])
+ self.on_new_event.assert_has_calls(
+ [call(StreamKeyType.TYPING, 1, rooms=[ROOM_ID])]
+ )
self.assertEqual(self.event_source.get_current_key(), 1)
events = self.get_success(
@@ -349,7 +353,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
)
)
- self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])])
+ self.on_new_event.assert_has_calls(
+ [call(StreamKeyType.TYPING, 1, rooms=[ROOM_ID])]
+ )
self.mock_federation_client.put_json.assert_called_once_with(
"farm",
@@ -399,7 +405,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
)
)
- self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])])
+ self.on_new_event.assert_has_calls(
+ [call(StreamKeyType.TYPING, 1, rooms=[ROOM_ID])]
+ )
self.on_new_event.reset_mock()
self.assertEqual(self.event_source.get_current_key(), 1)
@@ -425,7 +433,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.reactor.pump([16])
- self.on_new_event.assert_has_calls([call("typing_key", 2, rooms=[ROOM_ID])])
+ self.on_new_event.assert_has_calls(
+ [call(StreamKeyType.TYPING, 2, rooms=[ROOM_ID])]
+ )
self.assertEqual(self.event_source.get_current_key(), 2)
events = self.get_success(
@@ -459,7 +469,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
)
)
- self.on_new_event.assert_has_calls([call("typing_key", 3, rooms=[ROOM_ID])])
+ self.on_new_event.assert_has_calls(
+ [call(StreamKeyType.TYPING, 3, rooms=[ROOM_ID])]
+ )
self.on_new_event.reset_mock()
self.assertEqual(self.event_source.get_current_key(), 3)
diff --git a/tests/media/test_base.py b/tests/media/test_base.py
index 119d7ba66f..144948f23c 100644
--- a/tests/media/test_base.py
+++ b/tests/media/test_base.py
@@ -42,18 +42,35 @@ class GetFileNameFromHeadersTests(unittest.TestCase):
class AddFileHeadersTests(unittest.TestCase):
TEST_CASES = {
+ # Safe values use inline.
"text/plain": b"inline; filename=file.name",
"text/csv": b"inline; filename=file.name",
"image/png": b"inline; filename=file.name",
+ # Unlisted values are set to attachment.
"text/html": b"attachment; filename=file.name",
"any/thing": b"attachment; filename=file.name",
+ # Parameters get ignored.
+ "text/plain; charset=utf-8": b"inline; filename=file.name",
+ "text/markdown; charset=utf-8; variant=CommonMark": b"attachment; filename=file.name",
+ # Parsed as lowercase.
+ "Text/Plain": b"inline; filename=file.name",
+ # Bad values don't choke.
+ "": b"attachment; filename=file.name",
+ ";": b"attachment; filename=file.name",
}
def test_content_disposition(self) -> None:
for media_type, expected in self.TEST_CASES.items():
request = Mock()
add_file_headers(request, media_type, 0, "file.name")
- request.setHeader.assert_any_call(b"Content-Disposition", expected)
+ # There should be a single call to set Content-Disposition.
+ for call in request.setHeader.call_args_list:
+ args, _ = call
+ if args[0] == b"Content-Disposition":
+ break
+ else:
+ self.fail(f"No Content-Disposition header found for {media_type}")
+ self.assertEqual(args[1], expected, media_type)
def test_no_filename(self) -> None:
request = Mock()
diff --git a/tests/media/test_media_storage.py b/tests/media/test_media_storage.py
index 04fc7bdcef..15f5d644e4 100644
--- a/tests/media/test_media_storage.py
+++ b/tests/media/test_media_storage.py
@@ -28,12 +28,13 @@ from typing_extensions import Literal
from twisted.internet import defer
from twisted.internet.defer import Deferred
from twisted.test.proto_helpers import MemoryReactor
+from twisted.web.resource import Resource
from synapse.api.errors import Codes
from synapse.events import EventBase
from synapse.http.types import QueryParams
from synapse.logging.context import make_deferred_yieldable
-from synapse.media._base import FileInfo
+from synapse.media._base import FileInfo, ThumbnailInfo
from synapse.media.filepath import MediaFilePaths
from synapse.media.media_storage import MediaStorage, ReadableFileWrapper
from synapse.media.storage_provider import FileStorageProviderBackend
@@ -41,12 +42,13 @@ from synapse.module_api import ModuleApi
from synapse.module_api.callbacks.spamchecker_callbacks import load_legacy_spam_checkers
from synapse.rest import admin
from synapse.rest.client import login
+from synapse.rest.media.thumbnail_resource import ThumbnailResource
from synapse.server import HomeServer
from synapse.types import JsonDict, RoomAlias
from synapse.util import Clock
from tests import unittest
-from tests.server import FakeChannel, FakeSite, make_request
+from tests.server import FakeChannel
from tests.test_utils import SMALL_PNG
from tests.utils import default_config
@@ -288,22 +290,22 @@ class MediaRepoTests(unittest.HomeserverTestCase):
return hs
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- media_resource = hs.get_media_repository_resource()
- self.download_resource = media_resource.children[b"download"]
- self.thumbnail_resource = media_resource.children[b"thumbnail"]
self.store = hs.get_datastores().main
self.media_repo = hs.get_media_repository()
self.media_id = "example.com/12345"
+ def create_resource_dict(self) -> Dict[str, Resource]:
+ resources = super().create_resource_dict()
+ resources["/_matrix/media"] = self.hs.get_media_repository_resource()
+ return resources
+
def _req(
self, content_disposition: Optional[bytes], include_content_type: bool = True
) -> FakeChannel:
- channel = make_request(
- self.reactor,
- FakeSite(self.download_resource, self.reactor),
+ channel = self.make_request(
"GET",
- self.media_id,
+ f"/_matrix/media/v3/download/{self.media_id}",
shorthand=False,
await_result=False,
)
@@ -481,11 +483,9 @@ class MediaRepoTests(unittest.HomeserverTestCase):
# Fetching again should work, without re-requesting the image from the
# remote.
params = "?width=32&height=32&method=scale"
- channel = make_request(
- self.reactor,
- FakeSite(self.thumbnail_resource, self.reactor),
+ channel = self.make_request(
"GET",
- self.media_id + params,
+ f"/_matrix/media/v3/thumbnail/{self.media_id}{params}",
shorthand=False,
await_result=False,
)
@@ -511,11 +511,9 @@ class MediaRepoTests(unittest.HomeserverTestCase):
)
shutil.rmtree(thumbnail_dir, ignore_errors=True)
- channel = make_request(
- self.reactor,
- FakeSite(self.thumbnail_resource, self.reactor),
+ channel = self.make_request(
"GET",
- self.media_id + params,
+ f"/_matrix/media/v3/thumbnail/{self.media_id}{params}",
shorthand=False,
await_result=False,
)
@@ -549,11 +547,9 @@ class MediaRepoTests(unittest.HomeserverTestCase):
"""
params = "?width=32&height=32&method=" + method
- channel = make_request(
- self.reactor,
- FakeSite(self.thumbnail_resource, self.reactor),
+ channel = self.make_request(
"GET",
- self.media_id + params,
+ f"/_matrix/media/r0/thumbnail/{self.media_id}{params}",
shorthand=False,
await_result=False,
)
@@ -590,7 +586,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
channel.json_body,
{
"errcode": "M_UNKNOWN",
- "error": "Cannot find any thumbnails for the requested media ([b'example.com', b'12345']). This might mean the media is not a supported_media_format=(image/jpeg, image/jpg, image/webp, image/gif, image/png) or that thumbnailing failed for some other reason. (Dynamic thumbnails are disabled on this server.)",
+ "error": "Cannot find any thumbnails for the requested media ('/_matrix/media/r0/thumbnail/example.com/12345'). This might mean the media is not a supported_media_format=(image/jpeg, image/jpg, image/webp, image/gif, image/png) or that thumbnailing failed for some other reason. (Dynamic thumbnails are disabled on this server.)",
},
)
else:
@@ -600,7 +596,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
channel.json_body,
{
"errcode": "M_NOT_FOUND",
- "error": "Not found [b'example.com', b'12345']",
+ "error": "Not found '/_matrix/media/r0/thumbnail/example.com/12345'",
},
)
@@ -609,34 +605,39 @@ class MediaRepoTests(unittest.HomeserverTestCase):
"""Test that choosing between thumbnails with the same quality rating succeeds.
We are not particular about which thumbnail is chosen."""
+
+ content_type = self.test_image.content_type.decode()
+ media_repo = self.hs.get_media_repository()
+ thumbnail_resouce = ThumbnailResource(
+ self.hs, media_repo, media_repo.media_storage
+ )
+
self.assertIsNotNone(
- self.thumbnail_resource._select_thumbnail(
+ thumbnail_resouce._select_thumbnail(
desired_width=desired_size,
desired_height=desired_size,
desired_method=method,
- desired_type=self.test_image.content_type,
+ desired_type=content_type,
# Provide two identical thumbnails which are guaranteed to have the same
# quality rating.
thumbnail_infos=[
- {
- "thumbnail_width": 32,
- "thumbnail_height": 32,
- "thumbnail_method": method,
- "thumbnail_type": self.test_image.content_type,
- "thumbnail_length": 256,
- "filesystem_id": f"thumbnail1{self.test_image.extension.decode()}",
- },
- {
- "thumbnail_width": 32,
- "thumbnail_height": 32,
- "thumbnail_method": method,
- "thumbnail_type": self.test_image.content_type,
- "thumbnail_length": 256,
- "filesystem_id": f"thumbnail2{self.test_image.extension.decode()}",
- },
+ ThumbnailInfo(
+ width=32,
+ height=32,
+ method=method,
+ type=content_type,
+ length=256,
+ ),
+ ThumbnailInfo(
+ width=32,
+ height=32,
+ method=method,
+ type=content_type,
+ length=256,
+ ),
],
file_id=f"image{self.test_image.extension.decode()}",
- url_cache=None,
+ url_cache=False,
server_name=None,
)
)
@@ -725,13 +726,13 @@ class SpamCheckerTestCaseLegacy(unittest.HomeserverTestCase):
self.user = self.register_user("user", "pass")
self.tok = self.login("user", "pass")
- # Allow for uploading and downloading to/from the media repo
- self.media_repo = hs.get_media_repository_resource()
- self.download_resource = self.media_repo.children[b"download"]
- self.upload_resource = self.media_repo.children[b"upload"]
-
load_legacy_spam_checkers(hs)
+ def create_resource_dict(self) -> Dict[str, Resource]:
+ resources = super().create_resource_dict()
+ resources["/_matrix/media"] = self.hs.get_media_repository_resource()
+ return resources
+
def default_config(self) -> Dict[str, Any]:
config = default_config("test")
@@ -751,9 +752,7 @@ class SpamCheckerTestCaseLegacy(unittest.HomeserverTestCase):
def test_upload_innocent(self) -> None:
"""Attempt to upload some innocent data that should be allowed."""
- self.helper.upload_media(
- self.upload_resource, SMALL_PNG, tok=self.tok, expect_code=200
- )
+ self.helper.upload_media(SMALL_PNG, tok=self.tok, expect_code=200)
def test_upload_ban(self) -> None:
"""Attempt to upload some data that includes bytes "evil", which should
@@ -762,9 +761,7 @@ class SpamCheckerTestCaseLegacy(unittest.HomeserverTestCase):
data = b"Some evil data"
- self.helper.upload_media(
- self.upload_resource, data, tok=self.tok, expect_code=400
- )
+ self.helper.upload_media(data, tok=self.tok, expect_code=400)
EVIL_DATA = b"Some evil data"
@@ -781,15 +778,15 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase):
self.user = self.register_user("user", "pass")
self.tok = self.login("user", "pass")
- # Allow for uploading and downloading to/from the media repo
- self.media_repo = hs.get_media_repository_resource()
- self.download_resource = self.media_repo.children[b"download"]
- self.upload_resource = self.media_repo.children[b"upload"]
-
hs.get_module_api().register_spam_checker_callbacks(
check_media_file_for_spam=self.check_media_file_for_spam
)
+ def create_resource_dict(self) -> Dict[str, Resource]:
+ resources = super().create_resource_dict()
+ resources["/_matrix/media"] = self.hs.get_media_repository_resource()
+ return resources
+
async def check_media_file_for_spam(
self, file_wrapper: ReadableFileWrapper, file_info: FileInfo
) -> Union[Codes, Literal["NOT_SPAM"], Tuple[Codes, JsonDict]]:
@@ -805,21 +802,16 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase):
def test_upload_innocent(self) -> None:
"""Attempt to upload some innocent data that should be allowed."""
- self.helper.upload_media(
- self.upload_resource, SMALL_PNG, tok=self.tok, expect_code=200
- )
+ self.helper.upload_media(SMALL_PNG, tok=self.tok, expect_code=200)
def test_upload_ban(self) -> None:
"""Attempt to upload some data that includes bytes "evil", which should
get rejected by the spam checker.
"""
- self.helper.upload_media(
- self.upload_resource, EVIL_DATA, tok=self.tok, expect_code=400
- )
+ self.helper.upload_media(EVIL_DATA, tok=self.tok, expect_code=400)
self.helper.upload_media(
- self.upload_resource,
EVIL_DATA_EXPERIMENT,
tok=self.tok,
expect_code=400,
diff --git a/tests/media/test_url_previewer.py b/tests/media/test_url_previewer.py
index 46ecde5344..04b69f378a 100644
--- a/tests/media/test_url_previewer.py
+++ b/tests/media/test_url_previewer.py
@@ -61,9 +61,9 @@ class URLPreviewTests(unittest.HomeserverTestCase):
return self.setup_test_homeserver(config=config)
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- media_repo_resource = hs.get_media_repository_resource()
- preview_url = media_repo_resource.children[b"preview_url"]
- self.url_previewer = preview_url._url_previewer
+ media_repo = hs.get_media_repository()
+ assert media_repo.url_previewer is not None
+ self.url_previewer = media_repo.url_previewer
def test_all_urls_allowed(self) -> None:
self.assertFalse(self.url_previewer._is_url_blocked("http://matrix.org"))
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index 172fc3a736..1dabf52156 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -94,12 +94,12 @@ class ModuleApiTestCase(BaseModuleApiTestCase):
self.assertEqual(len(emails), 1)
email = emails[0]
- self.assertEqual(email["medium"], "email")
- self.assertEqual(email["address"], "bob@bobinator.bob")
+ self.assertEqual(email.medium, "email")
+ self.assertEqual(email.address, "bob@bobinator.bob")
# Should these be 0?
- self.assertEqual(email["validated_at"], 0)
- self.assertEqual(email["added_at"], 0)
+ self.assertEqual(email.validated_at, 0)
+ self.assertEqual(email.added_at, 0)
# Check that the displayname was assigned
displayname = self.get_success(
diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py
index 6e78daa830..b230a6c361 100644
--- a/tests/replication/test_multi_media_repo.py
+++ b/tests/replication/test_multi_media_repo.py
@@ -13,7 +13,7 @@
# limitations under the License.
import logging
import os
-from typing import Optional, Tuple
+from typing import Any, Optional, Tuple
from twisted.internet.interfaces import IOpenSSLServerConnectionCreator
from twisted.internet.protocol import Factory
@@ -29,7 +29,7 @@ from synapse.util import Clock
from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file
from tests.replication._base import BaseMultiWorkerStreamTestCase
-from tests.server import FakeChannel, FakeSite, FakeTransport, make_request
+from tests.server import FakeChannel, FakeTransport, make_request
from tests.test_utils import SMALL_PNG
logger = logging.getLogger(__name__)
@@ -56,6 +56,16 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
conf["federation_custom_ca_list"] = [get_test_ca_cert_file()]
return conf
+ def make_worker_hs(
+ self, worker_app: str, extra_config: Optional[dict] = None, **kwargs: Any
+ ) -> HomeServer:
+ worker_hs = super().make_worker_hs(worker_app, extra_config, **kwargs)
+ # Force the media paths onto the replication resource.
+ worker_hs.get_media_repository_resource().register_servlets(
+ self._hs_to_site[worker_hs].resource, worker_hs
+ )
+ return worker_hs
+
def _get_media_req(
self, hs: HomeServer, target: str, media_id: str
) -> Tuple[FakeChannel, Request]:
@@ -68,12 +78,11 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
The channel for the *client* request and the *outbound* request for
the media which the caller should respond to.
"""
- resource = hs.get_media_repository_resource().children[b"download"]
channel = make_request(
self.reactor,
- FakeSite(resource, self.reactor),
+ self._hs_to_site[hs],
"GET",
- f"/{target}/{media_id}",
+ f"/_matrix/media/r0/download/{target}/{media_id}",
shorthand=False,
access_token=self.access_token,
await_result=False,
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index 359d131b37..8646b2f0fd 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -13,10 +13,12 @@
# limitations under the License.
import urllib.parse
+from typing import Dict
from parameterized import parameterized
from twisted.test.proto_helpers import MemoryReactor
+from twisted.web.resource import Resource
import synapse.rest.admin
from synapse.http.server import JsonResource
@@ -26,7 +28,6 @@ from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest
-from tests.server import FakeSite, make_request
from tests.test_utils import SMALL_PNG
@@ -55,21 +56,18 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
room.register_servlets,
]
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- # Allow for uploading and downloading to/from the media repo
- self.media_repo = hs.get_media_repository_resource()
- self.download_resource = self.media_repo.children[b"download"]
- self.upload_resource = self.media_repo.children[b"upload"]
+ def create_resource_dict(self) -> Dict[str, Resource]:
+ resources = super().create_resource_dict()
+ resources["/_matrix/media"] = self.hs.get_media_repository_resource()
+ return resources
def _ensure_quarantined(
self, admin_user_tok: str, server_and_media_id: str
) -> None:
"""Ensure a piece of media is quarantined when trying to access it."""
- channel = make_request(
- self.reactor,
- FakeSite(self.download_resource, self.reactor),
+ channel = self.make_request(
"GET",
- server_and_media_id,
+ f"/_matrix/media/v3/download/{server_and_media_id}",
shorthand=False,
access_token=admin_user_tok,
)
@@ -117,20 +115,16 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
non_admin_user_tok = self.login("id_nonadmin", "pass")
# Upload some media into the room
- response = self.helper.upload_media(
- self.upload_resource, SMALL_PNG, tok=admin_user_tok
- )
+ response = self.helper.upload_media(SMALL_PNG, tok=admin_user_tok)
# Extract media ID from the response
server_name_and_media_id = response["content_uri"][6:] # Cut off 'mxc://'
server_name, media_id = server_name_and_media_id.split("/")
# Attempt to access the media
- channel = make_request(
- self.reactor,
- FakeSite(self.download_resource, self.reactor),
+ channel = self.make_request(
"GET",
- server_name_and_media_id,
+ f"/_matrix/media/v3/download/{server_name_and_media_id}",
shorthand=False,
access_token=non_admin_user_tok,
)
@@ -173,12 +167,8 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
self.helper.join(room_id, non_admin_user, tok=non_admin_user_tok)
# Upload some media
- response_1 = self.helper.upload_media(
- self.upload_resource, SMALL_PNG, tok=non_admin_user_tok
- )
- response_2 = self.helper.upload_media(
- self.upload_resource, SMALL_PNG, tok=non_admin_user_tok
- )
+ response_1 = self.helper.upload_media(SMALL_PNG, tok=non_admin_user_tok)
+ response_2 = self.helper.upload_media(SMALL_PNG, tok=non_admin_user_tok)
# Extract mxcs
mxc_1 = response_1["content_uri"]
@@ -227,12 +217,8 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
non_admin_user_tok = self.login("user_nonadmin", "pass")
# Upload some media
- response_1 = self.helper.upload_media(
- self.upload_resource, SMALL_PNG, tok=non_admin_user_tok
- )
- response_2 = self.helper.upload_media(
- self.upload_resource, SMALL_PNG, tok=non_admin_user_tok
- )
+ response_1 = self.helper.upload_media(SMALL_PNG, tok=non_admin_user_tok)
+ response_2 = self.helper.upload_media(SMALL_PNG, tok=non_admin_user_tok)
# Extract media IDs
server_and_media_id_1 = response_1["content_uri"][6:]
@@ -265,12 +251,8 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
non_admin_user_tok = self.login("user_nonadmin", "pass")
# Upload some media
- response_1 = self.helper.upload_media(
- self.upload_resource, SMALL_PNG, tok=non_admin_user_tok
- )
- response_2 = self.helper.upload_media(
- self.upload_resource, SMALL_PNG, tok=non_admin_user_tok
- )
+ response_1 = self.helper.upload_media(SMALL_PNG, tok=non_admin_user_tok)
+ response_2 = self.helper.upload_media(SMALL_PNG, tok=non_admin_user_tok)
# Extract media IDs
server_and_media_id_1 = response_1["content_uri"][6:]
@@ -304,11 +286,9 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
self._ensure_quarantined(admin_user_tok, server_and_media_id_1)
# Attempt to access each piece of media
- channel = make_request(
- self.reactor,
- FakeSite(self.download_resource, self.reactor),
+ channel = self.make_request(
"GET",
- server_and_media_id_2,
+ f"/_matrix/media/v3/download/{server_and_media_id_2}",
shorthand=False,
access_token=non_admin_user_tok,
)
diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py
index 6d04911d67..278808abb5 100644
--- a/tests/rest/admin/test_media.py
+++ b/tests/rest/admin/test_media.py
@@ -13,10 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
+from typing import Dict
from parameterized import parameterized
from twisted.test.proto_helpers import MemoryReactor
+from twisted.web.resource import Resource
import synapse.rest.admin
from synapse.api.errors import Codes
@@ -26,22 +28,27 @@ from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest
-from tests.server import FakeSite, make_request
from tests.test_utils import SMALL_PNG
VALID_TIMESTAMP = 1609459200000 # 2021-01-01 in milliseconds
INVALID_TIMESTAMP_IN_S = 1893456000 # 2030-01-01 in seconds
-class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
+class _AdminMediaTests(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets,
synapse.rest.admin.register_servlets_for_media_repo,
login.register_servlets,
]
+ def create_resource_dict(self) -> Dict[str, Resource]:
+ resources = super().create_resource_dict()
+ resources["/_matrix/media"] = self.hs.get_media_repository_resource()
+ return resources
+
+
+class DeleteMediaByIDTestCase(_AdminMediaTests):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.media_repo = hs.get_media_repository_resource()
self.server_name = hs.hostname
self.admin_user = self.register_user("admin", "pass", admin=True)
@@ -117,12 +124,8 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
Tests that delete a media is successfully
"""
- download_resource = self.media_repo.children[b"download"]
- upload_resource = self.media_repo.children[b"upload"]
-
# Upload some media into the room
response = self.helper.upload_media(
- upload_resource,
SMALL_PNG,
tok=self.admin_user_tok,
expect_code=200,
@@ -134,11 +137,9 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
self.assertEqual(server_name, self.server_name)
# Attempt to access media
- channel = make_request(
- self.reactor,
- FakeSite(download_resource, self.reactor),
+ channel = self.make_request(
"GET",
- server_and_media_id,
+ f"/_matrix/media/v3/download/{server_and_media_id}",
shorthand=False,
access_token=self.admin_user_tok,
)
@@ -173,11 +174,9 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
)
# Attempt to access media
- channel = make_request(
- self.reactor,
- FakeSite(download_resource, self.reactor),
+ channel = self.make_request(
"GET",
- server_and_media_id,
+ f"/_matrix/media/v3/download/{server_and_media_id}",
shorthand=False,
access_token=self.admin_user_tok,
)
@@ -194,7 +193,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
self.assertFalse(os.path.exists(local_path))
-class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
+class DeleteMediaByDateSizeTestCase(_AdminMediaTests):
servlets = [
synapse.rest.admin.register_servlets,
synapse.rest.admin.register_servlets_for_media_repo,
@@ -529,11 +528,8 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
"""
Create a media and return media_id and server_and_media_id
"""
- upload_resource = self.media_repo.children[b"upload"]
-
# Upload some media into the room
response = self.helper.upload_media(
- upload_resource,
SMALL_PNG,
tok=self.admin_user_tok,
expect_code=200,
@@ -553,16 +549,12 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
"""
Try to access a media and check the result
"""
- download_resource = self.media_repo.children[b"download"]
-
media_id = server_and_media_id.split("/")[1]
local_path = self.filepaths.local_media_filepath(media_id)
- channel = make_request(
- self.reactor,
- FakeSite(download_resource, self.reactor),
+ channel = self.make_request(
"GET",
- server_and_media_id,
+ f"/_matrix/media/v3/download/{server_and_media_id}",
shorthand=False,
access_token=self.admin_user_tok,
)
@@ -591,27 +583,16 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.assertFalse(os.path.exists(local_path))
-class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
- servlets = [
- synapse.rest.admin.register_servlets,
- synapse.rest.admin.register_servlets_for_media_repo,
- login.register_servlets,
- ]
-
+class QuarantineMediaByIDTestCase(_AdminMediaTests):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- media_repo = hs.get_media_repository_resource()
self.store = hs.get_datastores().main
self.server_name = hs.hostname
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
- # Create media
- upload_resource = media_repo.children[b"upload"]
-
# Upload some media into the room
response = self.helper.upload_media(
- upload_resource,
SMALL_PNG,
tok=self.admin_user_tok,
expect_code=200,
@@ -720,26 +701,16 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
self.assertFalse(media_info["quarantined_by"])
-class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
- servlets = [
- synapse.rest.admin.register_servlets,
- synapse.rest.admin.register_servlets_for_media_repo,
- login.register_servlets,
- ]
-
+class ProtectMediaByIDTestCase(_AdminMediaTests):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- media_repo = hs.get_media_repository_resource()
+ hs.get_media_repository_resource()
self.store = hs.get_datastores().main
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
- # Create media
- upload_resource = media_repo.children[b"upload"]
-
# Upload some media into the room
response = self.helper.upload_media(
- upload_resource,
SMALL_PNG,
tok=self.admin_user_tok,
expect_code=200,
@@ -816,7 +787,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
self.assertFalse(media_info["safe_from_quarantine"])
-class PurgeMediaCacheTestCase(unittest.HomeserverTestCase):
+class PurgeMediaCacheTestCase(_AdminMediaTests):
servlets = [
synapse.rest.admin.register_servlets,
synapse.rest.admin.register_servlets_for_media_repo,
diff --git a/tests/rest/admin/test_statistics.py b/tests/rest/admin/test_statistics.py
index b60f16b914..cd8ee274d8 100644
--- a/tests/rest/admin/test_statistics.py
+++ b/tests/rest/admin/test_statistics.py
@@ -12,9 +12,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import List, Optional
+from typing import Dict, List, Optional
from twisted.test.proto_helpers import MemoryReactor
+from twisted.web.resource import Resource
import synapse.rest.admin
from synapse.api.errors import Codes
@@ -34,8 +35,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.media_repo = hs.get_media_repository_resource()
-
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
@@ -44,6 +43,11 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.url = "/_synapse/admin/v1/statistics/users/media"
+ def create_resource_dict(self) -> Dict[str, Resource]:
+ resources = super().create_resource_dict()
+ resources["/_matrix/media"] = self.hs.get_media_repository_resource()
+ return resources
+
def test_no_auth(self) -> None:
"""
Try to list users without authentication.
@@ -470,12 +474,9 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
user_token: Access token of the user
number_media: Number of media to be created for the user
"""
- upload_resource = self.media_repo.children[b"upload"]
for _ in range(number_media):
# Upload some media into the room
- self.helper.upload_media(
- upload_resource, SMALL_PNG, tok=user_token, expect_code=200
- )
+ self.helper.upload_media(SMALL_PNG, tok=user_token, expect_code=200)
def _check_fields(self, content: List[JsonDict]) -> None:
"""Checks that all attributes are present in content
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index b326ad2c90..37f37a09d8 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -17,12 +17,13 @@ import hmac
import os
import urllib.parse
from binascii import unhexlify
-from typing import List, Optional
+from typing import Dict, List, Optional
from unittest.mock import AsyncMock, Mock, patch
from parameterized import parameterized, parameterized_class
from twisted.test.proto_helpers import MemoryReactor
+from twisted.web.resource import Resource
import synapse.rest.admin
from synapse.api.constants import ApprovalNoticeMedium, LoginType, UserTypes
@@ -45,7 +46,6 @@ from synapse.types import JsonDict, UserID, create_requester
from synapse.util import Clock
from tests import unittest
-from tests.server import FakeSite, make_request
from tests.test_utils import SMALL_PNG
from tests.unittest import override_config
@@ -3421,7 +3421,6 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
- self.media_repo = hs.get_media_repository_resource()
self.filepaths = MediaFilePaths(hs.config.media.media_store_path)
self.admin_user = self.register_user("admin", "pass", admin=True)
@@ -3432,6 +3431,11 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
self.other_user
)
+ def create_resource_dict(self) -> Dict[str, Resource]:
+ resources = super().create_resource_dict()
+ resources["/_matrix/media"] = self.hs.get_media_repository_resource()
+ return resources
+
@parameterized.expand(["GET", "DELETE"])
def test_no_auth(self, method: str) -> None:
"""Try to list media of an user without authentication."""
@@ -3907,12 +3911,9 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
Returns:
The ID of the newly created media.
"""
- upload_resource = self.media_repo.children[b"upload"]
- download_resource = self.media_repo.children[b"download"]
-
# Upload some media into the room
response = self.helper.upload_media(
- upload_resource, image_data, user_token, filename, expect_code=200
+ image_data, user_token, filename, expect_code=200
)
# Extract media ID from the response
@@ -3920,11 +3921,9 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
media_id = server_and_media_id.split("/")[1]
# Try to access a media and to create `last_access_ts`
- channel = make_request(
- self.reactor,
- FakeSite(download_resource, self.reactor),
+ channel = self.make_request(
"GET",
- server_and_media_id,
+ f"/_matrix/media/v3/download/{server_and_media_id}",
shorthand=False,
access_token=user_token,
)
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index 7627823d3f..aaa4f3bba0 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -1447,6 +1447,30 @@ class RoomJoinRatelimitTestCase(RoomBase):
@unittest.override_config(
{"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
)
+ def test_join_attempts_local_ratelimit(self) -> None:
+ """Tests that unsuccessful joins that end up being denied are rate-limited."""
+ # Create 4 rooms
+ room_ids = [
+ self.helper.create_room_as(self.user_id, is_public=True) for _ in range(4)
+ ]
+ # Pre-emptively ban the user who will attempt to join.
+ joiner_user_id = self.register_user("joiner", "secret")
+ for room_id in room_ids:
+ self.helper.ban(room_id, self.user_id, joiner_user_id)
+
+ # Now make a new user try to join some of them.
+ # The user can make 3 requests, each of which should be denied.
+ for room_id in room_ids[0:3]:
+ self.helper.join(room_id, joiner_user_id, expect_code=HTTPStatus.FORBIDDEN)
+
+ # The fourth attempt should be rate limited.
+ self.helper.join(
+ room_ids[3], joiner_user_id, expect_code=HTTPStatus.TOO_MANY_REQUESTS
+ )
+
+ @unittest.override_config(
+ {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
+ )
def test_join_local_ratelimit_profile_change(self) -> None:
"""Tests that sending a profile update into all of the user's joined rooms isn't
rate-limited by the rate-limiter on joins."""
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index 9532e5ddc1..465b696c0b 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -37,7 +37,6 @@ import attr
from typing_extensions import Literal
from twisted.test.proto_helpers import MemoryReactorClock
-from twisted.web.resource import Resource
from twisted.web.server import Site
from synapse.api.constants import Membership
@@ -45,7 +44,7 @@ from synapse.api.errors import Codes
from synapse.server import HomeServer
from synapse.types import JsonDict
-from tests.server import FakeChannel, FakeSite, make_request
+from tests.server import FakeChannel, make_request
from tests.test_utils.html_parsers import TestHtmlParser
from tests.test_utils.oidc import FakeAuthorizationGrant, FakeOidcServer
@@ -558,7 +557,6 @@ class RestHelper:
def upload_media(
self,
- resource: Resource,
image_data: bytes,
tok: str,
filename: str = "test.png",
@@ -576,7 +574,7 @@ class RestHelper:
path = "/_matrix/media/r0/upload?filename=%s" % (filename,)
channel = make_request(
self.reactor,
- FakeSite(resource, self.reactor),
+ self.site,
"POST",
path,
content=image_data,
diff --git a/tests/rest/media/test_url_preview.py b/tests/rest/media/test_url_preview.py
index 05d5e39cab..24459c6af4 100644
--- a/tests/rest/media/test_url_preview.py
+++ b/tests/rest/media/test_url_preview.py
@@ -24,10 +24,10 @@ from twisted.internet.address import IPv4Address, IPv6Address
from twisted.internet.error import DNSLookupError
from twisted.internet.interfaces import IAddress, IResolutionReceiver
from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactor
+from twisted.web.resource import Resource
from synapse.config.oembed import OEmbedEndpointConfig
from synapse.media.url_previewer import IMAGE_CACHE_EXPIRY_MS
-from synapse.rest.media.media_repository_resource import MediaRepositoryResource
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock
@@ -117,8 +117,8 @@ class URLPreviewTests(unittest.HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.media_repo = hs.get_media_repository()
- media_repo_resource = hs.get_media_repository_resource()
- self.preview_url = media_repo_resource.children[b"preview_url"]
+ assert self.media_repo.url_previewer is not None
+ self.url_previewer = self.media_repo.url_previewer
self.lookups: Dict[str, Any] = {}
@@ -143,8 +143,15 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.reactor.nameResolver = Resolver() # type: ignore[assignment]
- def create_test_resource(self) -> MediaRepositoryResource:
- return self.hs.get_media_repository_resource()
+ def create_resource_dict(self) -> Dict[str, Resource]:
+ """Create a resource tree for the test server
+
+ A resource tree is a mapping from path to twisted.web.resource.
+
+ The default implementation creates a JsonResource and calls each function in
+ `servlets` to register servlets against it.
+ """
+ return {"/_matrix/media": self.hs.get_media_repository_resource()}
def _assert_small_png(self, json_body: JsonDict) -> None:
"""Assert properties from the SMALL_PNG test image."""
@@ -159,7 +166,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request(
"GET",
- "preview_url?url=http://matrix.org",
+ "/_matrix/media/v3/preview_url?url=http://matrix.org",
shorthand=False,
await_result=False,
)
@@ -183,7 +190,9 @@ class URLPreviewTests(unittest.HomeserverTestCase):
# Check the cache returns the correct response
channel = self.make_request(
- "GET", "preview_url?url=http://matrix.org", shorthand=False
+ "GET",
+ "/_matrix/media/v3/preview_url?url=http://matrix.org",
+ shorthand=False,
)
# Check the cache response has the same content
@@ -193,13 +202,15 @@ class URLPreviewTests(unittest.HomeserverTestCase):
)
# Clear the in-memory cache
- self.assertIn("http://matrix.org", self.preview_url._url_previewer._cache)
- self.preview_url._url_previewer._cache.pop("http://matrix.org")
- self.assertNotIn("http://matrix.org", self.preview_url._url_previewer._cache)
+ self.assertIn("http://matrix.org", self.url_previewer._cache)
+ self.url_previewer._cache.pop("http://matrix.org")
+ self.assertNotIn("http://matrix.org", self.url_previewer._cache)
# Check the database cache returns the correct response
channel = self.make_request(
- "GET", "preview_url?url=http://matrix.org", shorthand=False
+ "GET",
+ "/_matrix/media/v3/preview_url?url=http://matrix.org",
+ shorthand=False,
)
# Check the cache response has the same content
@@ -221,7 +232,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request(
"GET",
- "preview_url?url=http://matrix.org",
+ "/_matrix/media/v3/preview_url?url=http://matrix.org",
shorthand=False,
await_result=False,
)
@@ -251,7 +262,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request(
"GET",
- "preview_url?url=http://matrix.org",
+ "/_matrix/media/v3/preview_url?url=http://matrix.org",
shorthand=False,
await_result=False,
)
@@ -287,7 +298,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request(
"GET",
- "preview_url?url=http://matrix.org",
+ "/_matrix/media/v3/preview_url?url=http://matrix.org",
shorthand=False,
await_result=False,
)
@@ -328,7 +339,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request(
"GET",
- "preview_url?url=http://matrix.org",
+ "/_matrix/media/v3/preview_url?url=http://matrix.org",
shorthand=False,
await_result=False,
)
@@ -363,7 +374,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request(
"GET",
- "preview_url?url=http://matrix.org",
+ "/_matrix/media/v3/preview_url?url=http://matrix.org",
shorthand=False,
await_result=False,
)
@@ -396,7 +407,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request(
"GET",
- "preview_url?url=http://example.com",
+ "/_matrix/media/v3/preview_url?url=http://example.com",
shorthand=False,
await_result=False,
)
@@ -425,7 +436,9 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.lookups["example.com"] = [(IPv4Address, "192.168.1.1")]
channel = self.make_request(
- "GET", "preview_url?url=http://example.com", shorthand=False
+ "GET",
+ "/_matrix/media/v3/preview_url?url=http://example.com",
+ shorthand=False,
)
# No requests made.
@@ -446,7 +459,9 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.lookups["example.com"] = [(IPv4Address, "1.1.1.2")]
channel = self.make_request(
- "GET", "preview_url?url=http://example.com", shorthand=False
+ "GET",
+ "/_matrix/media/v3/preview_url?url=http://example.com",
+ shorthand=False,
)
self.assertEqual(channel.code, 502)
@@ -463,7 +478,9 @@ class URLPreviewTests(unittest.HomeserverTestCase):
Blocked IP addresses, accessed directly, are not spidered.
"""
channel = self.make_request(
- "GET", "preview_url?url=http://192.168.1.1", shorthand=False
+ "GET",
+ "/_matrix/media/v3/preview_url?url=http://192.168.1.1",
+ shorthand=False,
)
# No requests made.
@@ -479,7 +496,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
Blocked IP ranges, accessed directly, are not spidered.
"""
channel = self.make_request(
- "GET", "preview_url?url=http://1.1.1.2", shorthand=False
+ "GET", "/_matrix/media/v3/preview_url?url=http://1.1.1.2", shorthand=False
)
self.assertEqual(channel.code, 403)
@@ -497,7 +514,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request(
"GET",
- "preview_url?url=http://example.com",
+ "/_matrix/media/v3/preview_url?url=http://example.com",
shorthand=False,
await_result=False,
)
@@ -533,7 +550,9 @@ class URLPreviewTests(unittest.HomeserverTestCase):
]
channel = self.make_request(
- "GET", "preview_url?url=http://example.com", shorthand=False
+ "GET",
+ "/_matrix/media/v3/preview_url?url=http://example.com",
+ shorthand=False,
)
self.assertEqual(channel.code, 502)
self.assertEqual(
@@ -553,7 +572,9 @@ class URLPreviewTests(unittest.HomeserverTestCase):
]
channel = self.make_request(
- "GET", "preview_url?url=http://example.com", shorthand=False
+ "GET",
+ "/_matrix/media/v3/preview_url?url=http://example.com",
+ shorthand=False,
)
# No requests made.
@@ -574,7 +595,9 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.lookups["example.com"] = [(IPv6Address, "2001:800::1")]
channel = self.make_request(
- "GET", "preview_url?url=http://example.com", shorthand=False
+ "GET",
+ "/_matrix/media/v3/preview_url?url=http://example.com",
+ shorthand=False,
)
self.assertEqual(channel.code, 502)
@@ -591,10 +614,11 @@ class URLPreviewTests(unittest.HomeserverTestCase):
OPTIONS returns the OPTIONS.
"""
channel = self.make_request(
- "OPTIONS", "preview_url?url=http://example.com", shorthand=False
+ "OPTIONS",
+ "/_matrix/media/v3/preview_url?url=http://example.com",
+ shorthand=False,
)
- self.assertEqual(channel.code, 200)
- self.assertEqual(channel.json_body, {})
+ self.assertEqual(channel.code, 204)
def test_accept_language_config_option(self) -> None:
"""
@@ -605,7 +629,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
# Build and make a request to the server
channel = self.make_request(
"GET",
- "preview_url?url=http://example.com",
+ "/_matrix/media/v3/preview_url?url=http://example.com",
shorthand=False,
await_result=False,
)
@@ -658,7 +682,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request(
"GET",
- "preview_url?url=http://matrix.org",
+ "/_matrix/media/v3/preview_url?url=http://matrix.org",
shorthand=False,
await_result=False,
)
@@ -708,7 +732,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request(
"GET",
- "preview_url?url=http://matrix.org",
+ "/_matrix/media/v3/preview_url?url=http://matrix.org",
shorthand=False,
await_result=False,
)
@@ -750,7 +774,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request(
"GET",
- "preview_url?url=http://matrix.org",
+ "/_matrix/media/v3/preview_url?url=http://matrix.org",
shorthand=False,
await_result=False,
)
@@ -790,7 +814,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request(
"GET",
- "preview_url?url=http://matrix.org",
+ "/_matrix/media/v3/preview_url?url=http://matrix.org",
shorthand=False,
await_result=False,
)
@@ -831,7 +855,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request(
"GET",
- f"preview_url?{query_params}",
+ f"/_matrix/media/v3/preview_url?{query_params}",
shorthand=False,
)
self.pump()
@@ -852,7 +876,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request(
"GET",
- "preview_url?url=http://matrix.org",
+ "/_matrix/media/v3/preview_url?url=http://matrix.org",
shorthand=False,
await_result=False,
)
@@ -889,7 +913,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request(
"GET",
- "preview_url?url=http://twitter.com/matrixdotorg/status/12345",
+ "/_matrix/media/v3/preview_url?url=http://twitter.com/matrixdotorg/status/12345",
shorthand=False,
await_result=False,
)
@@ -949,7 +973,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request(
"GET",
- "preview_url?url=http://twitter.com/matrixdotorg/status/12345",
+ "/_matrix/media/v3/preview_url?url=http://twitter.com/matrixdotorg/status/12345",
shorthand=False,
await_result=False,
)
@@ -998,7 +1022,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request(
"GET",
- "preview_url?url=http://www.hulu.com/watch/12345",
+ "/_matrix/media/v3/preview_url?url=http://www.hulu.com/watch/12345",
shorthand=False,
await_result=False,
)
@@ -1043,7 +1067,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request(
"GET",
- "preview_url?url=http://twitter.com/matrixdotorg/status/12345",
+ "/_matrix/media/v3/preview_url?url=http://twitter.com/matrixdotorg/status/12345",
shorthand=False,
await_result=False,
)
@@ -1072,7 +1096,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request(
"GET",
- "preview_url?url=http://www.twitter.com/matrixdotorg/status/12345",
+ "/_matrix/media/v3/preview_url?url=http://www.twitter.com/matrixdotorg/status/12345",
shorthand=False,
await_result=False,
)
@@ -1164,7 +1188,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request(
"GET",
- "preview_url?url=http://www.twitter.com/matrixdotorg/status/12345",
+ "/_matrix/media/v3/preview_url?url=http://www.twitter.com/matrixdotorg/status/12345",
shorthand=False,
await_result=False,
)
@@ -1205,7 +1229,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request(
"GET",
- "preview_url?url=http://cdn.twitter.com/matrixdotorg",
+ "/_matrix/media/v3/preview_url?url=http://cdn.twitter.com/matrixdotorg",
shorthand=False,
await_result=False,
)
@@ -1247,7 +1271,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
# Check fetching
channel = self.make_request(
"GET",
- f"download/{host}/{media_id}",
+ f"/_matrix/media/v3/download/{host}/{media_id}",
shorthand=False,
await_result=False,
)
@@ -1260,7 +1284,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request(
"GET",
- f"download/{host}/{media_id}",
+ f"/_matrix/media/v3/download/{host}/{media_id}",
shorthand=False,
await_result=False,
)
@@ -1295,7 +1319,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
# Check fetching
channel = self.make_request(
"GET",
- f"thumbnail/{host}/{media_id}?width=32&height=32&method=scale",
+ f"/_matrix/media/v3/thumbnail/{host}/{media_id}?width=32&height=32&method=scale",
shorthand=False,
await_result=False,
)
@@ -1313,7 +1337,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request(
"GET",
- f"thumbnail/{host}/{media_id}?width=32&height=32&method=scale",
+ f"/_matrix/media/v3/thumbnail/{host}/{media_id}?width=32&height=32&method=scale",
shorthand=False,
await_result=False,
)
@@ -1343,7 +1367,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertTrue(os.path.isdir(thumbnail_dir))
self.reactor.advance(IMAGE_CACHE_EXPIRY_MS * 1000 + 1)
- self.get_success(self.preview_url._url_previewer._expire_url_cache_data())
+ self.get_success(self.url_previewer._expire_url_cache_data())
for path in [file_path] + file_dirs + [thumbnail_dir] + thumbnail_dirs:
self.assertFalse(
@@ -1363,7 +1387,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request(
"GET",
- "preview_url?url=" + bad_url,
+ "/_matrix/media/v3/preview_url?url=" + bad_url,
shorthand=False,
await_result=False,
)
@@ -1372,7 +1396,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request(
"GET",
- "preview_url?url=" + good_url,
+ "/_matrix/media/v3/preview_url?url=" + good_url,
shorthand=False,
await_result=False,
)
@@ -1404,7 +1428,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request(
"GET",
- "preview_url?url=" + bad_url,
+ "/_matrix/media/v3/preview_url?url=" + bad_url,
shorthand=False,
await_result=False,
)
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 6b9692c486..0c054a598f 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -24,7 +24,10 @@ import synapse.rest.admin
from synapse.http.site import XForwardedForRequest
from synapse.rest.client import login
from synapse.server import HomeServer
-from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY
+from synapse.storage.databases.main.client_ips import (
+ LAST_SEEN_GRANULARITY,
+ DeviceLastConnectionInfo,
+)
from synapse.types import UserID
from synapse.util import Clock
@@ -65,15 +68,15 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
)
r = result[(user_id, device_id)]
- self.assertLessEqual(
- {
- "user_id": user_id,
- "device_id": device_id,
- "ip": "ip",
- "user_agent": "user_agent",
- "last_seen": 12345678000,
- }.items(),
- r.items(),
+ self.assertEqual(
+ DeviceLastConnectionInfo(
+ user_id=user_id,
+ device_id=device_id,
+ ip="ip",
+ user_agent="user_agent",
+ last_seen=12345678000,
+ ),
+ r,
)
def test_insert_new_client_ip_none_device_id(self) -> None:
@@ -201,13 +204,13 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.assertEqual(
result,
{
- (user_id, device_id): {
- "user_id": user_id,
- "device_id": device_id,
- "ip": "ip",
- "user_agent": "user_agent",
- "last_seen": 12345678000,
- },
+ (user_id, device_id): DeviceLastConnectionInfo(
+ user_id=user_id,
+ device_id=device_id,
+ ip="ip",
+ user_agent="user_agent",
+ last_seen=12345678000,
+ ),
},
)
@@ -292,20 +295,20 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.assertEqual(
result,
{
- (user_id, device_id_1): {
- "user_id": user_id,
- "device_id": device_id_1,
- "ip": "ip_1",
- "user_agent": "user_agent_1",
- "last_seen": 12345678000,
- },
- (user_id, device_id_2): {
- "user_id": user_id,
- "device_id": device_id_2,
- "ip": "ip_2",
- "user_agent": "user_agent_3",
- "last_seen": 12345688000 + LAST_SEEN_GRANULARITY,
- },
+ (user_id, device_id_1): DeviceLastConnectionInfo(
+ user_id=user_id,
+ device_id=device_id_1,
+ ip="ip_1",
+ user_agent="user_agent_1",
+ last_seen=12345678000,
+ ),
+ (user_id, device_id_2): DeviceLastConnectionInfo(
+ user_id=user_id,
+ device_id=device_id_2,
+ ip="ip_2",
+ user_agent="user_agent_3",
+ last_seen=12345688000 + LAST_SEEN_GRANULARITY,
+ ),
},
)
@@ -526,15 +529,15 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
)
r = result[(user_id, device_id)]
- self.assertLessEqual(
- {
- "user_id": user_id,
- "device_id": device_id,
- "ip": None,
- "user_agent": None,
- "last_seen": None,
- }.items(),
- r.items(),
+ self.assertEqual(
+ DeviceLastConnectionInfo(
+ user_id=user_id,
+ device_id=device_id,
+ ip=None,
+ user_agent=None,
+ last_seen=None,
+ ),
+ r,
)
# Register the background update to run again.
@@ -561,15 +564,15 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
)
r = result[(user_id, device_id)]
- self.assertLessEqual(
- {
- "user_id": user_id,
- "device_id": device_id,
- "ip": "ip",
- "user_agent": "user_agent",
- "last_seen": 0,
- }.items(),
- r.items(),
+ self.assertEqual(
+ DeviceLastConnectionInfo(
+ user_id=user_id,
+ device_id=device_id,
+ ip="ip",
+ user_agent="user_agent",
+ last_seen=0,
+ ),
+ r,
)
def test_old_user_ips_pruned(self) -> None:
@@ -640,15 +643,15 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
)
r = result2[(user_id, device_id)]
- self.assertLessEqual(
- {
- "user_id": user_id,
- "device_id": device_id,
- "ip": "ip",
- "user_agent": "user_agent",
- "last_seen": 0,
- }.items(),
- r.items(),
+ self.assertEqual(
+ DeviceLastConnectionInfo(
+ user_id=user_id,
+ device_id=device_id,
+ ip="ip",
+ user_agent="user_agent",
+ last_seen=0,
+ ),
+ r,
)
def test_invalid_user_agents_are_ignored(self) -> None:
@@ -777,13 +780,13 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase):
self.store.get_last_client_ip_by_device(self.user_id, device_id)
)
r = result[(self.user_id, device_id)]
- self.assertLessEqual(
- {
- "user_id": self.user_id,
- "device_id": device_id,
- "ip": expected_ip,
- "user_agent": "Mozzila pizza",
- "last_seen": 123456100,
- }.items(),
- r.items(),
+ self.assertEqual(
+ DeviceLastConnectionInfo(
+ user_id=self.user_id,
+ device_id=device_id,
+ ip=expected_ip,
+ user_agent="Mozzila pizza",
+ last_seen=123456100,
+ ),
+ r,
)
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index b55dd07f14..2f6499966c 100644
--- a/tests/storage/test_event_chain.py
+++ b/tests/storage/test_event_chain.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Dict, List, Set, Tuple
+from typing import Dict, List, Set, Tuple, cast
from twisted.test.proto_helpers import MemoryReactor
from twisted.trial import unittest
@@ -421,41 +421,53 @@ class EventChainStoreTestCase(HomeserverTestCase):
self, events: List[EventBase]
) -> Tuple[Dict[str, Tuple[int, int]], _LinkMap]:
# Fetch the map from event ID -> (chain ID, sequence number)
- rows = self.get_success(
- self.store.db_pool.simple_select_many_batch(
- table="event_auth_chains",
- column="event_id",
- iterable=[e.event_id for e in events],
- retcols=("event_id", "chain_id", "sequence_number"),
- keyvalues={},
- )
+ rows = cast(
+ List[Tuple[str, int, int]],
+ self.get_success(
+ self.store.db_pool.simple_select_many_batch(
+ table="event_auth_chains",
+ column="event_id",
+ iterable=[e.event_id for e in events],
+ retcols=("event_id", "chain_id", "sequence_number"),
+ keyvalues={},
+ )
+ ),
)
chain_map = {
- row["event_id"]: (row["chain_id"], row["sequence_number"]) for row in rows
+ event_id: (chain_id, sequence_number)
+ for event_id, chain_id, sequence_number in rows
}
# Fetch all the links and pass them to the _LinkMap.
- rows = self.get_success(
- self.store.db_pool.simple_select_many_batch(
- table="event_auth_chain_links",
- column="origin_chain_id",
- iterable=[chain_id for chain_id, _ in chain_map.values()],
- retcols=(
- "origin_chain_id",
- "origin_sequence_number",
- "target_chain_id",
- "target_sequence_number",
- ),
- keyvalues={},
- )
+ auth_chain_rows = cast(
+ List[Tuple[int, int, int, int]],
+ self.get_success(
+ self.store.db_pool.simple_select_many_batch(
+ table="event_auth_chain_links",
+ column="origin_chain_id",
+ iterable=[chain_id for chain_id, _ in chain_map.values()],
+ retcols=(
+ "origin_chain_id",
+ "origin_sequence_number",
+ "target_chain_id",
+ "target_sequence_number",
+ ),
+ keyvalues={},
+ )
+ ),
)
link_map = _LinkMap()
- for row in rows:
+ for (
+ origin_chain_id,
+ origin_sequence_number,
+ target_chain_id,
+ target_sequence_number,
+ ) in auth_chain_rows:
added = link_map.add_link(
- (row["origin_chain_id"], row["origin_sequence_number"]),
- (row["target_chain_id"], row["target_sequence_number"]),
+ (origin_chain_id, origin_sequence_number),
+ (target_chain_id, target_sequence_number),
)
# We shouldn't have persisted any redundant links
diff --git a/tests/unittest.py b/tests/unittest.py
index dbaff361b4..99ad02eb06 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -60,7 +60,7 @@ from synapse.config.homeserver import HomeServerConfig
from synapse.config.server import DEFAULT_ROOM_VERSION
from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.federation.transport.server import TransportLayerServer
-from synapse.http.server import JsonResource
+from synapse.http.server import JsonResource, OptionsResource
from synapse.http.site import SynapseRequest, SynapseSite
from synapse.logging.context import (
SENTINEL_CONTEXT,
@@ -459,7 +459,7 @@ class HomeserverTestCase(TestCase):
The default calls `self.create_resource_dict` and builds the resultant dict
into a tree.
"""
- root_resource = Resource()
+ root_resource = OptionsResource()
create_resource_tree(self.create_resource_dict(), root_resource)
return root_resource
diff --git a/tests/util/test_retryutils.py b/tests/util/test_retryutils.py
index 4bcd17a6fc..ad88b24566 100644
--- a/tests/util/test_retryutils.py
+++ b/tests/util/test_retryutils.py
@@ -11,6 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from unittest import mock
+
+from synapse.notifier import Notifier
+from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
from tests.unittest import HomeserverTestCase
@@ -109,6 +113,77 @@ class RetryLimiterTestCase(HomeserverTestCase):
new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
self.assertIsNone(new_timings)
+ def test_notifier_replication(self) -> None:
+ """Ensure the notifier/replication client is called only when expected."""
+ store = self.hs.get_datastores().main
+
+ notifier = mock.Mock(spec=Notifier)
+ replication_client = mock.Mock(spec=ReplicationCommandHandler)
+
+ limiter = self.get_success(
+ get_retry_limiter(
+ "test_dest",
+ self.clock,
+ store,
+ notifier=notifier,
+ replication_client=replication_client,
+ )
+ )
+
+ # The server is already up, nothing should occur.
+ self.pump(1)
+ with limiter:
+ pass
+ self.pump()
+
+ new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
+ self.assertIsNone(new_timings)
+ notifier.notify_remote_server_up.assert_not_called()
+ replication_client.send_remote_server_up.assert_not_called()
+
+ # Attempt again, but return an error. This will cause new retry timings, but
+ # should not trigger server up notifications.
+ self.pump(1)
+ try:
+ with limiter:
+ raise AssertionError("argh")
+ except AssertionError:
+ pass
+ self.pump()
+
+ new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
+ # The exact retry timings are tested separately.
+ self.assertIsNotNone(new_timings)
+ notifier.notify_remote_server_up.assert_not_called()
+ replication_client.send_remote_server_up.assert_not_called()
+
+ # A second failing request should be treated as the above.
+ self.pump(1)
+ try:
+ with limiter:
+ raise AssertionError("argh")
+ except AssertionError:
+ pass
+ self.pump()
+
+ new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
+ # The exact retry timings are tested separately.
+ self.assertIsNotNone(new_timings)
+ notifier.notify_remote_server_up.assert_not_called()
+ replication_client.send_remote_server_up.assert_not_called()
+
+ # A final successful attempt should generate a server up notification.
+ self.pump(1)
+ with limiter:
+ pass
+ self.pump()
+
+ new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
+ # The exact retry timings are tested separately.
+ self.assertIsNone(new_timings)
+ notifier.notify_remote_server_up.assert_called_once_with("test_dest")
+ replication_client.send_remote_server_up.assert_called_once_with("test_dest")
+
def test_max_retry_interval(self) -> None:
"""Test that `destination_max_retry_interval` setting works as expected"""
store = self.hs.get_datastores().main
|