diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/config/test_cache.py | 8 | ||||
-rw-r--r-- | tests/federation/transport/server/__init__.py | 13 | ||||
-rw-r--r-- | tests/federation/transport/server/test__base.py | 114 | ||||
-rw-r--r-- | tests/handlers/test_federation_event.py | 4 | ||||
-rw-r--r-- | tests/http/server/__init__.py | 13 | ||||
-rw-r--r-- | tests/http/server/_base.py | 100 | ||||
-rw-r--r-- | tests/http/test_servlet.py | 60 | ||||
-rw-r--r-- | tests/replication/http/__init__.py | 13 | ||||
-rw-r--r-- | tests/replication/http/test__base.py | 106 | ||||
-rw-r--r-- | tests/replication/test_sharded_event_persister.py | 14 | ||||
-rw-r--r-- | tests/rest/client/test_rooms.py | 264 | ||||
-rw-r--r-- | tests/rest/client/test_sync.py | 38 | ||||
-rw-r--r-- | tests/server.py | 14 | ||||
-rw-r--r-- | tests/server_notices/test_resource_limits_server_notices.py | 11 | ||||
-rw-r--r-- | tests/storage/test_event_chain.py | 2 | ||||
-rw-r--r-- | tests/storage/test_event_federation.py | 9 | ||||
-rw-r--r-- | tests/test_server.py | 111 | ||||
-rw-r--r-- | tests/test_state.py | 3 | ||||
-rw-r--r-- | tests/test_visibility.py | 4 | ||||
-rw-r--r-- | tests/unittest.py | 2 | ||||
-rw-r--r-- | tests/util/test_lrucache.py | 58 |
21 files changed, 900 insertions, 61 deletions
diff --git a/tests/config/test_cache.py b/tests/config/test_cache.py index 4bb82e810e..d2b3c299e3 100644 --- a/tests/config/test_cache.py +++ b/tests/config/test_cache.py @@ -38,6 +38,7 @@ class CacheConfigTests(TestCase): "SYNAPSE_NOT_CACHE": "BLAH", } self.config.read_config(config, config_dir_path="", data_dir_path="") + self.config.resize_all_caches() self.assertEqual(dict(self.config.cache_factors), {"something_or_other": 2.0}) @@ -52,6 +53,7 @@ class CacheConfigTests(TestCase): "SYNAPSE_CACHE_FACTOR_FOO": 1, } self.config.read_config(config, config_dir_path="", data_dir_path="") + self.config.resize_all_caches() self.assertEqual( dict(self.config.cache_factors), @@ -71,6 +73,7 @@ class CacheConfigTests(TestCase): config = {"caches": {"per_cache_factors": {"foo": 3}}} self.config.read_config(config) + self.config.resize_all_caches() self.assertEqual(cache.max_size, 300) @@ -82,6 +85,7 @@ class CacheConfigTests(TestCase): """ config = {"caches": {"per_cache_factors": {"foo": 2}}} self.config.read_config(config, config_dir_path="", data_dir_path="") + self.config.resize_all_caches() cache = LruCache(100) add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) @@ -99,6 +103,7 @@ class CacheConfigTests(TestCase): config = {"caches": {"global_factor": 4}} self.config.read_config(config, config_dir_path="", data_dir_path="") + self.config.resize_all_caches() self.assertEqual(cache.max_size, 400) @@ -110,6 +115,7 @@ class CacheConfigTests(TestCase): """ config = {"caches": {"global_factor": 1.5}} self.config.read_config(config, config_dir_path="", data_dir_path="") + self.config.resize_all_caches() cache = LruCache(100) add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) @@ -128,6 +134,7 @@ class CacheConfigTests(TestCase): "SYNAPSE_CACHE_FACTOR_CACHE_B": 3, } self.config.read_config(config, config_dir_path="", data_dir_path="") + self.config.resize_all_caches() cache_a = LruCache(100) add_resizable_cache("*cache_a*", cache_resize_callback=cache_a.set_cache_factor) @@ -148,6 +155,7 @@ class CacheConfigTests(TestCase): config = {"caches": {"event_cache_size": "10k"}} self.config.read_config(config, config_dir_path="", data_dir_path="") + self.config.resize_all_caches() cache = LruCache( max_size=self.config.event_cache_size, diff --git a/tests/federation/transport/server/__init__.py b/tests/federation/transport/server/__init__.py new file mode 100644 index 0000000000..3a5f22c022 --- /dev/null +++ b/tests/federation/transport/server/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/federation/transport/server/test__base.py b/tests/federation/transport/server/test__base.py new file mode 100644 index 0000000000..ac3695a8cc --- /dev/null +++ b/tests/federation/transport/server/test__base.py @@ -0,0 +1,114 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from http import HTTPStatus +from typing import Dict, List, Tuple + +from synapse.api.errors import Codes +from synapse.federation.transport.server import BaseFederationServlet +from synapse.federation.transport.server._base import Authenticator +from synapse.http.server import JsonResource, cancellable +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util.ratelimitutils import FederationRateLimiter + +from tests import unittest +from tests.http.server._base import EndpointCancellationTestHelperMixin + + +class CancellableFederationServlet(BaseFederationServlet): + PATH = "/sleep" + + def __init__( + self, + hs: HomeServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + super().__init__(hs, authenticator, ratelimiter, server_name) + self.clock = hs.get_clock() + + @cancellable + async def on_GET( + self, origin: str, content: None, query: Dict[bytes, List[bytes]] + ) -> Tuple[int, JsonDict]: + await self.clock.sleep(1.0) + return HTTPStatus.OK, {"result": True} + + async def on_POST( + self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]] + ) -> Tuple[int, JsonDict]: + await self.clock.sleep(1.0) + return HTTPStatus.OK, {"result": True} + + +class BaseFederationServletCancellationTests( + unittest.FederatingHomeserverTestCase, EndpointCancellationTestHelperMixin +): + """Tests for `BaseFederationServlet` cancellation.""" + + skip = "`BaseFederationServlet` does not support cancellation yet." + + path = f"{CancellableFederationServlet.PREFIX}{CancellableFederationServlet.PATH}" + + def create_test_resource(self): + """Overrides `HomeserverTestCase.create_test_resource`.""" + resource = JsonResource(self.hs) + + CancellableFederationServlet( + hs=self.hs, + authenticator=Authenticator(self.hs), + ratelimiter=self.hs.get_federation_ratelimiter(), + server_name=self.hs.hostname, + ).register(resource) + + return resource + + def test_cancellable_disconnect(self) -> None: + """Test that handlers with the `@cancellable` flag can be cancelled.""" + channel = self.make_signed_federation_request( + "GET", self.path, await_result=False + ) + + # Advance past all the rate limiting logic. If we disconnect too early, the + # request won't be processed. + self.pump() + + self._test_disconnect( + self.reactor, + channel, + expect_cancellation=True, + expected_body={"error": "Request cancelled", "errcode": Codes.UNKNOWN}, + ) + + def test_uncancellable_disconnect(self) -> None: + """Test that handlers without the `@cancellable` flag cannot be cancelled.""" + channel = self.make_signed_federation_request( + "POST", + self.path, + content={}, + await_result=False, + ) + + # Advance past all the rate limiting logic. If we disconnect too early, the + # request won't be processed. + self.pump() + + self._test_disconnect( + self.reactor, + channel, + expect_cancellation=False, + expected_body={"result": True}, + ) diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py index 489ba57736..e64b28f28b 100644 --- a/tests/handlers/test_federation_event.py +++ b/tests/handlers/test_federation_event.py @@ -148,7 +148,9 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): prev_event.internal_metadata.outlier = True persistence = self.hs.get_storage().persistence self.get_success( - persistence.persist_event(prev_event, EventContext.for_outlier()) + persistence.persist_event( + prev_event, EventContext.for_outlier(self.hs.get_storage()) + ) ) else: diff --git a/tests/http/server/__init__.py b/tests/http/server/__init__.py new file mode 100644 index 0000000000..3a5f22c022 --- /dev/null +++ b/tests/http/server/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/http/server/_base.py b/tests/http/server/_base.py new file mode 100644 index 0000000000..b9f1a381aa --- /dev/null +++ b/tests/http/server/_base.py @@ -0,0 +1,100 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unles4s required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from http import HTTPStatus +from typing import Any, Callable, Optional, Union +from unittest import mock + +from twisted.internet.error import ConnectionDone + +from synapse.http.server import ( + HTTP_STATUS_REQUEST_CANCELLED, + respond_with_html_bytes, + respond_with_json, +) +from synapse.types import JsonDict + +from tests import unittest +from tests.server import FakeChannel, ThreadedMemoryReactorClock + + +class EndpointCancellationTestHelperMixin(unittest.TestCase): + """Provides helper methods for testing cancellation of endpoints.""" + + def _test_disconnect( + self, + reactor: ThreadedMemoryReactorClock, + channel: FakeChannel, + expect_cancellation: bool, + expected_body: Union[bytes, JsonDict], + expected_code: Optional[int] = None, + ) -> None: + """Disconnects an in-flight request and checks the response. + + Args: + reactor: The twisted reactor running the request handler. + channel: The `FakeChannel` for the request. + expect_cancellation: `True` if request processing is expected to be + cancelled, `False` if the request should run to completion. + expected_body: The expected response for the request. + expected_code: The expected status code for the request. Defaults to `200` + or `499` depending on `expect_cancellation`. + """ + # Determine the expected status code. + if expected_code is None: + if expect_cancellation: + expected_code = HTTP_STATUS_REQUEST_CANCELLED + else: + expected_code = HTTPStatus.OK + + request = channel.request + self.assertFalse( + channel.is_finished(), + "Request finished before we could disconnect - " + "was `await_result=False` passed to `make_request`?", + ) + + # We're about to disconnect the request. This also disconnects the channel, so + # we have to rely on mocks to extract the response. + respond_method: Callable[..., Any] + if isinstance(expected_body, bytes): + respond_method = respond_with_html_bytes + else: + respond_method = respond_with_json + + with mock.patch( + f"synapse.http.server.{respond_method.__name__}", wraps=respond_method + ) as respond_mock: + # Disconnect the request. + request.connectionLost(reason=ConnectionDone()) + + if expect_cancellation: + # An immediate cancellation is expected. + respond_mock.assert_called_once() + args, _kwargs = respond_mock.call_args + code, body = args[1], args[2] + self.assertEqual(code, expected_code) + self.assertEqual(request.code, expected_code) + self.assertEqual(body, expected_body) + else: + respond_mock.assert_not_called() + + # The handler is expected to run to completion. + reactor.pump([1.0]) + respond_mock.assert_called_once() + args, _kwargs = respond_mock.call_args + code, body = args[1], args[2] + self.assertEqual(code, expected_code) + self.assertEqual(request.code, expected_code) + self.assertEqual(body, expected_body) diff --git a/tests/http/test_servlet.py b/tests/http/test_servlet.py index a80bfb9f4e..ad521525cf 100644 --- a/tests/http/test_servlet.py +++ b/tests/http/test_servlet.py @@ -12,16 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. import json +from http import HTTPStatus from io import BytesIO +from typing import Tuple from unittest.mock import Mock -from synapse.api.errors import SynapseError +from synapse.api.errors import Codes, SynapseError +from synapse.http.server import cancellable from synapse.http.servlet import ( + RestServlet, parse_json_object_from_request, parse_json_value_from_request, ) +from synapse.http.site import SynapseRequest +from synapse.rest.client._base import client_patterns +from synapse.server import HomeServer +from synapse.types import JsonDict from tests import unittest +from tests.http.server._base import EndpointCancellationTestHelperMixin def make_request(content): @@ -76,3 +85,52 @@ class TestServletUtils(unittest.TestCase): # Test not an object with self.assertRaises(SynapseError): parse_json_object_from_request(make_request(b'["foo"]')) + + +class CancellableRestServlet(RestServlet): + """A `RestServlet` with a mix of cancellable and uncancellable handlers.""" + + PATTERNS = client_patterns("/sleep$") + + def __init__(self, hs: HomeServer): + super().__init__() + self.clock = hs.get_clock() + + @cancellable + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + await self.clock.sleep(1.0) + return HTTPStatus.OK, {"result": True} + + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + await self.clock.sleep(1.0) + return HTTPStatus.OK, {"result": True} + + +class TestRestServletCancellation( + unittest.HomeserverTestCase, EndpointCancellationTestHelperMixin +): + """Tests for `RestServlet` cancellation.""" + + servlets = [ + lambda hs, http_server: CancellableRestServlet(hs).register(http_server) + ] + + def test_cancellable_disconnect(self) -> None: + """Test that handlers with the `@cancellable` flag can be cancelled.""" + channel = self.make_request("GET", "/sleep", await_result=False) + self._test_disconnect( + self.reactor, + channel, + expect_cancellation=True, + expected_body={"error": "Request cancelled", "errcode": Codes.UNKNOWN}, + ) + + def test_uncancellable_disconnect(self) -> None: + """Test that handlers without the `@cancellable` flag cannot be cancelled.""" + channel = self.make_request("POST", "/sleep", await_result=False) + self._test_disconnect( + self.reactor, + channel, + expect_cancellation=False, + expected_body={"result": True}, + ) diff --git a/tests/replication/http/__init__.py b/tests/replication/http/__init__.py new file mode 100644 index 0000000000..3a5f22c022 --- /dev/null +++ b/tests/replication/http/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/replication/http/test__base.py b/tests/replication/http/test__base.py new file mode 100644 index 0000000000..a5ab093a27 --- /dev/null +++ b/tests/replication/http/test__base.py @@ -0,0 +1,106 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from http import HTTPStatus +from typing import Tuple + +from twisted.web.server import Request + +from synapse.api.errors import Codes +from synapse.http.server import JsonResource, cancellable +from synapse.replication.http import REPLICATION_PREFIX +from synapse.replication.http._base import ReplicationEndpoint +from synapse.server import HomeServer +from synapse.types import JsonDict + +from tests import unittest +from tests.http.server._base import EndpointCancellationTestHelperMixin + + +class CancellableReplicationEndpoint(ReplicationEndpoint): + NAME = "cancellable_sleep" + PATH_ARGS = () + CACHE = False + + def __init__(self, hs: HomeServer): + super().__init__(hs) + self.clock = hs.get_clock() + + @staticmethod + async def _serialize_payload() -> JsonDict: + return {} + + @cancellable + async def _handle_request( # type: ignore[override] + self, request: Request + ) -> Tuple[int, JsonDict]: + await self.clock.sleep(1.0) + return HTTPStatus.OK, {"result": True} + + +class UncancellableReplicationEndpoint(ReplicationEndpoint): + NAME = "uncancellable_sleep" + PATH_ARGS = () + CACHE = False + + def __init__(self, hs: HomeServer): + super().__init__(hs) + self.clock = hs.get_clock() + + @staticmethod + async def _serialize_payload() -> JsonDict: + return {} + + async def _handle_request( # type: ignore[override] + self, request: Request + ) -> Tuple[int, JsonDict]: + await self.clock.sleep(1.0) + return HTTPStatus.OK, {"result": True} + + +class ReplicationEndpointCancellationTestCase( + unittest.HomeserverTestCase, EndpointCancellationTestHelperMixin +): + """Tests for `ReplicationEndpoint` cancellation.""" + + def create_test_resource(self): + """Overrides `HomeserverTestCase.create_test_resource`.""" + resource = JsonResource(self.hs) + + CancellableReplicationEndpoint(self.hs).register(resource) + UncancellableReplicationEndpoint(self.hs).register(resource) + + return resource + + def test_cancellable_disconnect(self) -> None: + """Test that handlers with the `@cancellable` flag can be cancelled.""" + path = f"{REPLICATION_PREFIX}/{CancellableReplicationEndpoint.NAME}/" + channel = self.make_request("POST", path, await_result=False) + self._test_disconnect( + self.reactor, + channel, + expect_cancellation=True, + expected_body={"error": "Request cancelled", "errcode": Codes.UNKNOWN}, + ) + + def test_uncancellable_disconnect(self) -> None: + """Test that handlers without the `@cancellable` flag cannot be cancelled.""" + path = f"{REPLICATION_PREFIX}/{UncancellableReplicationEndpoint.NAME}/" + channel = self.make_request("POST", path, await_result=False) + self._test_disconnect( + self.reactor, + channel, + expect_cancellation=False, + expected_body={"result": True}, + ) diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py index 5f142e84c3..a7ca68069e 100644 --- a/tests/replication/test_sharded_event_persister.py +++ b/tests/replication/test_sharded_event_persister.py @@ -14,7 +14,6 @@ import logging from unittest.mock import patch -from synapse.api.room_versions import RoomVersion from synapse.rest import admin from synapse.rest.client import login, room, sync from synapse.storage.util.id_generators import MultiWriterIdGenerator @@ -64,21 +63,10 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): # We control the room ID generation by patching out the # `_generate_room_id` method - async def generate_room( - creator_id: str, is_public: bool, room_version: RoomVersion - ): - await self.store.store_room( - room_id=room_id, - room_creator_user_id=creator_id, - is_public=is_public, - room_version=room_version, - ) - return room_id - with patch( "synapse.handlers.room.RoomCreationHandler._generate_room_id" ) as mock: - mock.side_effect = generate_room + mock.side_effect = lambda: room_id self.helper.create_room_as(user_id, tok=tok) def test_basic(self): diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 9443daa056..d0197aca94 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -925,7 +925,7 @@ class RoomJoinTestCase(RoomBase): ) -> bool: return return_value - callback_mock = Mock(side_effect=user_may_join_room) + callback_mock = Mock(side_effect=user_may_join_room, spec=lambda *x: None) self.hs.get_spam_checker()._user_may_join_room_callbacks.append(callback_mock) # Join a first room, without being invited to it. @@ -1116,6 +1116,264 @@ class RoomMessagesTestCase(RoomBase): self.assertEqual(200, channel.code, msg=channel.result["body"]) +class RoomPowerLevelOverridesTestCase(RoomBase): + """Tests that the power levels can be overridden with server config.""" + + user_id = "@sid1:red" + + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.admin_user_id = self.register_user("admin", "pass") + self.admin_access_token = self.login("admin", "pass") + + def power_levels(self, room_id: str) -> Dict[str, Any]: + return self.helper.get_state( + room_id, "m.room.power_levels", self.admin_access_token + ) + + def test_default_power_levels_with_room_override(self) -> None: + """ + Create a room, providing power level overrides. + Confirm that the room's power levels reflect the overrides. + + See https://github.com/matrix-org/matrix-spec/issues/492 + - currently we overwrite each key of power_level_content_override + completely. + """ + + room_id = self.helper.create_room_as( + self.user_id, + extra_content={ + "power_level_content_override": {"events": {"custom.event": 0}} + }, + ) + self.assertEqual( + { + "custom.event": 0, + }, + self.power_levels(room_id)["events"], + ) + + @unittest.override_config( + { + "default_power_level_content_override": { + "public_chat": {"events": {"custom.event": 0}}, + } + }, + ) + def test_power_levels_with_server_override(self) -> None: + """ + With a server configured to modify the room-level defaults, + Create a room, without providing any extra power level overrides. + Confirm that the room's power levels reflect the server-level overrides. + + Similar to https://github.com/matrix-org/matrix-spec/issues/492, + we overwrite each key of power_level_content_override completely. + """ + + room_id = self.helper.create_room_as(self.user_id) + self.assertEqual( + { + "custom.event": 0, + }, + self.power_levels(room_id)["events"], + ) + + @unittest.override_config( + { + "default_power_level_content_override": { + "public_chat": { + "events": {"server.event": 0}, + "ban": 13, + }, + } + }, + ) + def test_power_levels_with_server_and_room_overrides(self) -> None: + """ + With a server configured to modify the room-level defaults, + create a room, providing different overrides. + Confirm that the room's power levels reflect both overrides, and + choose the room overrides where they clash. + """ + + room_id = self.helper.create_room_as( + self.user_id, + extra_content={ + "power_level_content_override": {"events": {"room.event": 0}} + }, + ) + + # Room override wins over server config + self.assertEqual( + {"room.event": 0}, + self.power_levels(room_id)["events"], + ) + + # But where there is no room override, server config wins + self.assertEqual(13, self.power_levels(room_id)["ban"]) + + +class RoomPowerLevelOverridesInPracticeTestCase(RoomBase): + """ + Tests that we can really do various otherwise-prohibited actions + based on overriding the power levels in config. + """ + + user_id = "@sid1:red" + + def test_creator_can_post_state_event(self) -> None: + # Given I am the creator of a room + room_id = self.helper.create_room_as(self.user_id) + + # When I send a state event + path = "/rooms/{room_id}/state/custom.event/my_state_key".format( + room_id=urlparse.quote(room_id), + ) + channel = self.make_request("PUT", path, "{}") + + # Then I am allowed + self.assertEqual(200, channel.code, msg=channel.result["body"]) + + def test_normal_user_can_not_post_state_event(self) -> None: + # Given I am a normal member of a room + room_id = self.helper.create_room_as("@some_other_guy:red") + self.helper.join(room=room_id, user=self.user_id) + + # When I send a state event + path = "/rooms/{room_id}/state/custom.event/my_state_key".format( + room_id=urlparse.quote(room_id), + ) + channel = self.make_request("PUT", path, "{}") + + # Then I am not allowed because state events require PL>=50 + self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual( + "You don't have permission to post that to the room. " + "user_level (0) < send_level (50)", + channel.json_body["error"], + ) + + @unittest.override_config( + { + "default_power_level_content_override": { + "public_chat": {"events": {"custom.event": 0}}, + } + }, + ) + def test_with_config_override_normal_user_can_post_state_event(self) -> None: + # Given the server has config allowing normal users to post my event type, + # and I am a normal member of a room + room_id = self.helper.create_room_as("@some_other_guy:red") + self.helper.join(room=room_id, user=self.user_id) + + # When I send a state event + path = "/rooms/{room_id}/state/custom.event/my_state_key".format( + room_id=urlparse.quote(room_id), + ) + channel = self.make_request("PUT", path, "{}") + + # Then I am allowed + self.assertEqual(200, channel.code, msg=channel.result["body"]) + + @unittest.override_config( + { + "default_power_level_content_override": { + "public_chat": {"events": {"custom.event": 0}}, + } + }, + ) + def test_any_room_override_defeats_config_override(self) -> None: + # Given the server has config allowing normal users to post my event type + # And I am a normal member of a room + # But the room was created with special permissions + extra_content: Dict[str, Any] = { + "power_level_content_override": {"events": {}}, + } + room_id = self.helper.create_room_as( + "@some_other_guy:red", extra_content=extra_content + ) + self.helper.join(room=room_id, user=self.user_id) + + # When I send a state event + path = "/rooms/{room_id}/state/custom.event/my_state_key".format( + room_id=urlparse.quote(room_id), + ) + channel = self.make_request("PUT", path, "{}") + + # Then I am not allowed + self.assertEqual(403, channel.code, msg=channel.result["body"]) + + @unittest.override_config( + { + "default_power_level_content_override": { + "public_chat": {"events": {"custom.event": 0}}, + } + }, + ) + def test_specific_room_override_defeats_config_override(self) -> None: + # Given the server has config allowing normal users to post my event type, + # and I am a normal member of a room, + # but the room was created with special permissions for this event type + extra_content = { + "power_level_content_override": {"events": {"custom.event": 1}}, + } + room_id = self.helper.create_room_as( + "@some_other_guy:red", extra_content=extra_content + ) + self.helper.join(room=room_id, user=self.user_id) + + # When I send a state event + path = "/rooms/{room_id}/state/custom.event/my_state_key".format( + room_id=urlparse.quote(room_id), + ) + channel = self.make_request("PUT", path, "{}") + + # Then I am not allowed + self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual( + "You don't have permission to post that to the room. " + + "user_level (0) < send_level (1)", + channel.json_body["error"], + ) + + @unittest.override_config( + { + "default_power_level_content_override": { + "public_chat": {"events": {"custom.event": 0}}, + "private_chat": None, + "trusted_private_chat": None, + } + }, + ) + def test_config_override_applies_only_to_specific_preset(self) -> None: + # Given the server has config for public_chats, + # and I am a normal member of a private_chat room + room_id = self.helper.create_room_as("@some_other_guy:red", is_public=False) + self.helper.invite(room=room_id, src="@some_other_guy:red", targ=self.user_id) + self.helper.join(room=room_id, user=self.user_id) + + # When I send a state event + path = "/rooms/{room_id}/state/custom.event/my_state_key".format( + room_id=urlparse.quote(room_id), + ) + channel = self.make_request("PUT", path, "{}") + + # Then I am not allowed because the public_chat config does not + # affect this room, because this room is a private_chat + self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual( + "You don't have permission to post that to the room. " + + "user_level (0) < send_level (50)", + channel.json_body["error"], + ) + + class RoomInitialSyncTestCase(RoomBase): """Tests /rooms/$room_id/initialSync.""" @@ -2598,7 +2856,9 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase): # Add a mock to the spamchecker callbacks for user_may_send_3pid_invite. Make it # allow everything for now. - mock = Mock(return_value=make_awaitable(True)) + # `spec` argument is needed for this function mock to have `__qualname__`, which + # is needed for `Measure` metrics buried in SpamChecker. + mock = Mock(return_value=make_awaitable(True), spec=lambda *x: None) self.hs.get_spam_checker()._user_may_send_3pid_invite_callbacks.append(mock) # Send a 3PID invite into the room and check that it succeeded. diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py index 0108337649..74b6560cbc 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import json +from http import HTTPStatus from typing import List, Optional from parameterized import parameterized @@ -485,30 +486,7 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): # Test that we didn't override the public read receipt self.assertIsNone(self._get_read_receipt()) - @parameterized.expand( - [ - # Old Element version, expected to send an empty body - ( - "agent1", - "Element/1.2.2 (Linux; U; Android 9; MatrixAndroidSDK_X 0.0.1)", - 200, - ), - # Old SchildiChat version, expected to send an empty body - ("agent2", "SchildiChat/1.2.1 (Android 10)", 200), - # Expected 400: Denies empty body starting at version 1.3+ - ("agent3", "Element/1.3.6 (Android 10)", 400), - ("agent4", "SchildiChat/1.3.6 (Android 11)", 400), - # Contains "Riot": Receipts with empty bodies expected - ("agent5", "Element (Riot.im) (Android 9)", 200), - # Expected 400: Does not contain "Android" - ("agent6", "Element/1.2.1", 400), - # Expected 400: Different format, missing "/" after Element; existing build that should allow empty bodies, but minimal ongoing usage - ("agent7", "Element dbg/1.1.8-dev (Android)", 400), - ] - ) - def test_read_receipt_with_empty_body( - self, name: str, user_agent: str, expected_status_code: int - ) -> None: + def test_read_receipt_with_empty_body_is_rejected(self) -> None: # Send a message as the first user res = self.helper.send(self.room_id, body="hello", tok=self.tok) @@ -517,9 +495,9 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): "POST", f"/rooms/{self.room_id}/receipt/m.read/{res['event_id']}", access_token=self.tok2, - custom_headers=[("User-Agent", user_agent)], ) - self.assertEqual(channel.code, expected_status_code) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST) + self.assertEqual(channel.json_body["errcode"], "M_NOT_JSON", channel.json_body) def _get_read_receipt(self) -> Optional[JsonDict]: """Syncs and returns the read receipt.""" @@ -678,12 +656,13 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): self._check_unread_count(3) # Check that custom events with a body increase the unread counter. - self.helper.send_event( + result = self.helper.send_event( self.room_id, "org.matrix.custom_type", {"body": "hello"}, tok=self.tok2, ) + event_id = result["event_id"] self._check_unread_count(4) # Check that edits don't increase the unread counter. @@ -693,7 +672,10 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): content={ "body": "hello", "msgtype": "m.text", - "m.relates_to": {"rel_type": RelationTypes.REPLACE}, + "m.relates_to": { + "rel_type": RelationTypes.REPLACE, + "event_id": event_id, + }, }, tok=self.tok2, ) diff --git a/tests/server.py b/tests/server.py index 8f30e250c8..b9f465971f 100644 --- a/tests/server.py +++ b/tests/server.py @@ -109,6 +109,17 @@ class FakeChannel: _ip: str = "127.0.0.1" _producer: Optional[Union[IPullProducer, IPushProducer]] = None resource_usage: Optional[ContextResourceUsage] = None + _request: Optional[Request] = None + + @property + def request(self) -> Request: + assert self._request is not None + return self._request + + @request.setter + def request(self, request: Request) -> None: + assert self._request is None + self._request = request @property def json_body(self): @@ -322,6 +333,8 @@ def make_request( channel = FakeChannel(site, reactor, ip=client_ip) req = request(channel, site) + channel.request = req + req.content = BytesIO(content) # Twisted expects to be at the end of the content when parsing the request. req.content.seek(0, SEEK_END) @@ -736,6 +749,7 @@ def setup_test_homeserver( if config is None: config = default_config(name, parse=True) + config.caches.resize_all_caches() config.ldap_enabled = False if "clock" not in kwargs: diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 9ee9509d3a..07e29788e5 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -75,6 +75,9 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): self._rlsn._server_notices_manager.get_or_create_notice_room_for_user = Mock( return_value=make_awaitable("!something:localhost") ) + self._rlsn._server_notices_manager.maybe_get_notice_room_for_user = Mock( + return_value=make_awaitable("!something:localhost") + ) self._rlsn._store.add_tag_to_room = Mock(return_value=make_awaitable(None)) self._rlsn._store.get_tags_for_room = Mock(return_value=make_awaitable({})) @@ -102,6 +105,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) # Would be better to check the content, but once == remove blocking event + self._rlsn._server_notices_manager.maybe_get_notice_room_for_user.assert_called_once() self._send_notice.assert_called_once() def test_maybe_send_server_notice_to_user_remove_blocked_notice_noop(self): @@ -300,7 +304,10 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): hasn't been reached (since it's the only user and the limit is 5), so users shouldn't receive a server notice. """ - self.register_user("user", "password") + m = Mock(return_value=make_awaitable(None)) + self._rlsn._server_notices_manager.maybe_get_notice_room_for_user = m + + user_id = self.register_user("user", "password") tok = self.login("user", "password") channel = self.make_request("GET", "/sync?timeout=0", access_token=tok) @@ -309,6 +316,8 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): "rooms", channel.json_body, "Got invites without server notice" ) + m.assert_called_once_with(user_id) + def test_invite_with_notice(self): """Tests that, if the MAU limit is hit, the server notices user invites each user to a room in which it has sent a notice. diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py index 401020fd63..c7661e7186 100644 --- a/tests/storage/test_event_chain.py +++ b/tests/storage/test_event_chain.py @@ -393,7 +393,7 @@ class EventChainStoreTestCase(HomeserverTestCase): # We need to persist the events to the events and state_events # tables. persist_events_store._store_event_txn( - txn, [(e, EventContext()) for e in events] + txn, [(e, EventContext(self.hs.get_storage())) for e in events] ) # Actually call the function that calculates the auth chain stuff. diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 645d564d1c..d92a9ac5b7 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -58,15 +58,6 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): (room_id, event_id), ) - txn.execute( - ( - "INSERT INTO event_reference_hashes " - "(event_id, algorithm, hash) " - "VALUES (?, 'sha256', ?)" - ), - (event_id, bytearray(b"ffff")), - ) - for i in range(0, 20): self.get_success( self.store.db_pool.runInteraction("insert", insert_event, i) diff --git a/tests/test_server.py b/tests/test_server.py index f2ffbc895b..0f1eb43cbc 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -13,18 +13,28 @@ # limitations under the License. import re +from http import HTTPStatus +from typing import Tuple from twisted.internet.defer import Deferred from twisted.web.resource import Resource from synapse.api.errors import Codes, RedirectException, SynapseError from synapse.config.server import parse_listener_def -from synapse.http.server import DirectServeHtmlResource, JsonResource, OptionsResource -from synapse.http.site import SynapseSite +from synapse.http.server import ( + DirectServeHtmlResource, + DirectServeJsonResource, + JsonResource, + OptionsResource, + cancellable, +) +from synapse.http.site import SynapseRequest, SynapseSite from synapse.logging.context import make_deferred_yieldable +from synapse.types import JsonDict from synapse.util import Clock from tests import unittest +from tests.http.server._base import EndpointCancellationTestHelperMixin from tests.server import ( FakeSite, ThreadedMemoryReactorClock, @@ -363,3 +373,100 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase): self.assertEqual(channel.result["code"], b"200") self.assertNotIn("body", channel.result) + + +class CancellableDirectServeJsonResource(DirectServeJsonResource): + def __init__(self, clock: Clock): + super().__init__() + self.clock = clock + + @cancellable + async def _async_render_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + await self.clock.sleep(1.0) + return HTTPStatus.OK, {"result": True} + + async def _async_render_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + await self.clock.sleep(1.0) + return HTTPStatus.OK, {"result": True} + + +class CancellableDirectServeHtmlResource(DirectServeHtmlResource): + ERROR_TEMPLATE = "{code} {msg}" + + def __init__(self, clock: Clock): + super().__init__() + self.clock = clock + + @cancellable + async def _async_render_GET(self, request: SynapseRequest) -> Tuple[int, bytes]: + await self.clock.sleep(1.0) + return HTTPStatus.OK, b"ok" + + async def _async_render_POST(self, request: SynapseRequest) -> Tuple[int, bytes]: + await self.clock.sleep(1.0) + return HTTPStatus.OK, b"ok" + + +class DirectServeJsonResourceCancellationTests(EndpointCancellationTestHelperMixin): + """Tests for `DirectServeJsonResource` cancellation.""" + + def setUp(self): + self.reactor = ThreadedMemoryReactorClock() + self.clock = Clock(self.reactor) + self.resource = CancellableDirectServeJsonResource(self.clock) + self.site = FakeSite(self.resource, self.reactor) + + def test_cancellable_disconnect(self) -> None: + """Test that handlers with the `@cancellable` flag can be cancelled.""" + channel = make_request( + self.reactor, self.site, "GET", "/sleep", await_result=False + ) + self._test_disconnect( + self.reactor, + channel, + expect_cancellation=True, + expected_body={"error": "Request cancelled", "errcode": Codes.UNKNOWN}, + ) + + def test_uncancellable_disconnect(self) -> None: + """Test that handlers without the `@cancellable` flag cannot be cancelled.""" + channel = make_request( + self.reactor, self.site, "POST", "/sleep", await_result=False + ) + self._test_disconnect( + self.reactor, + channel, + expect_cancellation=False, + expected_body={"result": True}, + ) + + +class DirectServeHtmlResourceCancellationTests(EndpointCancellationTestHelperMixin): + """Tests for `DirectServeHtmlResource` cancellation.""" + + def setUp(self): + self.reactor = ThreadedMemoryReactorClock() + self.clock = Clock(self.reactor) + self.resource = CancellableDirectServeHtmlResource(self.clock) + self.site = FakeSite(self.resource, self.reactor) + + def test_cancellable_disconnect(self) -> None: + """Test that handlers with the `@cancellable` flag can be cancelled.""" + channel = make_request( + self.reactor, self.site, "GET", "/sleep", await_result=False + ) + self._test_disconnect( + self.reactor, + channel, + expect_cancellation=True, + expected_body=b"499 Request cancelled", + ) + + def test_uncancellable_disconnect(self) -> None: + """Test that handlers without the `@cancellable` flag cannot be cancelled.""" + channel = make_request( + self.reactor, self.site, "POST", "/sleep", await_result=False + ) + self._test_disconnect( + self.reactor, channel, expect_cancellation=False, expected_body=b"ok" + ) diff --git a/tests/test_state.py b/tests/test_state.py index e4baa69137..651ec1c7d4 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -88,6 +88,9 @@ class _DummyStore: return groups + async def get_state_ids_for_group(self, state_group): + return self._group_to_state[state_group] + async def store_state_group( self, event_id, room_id, prev_group, delta_ids, current_state_ids ): diff --git a/tests/test_visibility.py b/tests/test_visibility.py index d0230f9ebb..7a9b01ef9d 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -234,7 +234,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): event = self.get_success(builder.build(prev_event_ids=[], auth_event_ids=[])) event.internal_metadata.outlier = True self.get_success( - self.storage.persistence.persist_event(event, EventContext.for_outlier()) + self.storage.persistence.persist_event( + event, EventContext.for_outlier(self.storage) + ) ) return event diff --git a/tests/unittest.py b/tests/unittest.py index 9afa68c164..e7f255b4fa 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -831,7 +831,7 @@ class FederatingHomeserverTestCase(HomeserverTestCase): self.site, method=method, path=path, - content=content or "", + content=content if content is not None else "", shorthand=False, await_result=await_result, custom_headers=custom_headers, diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py index 321fc1776f..67173a4f5b 100644 --- a/tests/util/test_lrucache.py +++ b/tests/util/test_lrucache.py @@ -14,8 +14,9 @@ from typing import List -from unittest.mock import Mock +from unittest.mock import Mock, patch +from synapse.metrics.jemalloc import JemallocStats from synapse.util.caches.lrucache import LruCache, setup_expire_lru_cache_entries from synapse.util.caches.treecache import TreeCache @@ -316,3 +317,58 @@ class TimeEvictionTestCase(unittest.HomeserverTestCase): self.assertEqual(cache.get("key1"), None) self.assertEqual(cache.get("key2"), 3) + + +class MemoryEvictionTestCase(unittest.HomeserverTestCase): + @override_config( + { + "caches": { + "cache_autotuning": { + "max_cache_memory_usage": "700M", + "target_cache_memory_usage": "500M", + "min_cache_ttl": "5m", + } + } + } + ) + @patch("synapse.util.caches.lrucache.get_jemalloc_stats") + def test_evict_memory(self, jemalloc_interface) -> None: + mock_jemalloc_class = Mock(spec=JemallocStats) + jemalloc_interface.return_value = mock_jemalloc_class + + # set the return value of get_stat() to be greater than max_cache_memory_usage + mock_jemalloc_class.get_stat.return_value = 924288000 + + setup_expire_lru_cache_entries(self.hs) + cache = LruCache(4, clock=self.hs.get_clock()) + + cache["key1"] = 1 + cache["key2"] = 2 + + # advance the reactor less than the min_cache_ttl + self.reactor.advance(60 * 2) + + # our items should still be in the cache + self.assertEqual(cache.get("key1"), 1) + self.assertEqual(cache.get("key2"), 2) + + # advance the reactor past the min_cache_ttl + self.reactor.advance(60 * 6) + + # the items should be cleared from cache + self.assertEqual(cache.get("key1"), None) + self.assertEqual(cache.get("key2"), None) + + # add more stuff to caches + cache["key1"] = 1 + cache["key2"] = 2 + + # set the return value of get_stat() to be lower than target_cache_memory_usage + mock_jemalloc_class.get_stat.return_value = 10000 + + # advance the reactor past the min_cache_ttl + self.reactor.advance(60 * 6) + + # the items should still be in the cache + self.assertEqual(cache.get("key1"), 1) + self.assertEqual(cache.get("key2"), 2) |