diff --git a/tests/config/test_cache.py b/tests/config/test_cache.py
index 4bb82e810e..d2b3c299e3 100644
--- a/tests/config/test_cache.py
+++ b/tests/config/test_cache.py
@@ -38,6 +38,7 @@ class CacheConfigTests(TestCase):
"SYNAPSE_NOT_CACHE": "BLAH",
}
self.config.read_config(config, config_dir_path="", data_dir_path="")
+ self.config.resize_all_caches()
self.assertEqual(dict(self.config.cache_factors), {"something_or_other": 2.0})
@@ -52,6 +53,7 @@ class CacheConfigTests(TestCase):
"SYNAPSE_CACHE_FACTOR_FOO": 1,
}
self.config.read_config(config, config_dir_path="", data_dir_path="")
+ self.config.resize_all_caches()
self.assertEqual(
dict(self.config.cache_factors),
@@ -71,6 +73,7 @@ class CacheConfigTests(TestCase):
config = {"caches": {"per_cache_factors": {"foo": 3}}}
self.config.read_config(config)
+ self.config.resize_all_caches()
self.assertEqual(cache.max_size, 300)
@@ -82,6 +85,7 @@ class CacheConfigTests(TestCase):
"""
config = {"caches": {"per_cache_factors": {"foo": 2}}}
self.config.read_config(config, config_dir_path="", data_dir_path="")
+ self.config.resize_all_caches()
cache = LruCache(100)
add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
@@ -99,6 +103,7 @@ class CacheConfigTests(TestCase):
config = {"caches": {"global_factor": 4}}
self.config.read_config(config, config_dir_path="", data_dir_path="")
+ self.config.resize_all_caches()
self.assertEqual(cache.max_size, 400)
@@ -110,6 +115,7 @@ class CacheConfigTests(TestCase):
"""
config = {"caches": {"global_factor": 1.5}}
self.config.read_config(config, config_dir_path="", data_dir_path="")
+ self.config.resize_all_caches()
cache = LruCache(100)
add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
@@ -128,6 +134,7 @@ class CacheConfigTests(TestCase):
"SYNAPSE_CACHE_FACTOR_CACHE_B": 3,
}
self.config.read_config(config, config_dir_path="", data_dir_path="")
+ self.config.resize_all_caches()
cache_a = LruCache(100)
add_resizable_cache("*cache_a*", cache_resize_callback=cache_a.set_cache_factor)
@@ -148,6 +155,7 @@ class CacheConfigTests(TestCase):
config = {"caches": {"event_cache_size": "10k"}}
self.config.read_config(config, config_dir_path="", data_dir_path="")
+ self.config.resize_all_caches()
cache = LruCache(
max_size=self.config.event_cache_size,
diff --git a/tests/federation/transport/server/__init__.py b/tests/federation/transport/server/__init__.py
new file mode 100644
index 0000000000..3a5f22c022
--- /dev/null
+++ b/tests/federation/transport/server/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/federation/transport/server/test__base.py b/tests/federation/transport/server/test__base.py
new file mode 100644
index 0000000000..ac3695a8cc
--- /dev/null
+++ b/tests/federation/transport/server/test__base.py
@@ -0,0 +1,114 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from http import HTTPStatus
+from typing import Dict, List, Tuple
+
+from synapse.api.errors import Codes
+from synapse.federation.transport.server import BaseFederationServlet
+from synapse.federation.transport.server._base import Authenticator
+from synapse.http.server import JsonResource, cancellable
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+from synapse.util.ratelimitutils import FederationRateLimiter
+
+from tests import unittest
+from tests.http.server._base import EndpointCancellationTestHelperMixin
+
+
+class CancellableFederationServlet(BaseFederationServlet):
+ PATH = "/sleep"
+
+ def __init__(
+ self,
+ hs: HomeServer,
+ authenticator: Authenticator,
+ ratelimiter: FederationRateLimiter,
+ server_name: str,
+ ):
+ super().__init__(hs, authenticator, ratelimiter, server_name)
+ self.clock = hs.get_clock()
+
+ @cancellable
+ async def on_GET(
+ self, origin: str, content: None, query: Dict[bytes, List[bytes]]
+ ) -> Tuple[int, JsonDict]:
+ await self.clock.sleep(1.0)
+ return HTTPStatus.OK, {"result": True}
+
+ async def on_POST(
+ self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
+ ) -> Tuple[int, JsonDict]:
+ await self.clock.sleep(1.0)
+ return HTTPStatus.OK, {"result": True}
+
+
+class BaseFederationServletCancellationTests(
+ unittest.FederatingHomeserverTestCase, EndpointCancellationTestHelperMixin
+):
+ """Tests for `BaseFederationServlet` cancellation."""
+
+ skip = "`BaseFederationServlet` does not support cancellation yet."
+
+ path = f"{CancellableFederationServlet.PREFIX}{CancellableFederationServlet.PATH}"
+
+ def create_test_resource(self):
+ """Overrides `HomeserverTestCase.create_test_resource`."""
+ resource = JsonResource(self.hs)
+
+ CancellableFederationServlet(
+ hs=self.hs,
+ authenticator=Authenticator(self.hs),
+ ratelimiter=self.hs.get_federation_ratelimiter(),
+ server_name=self.hs.hostname,
+ ).register(resource)
+
+ return resource
+
+ def test_cancellable_disconnect(self) -> None:
+ """Test that handlers with the `@cancellable` flag can be cancelled."""
+ channel = self.make_signed_federation_request(
+ "GET", self.path, await_result=False
+ )
+
+ # Advance past all the rate limiting logic. If we disconnect too early, the
+ # request won't be processed.
+ self.pump()
+
+ self._test_disconnect(
+ self.reactor,
+ channel,
+ expect_cancellation=True,
+ expected_body={"error": "Request cancelled", "errcode": Codes.UNKNOWN},
+ )
+
+ def test_uncancellable_disconnect(self) -> None:
+ """Test that handlers without the `@cancellable` flag cannot be cancelled."""
+ channel = self.make_signed_federation_request(
+ "POST",
+ self.path,
+ content={},
+ await_result=False,
+ )
+
+ # Advance past all the rate limiting logic. If we disconnect too early, the
+ # request won't be processed.
+ self.pump()
+
+ self._test_disconnect(
+ self.reactor,
+ channel,
+ expect_cancellation=False,
+ expected_body={"result": True},
+ )
diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py
index 489ba57736..e64b28f28b 100644
--- a/tests/handlers/test_federation_event.py
+++ b/tests/handlers/test_federation_event.py
@@ -148,7 +148,9 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
prev_event.internal_metadata.outlier = True
persistence = self.hs.get_storage().persistence
self.get_success(
- persistence.persist_event(prev_event, EventContext.for_outlier())
+ persistence.persist_event(
+ prev_event, EventContext.for_outlier(self.hs.get_storage())
+ )
)
else:
diff --git a/tests/handlers/test_receipts.py b/tests/handlers/test_receipts.py
index 0482a1ea34..78807cdcfc 100644
--- a/tests/handlers/test_receipts.py
+++ b/tests/handlers/test_receipts.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
+from copy import deepcopy
from typing import List
from synapse.api.constants import ReceiptTypes
@@ -125,42 +125,6 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
],
)
- def test_handles_missing_content_of_m_read(self):
- self._test_filters_private(
- [
- {
- "content": {
- "$14356419ggffg114394fHBLK:matrix.org": {ReceiptTypes.READ: {}},
- "$1435641916114394fHBLK:matrix.org": {
- ReceiptTypes.READ: {
- "@user:jki.re": {
- "ts": 1436451550453,
- }
- }
- },
- },
- "room_id": "!jEsUZKDJdhlrceRyVU:example.org",
- "type": "m.receipt",
- }
- ],
- [
- {
- "content": {
- "$14356419ggffg114394fHBLK:matrix.org": {ReceiptTypes.READ: {}},
- "$1435641916114394fHBLK:matrix.org": {
- ReceiptTypes.READ: {
- "@user:jki.re": {
- "ts": 1436451550453,
- }
- }
- },
- },
- "room_id": "!jEsUZKDJdhlrceRyVU:example.org",
- "type": "m.receipt",
- }
- ],
- )
-
def test_handles_empty_event(self):
self._test_filters_private(
[
@@ -332,9 +296,33 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
],
)
+ def test_we_do_not_mutate(self):
+ """Ensure the input values are not modified."""
+ events = [
+ {
+ "content": {
+ "$1435641916114394fHBLK:matrix.org": {
+ ReceiptTypes.READ_PRIVATE: {
+ "@rikj:jki.re": {
+ "ts": 1436451550453,
+ }
+ }
+ }
+ },
+ "room_id": "!jEsUZKDJdhlrceRyVU:example.org",
+ "type": "m.receipt",
+ }
+ ]
+ original_events = deepcopy(events)
+ self._test_filters_private(events, [])
+ # Since the events are fed in from a cache they should not be modified.
+ self.assertEqual(events, original_events)
+
def _test_filters_private(
self, events: List[JsonDict], expected_output: List[JsonDict]
):
"""Tests that the _filter_out_private returns the expected output"""
- filtered_events = self.event_source.filter_out_private(events, "@me:server.org")
+ filtered_events = self.event_source.filter_out_private_receipts(
+ events, "@me:server.org"
+ )
self.assertEqual(filtered_events, expected_output)
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index 865b8b7e47..db3302a4c7 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -160,6 +160,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
# Blow away caches (supported room versions can only change due to a restart).
self.store.get_rooms_for_user_with_stream_ordering.invalidate_all()
self.store._get_event_cache.clear()
+ self.store._event_ref.clear()
# The rooms should be excluded from the sync response.
# Get a new request key.
diff --git a/tests/http/server/__init__.py b/tests/http/server/__init__.py
new file mode 100644
index 0000000000..3a5f22c022
--- /dev/null
+++ b/tests/http/server/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/http/server/_base.py b/tests/http/server/_base.py
new file mode 100644
index 0000000000..b9f1a381aa
--- /dev/null
+++ b/tests/http/server/_base.py
@@ -0,0 +1,100 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unles4s required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from http import HTTPStatus
+from typing import Any, Callable, Optional, Union
+from unittest import mock
+
+from twisted.internet.error import ConnectionDone
+
+from synapse.http.server import (
+ HTTP_STATUS_REQUEST_CANCELLED,
+ respond_with_html_bytes,
+ respond_with_json,
+)
+from synapse.types import JsonDict
+
+from tests import unittest
+from tests.server import FakeChannel, ThreadedMemoryReactorClock
+
+
+class EndpointCancellationTestHelperMixin(unittest.TestCase):
+ """Provides helper methods for testing cancellation of endpoints."""
+
+ def _test_disconnect(
+ self,
+ reactor: ThreadedMemoryReactorClock,
+ channel: FakeChannel,
+ expect_cancellation: bool,
+ expected_body: Union[bytes, JsonDict],
+ expected_code: Optional[int] = None,
+ ) -> None:
+ """Disconnects an in-flight request and checks the response.
+
+ Args:
+ reactor: The twisted reactor running the request handler.
+ channel: The `FakeChannel` for the request.
+ expect_cancellation: `True` if request processing is expected to be
+ cancelled, `False` if the request should run to completion.
+ expected_body: The expected response for the request.
+ expected_code: The expected status code for the request. Defaults to `200`
+ or `499` depending on `expect_cancellation`.
+ """
+ # Determine the expected status code.
+ if expected_code is None:
+ if expect_cancellation:
+ expected_code = HTTP_STATUS_REQUEST_CANCELLED
+ else:
+ expected_code = HTTPStatus.OK
+
+ request = channel.request
+ self.assertFalse(
+ channel.is_finished(),
+ "Request finished before we could disconnect - "
+ "was `await_result=False` passed to `make_request`?",
+ )
+
+ # We're about to disconnect the request. This also disconnects the channel, so
+ # we have to rely on mocks to extract the response.
+ respond_method: Callable[..., Any]
+ if isinstance(expected_body, bytes):
+ respond_method = respond_with_html_bytes
+ else:
+ respond_method = respond_with_json
+
+ with mock.patch(
+ f"synapse.http.server.{respond_method.__name__}", wraps=respond_method
+ ) as respond_mock:
+ # Disconnect the request.
+ request.connectionLost(reason=ConnectionDone())
+
+ if expect_cancellation:
+ # An immediate cancellation is expected.
+ respond_mock.assert_called_once()
+ args, _kwargs = respond_mock.call_args
+ code, body = args[1], args[2]
+ self.assertEqual(code, expected_code)
+ self.assertEqual(request.code, expected_code)
+ self.assertEqual(body, expected_body)
+ else:
+ respond_mock.assert_not_called()
+
+ # The handler is expected to run to completion.
+ reactor.pump([1.0])
+ respond_mock.assert_called_once()
+ args, _kwargs = respond_mock.call_args
+ code, body = args[1], args[2]
+ self.assertEqual(code, expected_code)
+ self.assertEqual(request.code, expected_code)
+ self.assertEqual(body, expected_body)
diff --git a/tests/http/test_servlet.py b/tests/http/test_servlet.py
index a80bfb9f4e..ad521525cf 100644
--- a/tests/http/test_servlet.py
+++ b/tests/http/test_servlet.py
@@ -12,16 +12,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
+from http import HTTPStatus
from io import BytesIO
+from typing import Tuple
from unittest.mock import Mock
-from synapse.api.errors import SynapseError
+from synapse.api.errors import Codes, SynapseError
+from synapse.http.server import cancellable
from synapse.http.servlet import (
+ RestServlet,
parse_json_object_from_request,
parse_json_value_from_request,
)
+from synapse.http.site import SynapseRequest
+from synapse.rest.client._base import client_patterns
+from synapse.server import HomeServer
+from synapse.types import JsonDict
from tests import unittest
+from tests.http.server._base import EndpointCancellationTestHelperMixin
def make_request(content):
@@ -76,3 +85,52 @@ class TestServletUtils(unittest.TestCase):
# Test not an object
with self.assertRaises(SynapseError):
parse_json_object_from_request(make_request(b'["foo"]'))
+
+
+class CancellableRestServlet(RestServlet):
+ """A `RestServlet` with a mix of cancellable and uncancellable handlers."""
+
+ PATTERNS = client_patterns("/sleep$")
+
+ def __init__(self, hs: HomeServer):
+ super().__init__()
+ self.clock = hs.get_clock()
+
+ @cancellable
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+ await self.clock.sleep(1.0)
+ return HTTPStatus.OK, {"result": True}
+
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+ await self.clock.sleep(1.0)
+ return HTTPStatus.OK, {"result": True}
+
+
+class TestRestServletCancellation(
+ unittest.HomeserverTestCase, EndpointCancellationTestHelperMixin
+):
+ """Tests for `RestServlet` cancellation."""
+
+ servlets = [
+ lambda hs, http_server: CancellableRestServlet(hs).register(http_server)
+ ]
+
+ def test_cancellable_disconnect(self) -> None:
+ """Test that handlers with the `@cancellable` flag can be cancelled."""
+ channel = self.make_request("GET", "/sleep", await_result=False)
+ self._test_disconnect(
+ self.reactor,
+ channel,
+ expect_cancellation=True,
+ expected_body={"error": "Request cancelled", "errcode": Codes.UNKNOWN},
+ )
+
+ def test_uncancellable_disconnect(self) -> None:
+ """Test that handlers without the `@cancellable` flag cannot be cancelled."""
+ channel = self.make_request("POST", "/sleep", await_result=False)
+ self._test_disconnect(
+ self.reactor,
+ channel,
+ expect_cancellation=False,
+ expected_body={"result": True},
+ )
diff --git a/tests/replication/http/__init__.py b/tests/replication/http/__init__.py
new file mode 100644
index 0000000000..3a5f22c022
--- /dev/null
+++ b/tests/replication/http/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/replication/http/test__base.py b/tests/replication/http/test__base.py
new file mode 100644
index 0000000000..a5ab093a27
--- /dev/null
+++ b/tests/replication/http/test__base.py
@@ -0,0 +1,106 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from http import HTTPStatus
+from typing import Tuple
+
+from twisted.web.server import Request
+
+from synapse.api.errors import Codes
+from synapse.http.server import JsonResource, cancellable
+from synapse.replication.http import REPLICATION_PREFIX
+from synapse.replication.http._base import ReplicationEndpoint
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+
+from tests import unittest
+from tests.http.server._base import EndpointCancellationTestHelperMixin
+
+
+class CancellableReplicationEndpoint(ReplicationEndpoint):
+ NAME = "cancellable_sleep"
+ PATH_ARGS = ()
+ CACHE = False
+
+ def __init__(self, hs: HomeServer):
+ super().__init__(hs)
+ self.clock = hs.get_clock()
+
+ @staticmethod
+ async def _serialize_payload() -> JsonDict:
+ return {}
+
+ @cancellable
+ async def _handle_request( # type: ignore[override]
+ self, request: Request
+ ) -> Tuple[int, JsonDict]:
+ await self.clock.sleep(1.0)
+ return HTTPStatus.OK, {"result": True}
+
+
+class UncancellableReplicationEndpoint(ReplicationEndpoint):
+ NAME = "uncancellable_sleep"
+ PATH_ARGS = ()
+ CACHE = False
+
+ def __init__(self, hs: HomeServer):
+ super().__init__(hs)
+ self.clock = hs.get_clock()
+
+ @staticmethod
+ async def _serialize_payload() -> JsonDict:
+ return {}
+
+ async def _handle_request( # type: ignore[override]
+ self, request: Request
+ ) -> Tuple[int, JsonDict]:
+ await self.clock.sleep(1.0)
+ return HTTPStatus.OK, {"result": True}
+
+
+class ReplicationEndpointCancellationTestCase(
+ unittest.HomeserverTestCase, EndpointCancellationTestHelperMixin
+):
+ """Tests for `ReplicationEndpoint` cancellation."""
+
+ def create_test_resource(self):
+ """Overrides `HomeserverTestCase.create_test_resource`."""
+ resource = JsonResource(self.hs)
+
+ CancellableReplicationEndpoint(self.hs).register(resource)
+ UncancellableReplicationEndpoint(self.hs).register(resource)
+
+ return resource
+
+ def test_cancellable_disconnect(self) -> None:
+ """Test that handlers with the `@cancellable` flag can be cancelled."""
+ path = f"{REPLICATION_PREFIX}/{CancellableReplicationEndpoint.NAME}/"
+ channel = self.make_request("POST", path, await_result=False)
+ self._test_disconnect(
+ self.reactor,
+ channel,
+ expect_cancellation=True,
+ expected_body={"error": "Request cancelled", "errcode": Codes.UNKNOWN},
+ )
+
+ def test_uncancellable_disconnect(self) -> None:
+ """Test that handlers without the `@cancellable` flag cannot be cancelled."""
+ path = f"{REPLICATION_PREFIX}/{UncancellableReplicationEndpoint.NAME}/"
+ channel = self.make_request("POST", path, await_result=False)
+ self._test_disconnect(
+ self.reactor,
+ channel,
+ expect_cancellation=False,
+ expected_body={"result": True},
+ )
diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py
index 5f142e84c3..a7ca68069e 100644
--- a/tests/replication/test_sharded_event_persister.py
+++ b/tests/replication/test_sharded_event_persister.py
@@ -14,7 +14,6 @@
import logging
from unittest.mock import patch
-from synapse.api.room_versions import RoomVersion
from synapse.rest import admin
from synapse.rest.client import login, room, sync
from synapse.storage.util.id_generators import MultiWriterIdGenerator
@@ -64,21 +63,10 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
# We control the room ID generation by patching out the
# `_generate_room_id` method
- async def generate_room(
- creator_id: str, is_public: bool, room_version: RoomVersion
- ):
- await self.store.store_room(
- room_id=room_id,
- room_creator_user_id=creator_id,
- is_public=is_public,
- room_version=room_version,
- )
- return room_id
-
with patch(
"synapse.handlers.room.RoomCreationHandler._generate_room_id"
) as mock:
- mock.side_effect = generate_room
+ mock.side_effect = lambda: room_id
self.helper.create_room_as(user_id, tok=tok)
def test_basic(self):
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index 9443daa056..d0197aca94 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -925,7 +925,7 @@ class RoomJoinTestCase(RoomBase):
) -> bool:
return return_value
- callback_mock = Mock(side_effect=user_may_join_room)
+ callback_mock = Mock(side_effect=user_may_join_room, spec=lambda *x: None)
self.hs.get_spam_checker()._user_may_join_room_callbacks.append(callback_mock)
# Join a first room, without being invited to it.
@@ -1116,6 +1116,264 @@ class RoomMessagesTestCase(RoomBase):
self.assertEqual(200, channel.code, msg=channel.result["body"])
+class RoomPowerLevelOverridesTestCase(RoomBase):
+ """Tests that the power levels can be overridden with server config."""
+
+ user_id = "@sid1:red"
+
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.admin_user_id = self.register_user("admin", "pass")
+ self.admin_access_token = self.login("admin", "pass")
+
+ def power_levels(self, room_id: str) -> Dict[str, Any]:
+ return self.helper.get_state(
+ room_id, "m.room.power_levels", self.admin_access_token
+ )
+
+ def test_default_power_levels_with_room_override(self) -> None:
+ """
+ Create a room, providing power level overrides.
+ Confirm that the room's power levels reflect the overrides.
+
+ See https://github.com/matrix-org/matrix-spec/issues/492
+ - currently we overwrite each key of power_level_content_override
+ completely.
+ """
+
+ room_id = self.helper.create_room_as(
+ self.user_id,
+ extra_content={
+ "power_level_content_override": {"events": {"custom.event": 0}}
+ },
+ )
+ self.assertEqual(
+ {
+ "custom.event": 0,
+ },
+ self.power_levels(room_id)["events"],
+ )
+
+ @unittest.override_config(
+ {
+ "default_power_level_content_override": {
+ "public_chat": {"events": {"custom.event": 0}},
+ }
+ },
+ )
+ def test_power_levels_with_server_override(self) -> None:
+ """
+ With a server configured to modify the room-level defaults,
+ Create a room, without providing any extra power level overrides.
+ Confirm that the room's power levels reflect the server-level overrides.
+
+ Similar to https://github.com/matrix-org/matrix-spec/issues/492,
+ we overwrite each key of power_level_content_override completely.
+ """
+
+ room_id = self.helper.create_room_as(self.user_id)
+ self.assertEqual(
+ {
+ "custom.event": 0,
+ },
+ self.power_levels(room_id)["events"],
+ )
+
+ @unittest.override_config(
+ {
+ "default_power_level_content_override": {
+ "public_chat": {
+ "events": {"server.event": 0},
+ "ban": 13,
+ },
+ }
+ },
+ )
+ def test_power_levels_with_server_and_room_overrides(self) -> None:
+ """
+ With a server configured to modify the room-level defaults,
+ create a room, providing different overrides.
+ Confirm that the room's power levels reflect both overrides, and
+ choose the room overrides where they clash.
+ """
+
+ room_id = self.helper.create_room_as(
+ self.user_id,
+ extra_content={
+ "power_level_content_override": {"events": {"room.event": 0}}
+ },
+ )
+
+ # Room override wins over server config
+ self.assertEqual(
+ {"room.event": 0},
+ self.power_levels(room_id)["events"],
+ )
+
+ # But where there is no room override, server config wins
+ self.assertEqual(13, self.power_levels(room_id)["ban"])
+
+
+class RoomPowerLevelOverridesInPracticeTestCase(RoomBase):
+ """
+ Tests that we can really do various otherwise-prohibited actions
+ based on overriding the power levels in config.
+ """
+
+ user_id = "@sid1:red"
+
+ def test_creator_can_post_state_event(self) -> None:
+ # Given I am the creator of a room
+ room_id = self.helper.create_room_as(self.user_id)
+
+ # When I send a state event
+ path = "/rooms/{room_id}/state/custom.event/my_state_key".format(
+ room_id=urlparse.quote(room_id),
+ )
+ channel = self.make_request("PUT", path, "{}")
+
+ # Then I am allowed
+ self.assertEqual(200, channel.code, msg=channel.result["body"])
+
+ def test_normal_user_can_not_post_state_event(self) -> None:
+ # Given I am a normal member of a room
+ room_id = self.helper.create_room_as("@some_other_guy:red")
+ self.helper.join(room=room_id, user=self.user_id)
+
+ # When I send a state event
+ path = "/rooms/{room_id}/state/custom.event/my_state_key".format(
+ room_id=urlparse.quote(room_id),
+ )
+ channel = self.make_request("PUT", path, "{}")
+
+ # Then I am not allowed because state events require PL>=50
+ self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ "You don't have permission to post that to the room. "
+ "user_level (0) < send_level (50)",
+ channel.json_body["error"],
+ )
+
+ @unittest.override_config(
+ {
+ "default_power_level_content_override": {
+ "public_chat": {"events": {"custom.event": 0}},
+ }
+ },
+ )
+ def test_with_config_override_normal_user_can_post_state_event(self) -> None:
+ # Given the server has config allowing normal users to post my event type,
+ # and I am a normal member of a room
+ room_id = self.helper.create_room_as("@some_other_guy:red")
+ self.helper.join(room=room_id, user=self.user_id)
+
+ # When I send a state event
+ path = "/rooms/{room_id}/state/custom.event/my_state_key".format(
+ room_id=urlparse.quote(room_id),
+ )
+ channel = self.make_request("PUT", path, "{}")
+
+ # Then I am allowed
+ self.assertEqual(200, channel.code, msg=channel.result["body"])
+
+ @unittest.override_config(
+ {
+ "default_power_level_content_override": {
+ "public_chat": {"events": {"custom.event": 0}},
+ }
+ },
+ )
+ def test_any_room_override_defeats_config_override(self) -> None:
+ # Given the server has config allowing normal users to post my event type
+ # And I am a normal member of a room
+ # But the room was created with special permissions
+ extra_content: Dict[str, Any] = {
+ "power_level_content_override": {"events": {}},
+ }
+ room_id = self.helper.create_room_as(
+ "@some_other_guy:red", extra_content=extra_content
+ )
+ self.helper.join(room=room_id, user=self.user_id)
+
+ # When I send a state event
+ path = "/rooms/{room_id}/state/custom.event/my_state_key".format(
+ room_id=urlparse.quote(room_id),
+ )
+ channel = self.make_request("PUT", path, "{}")
+
+ # Then I am not allowed
+ self.assertEqual(403, channel.code, msg=channel.result["body"])
+
+ @unittest.override_config(
+ {
+ "default_power_level_content_override": {
+ "public_chat": {"events": {"custom.event": 0}},
+ }
+ },
+ )
+ def test_specific_room_override_defeats_config_override(self) -> None:
+ # Given the server has config allowing normal users to post my event type,
+ # and I am a normal member of a room,
+ # but the room was created with special permissions for this event type
+ extra_content = {
+ "power_level_content_override": {"events": {"custom.event": 1}},
+ }
+ room_id = self.helper.create_room_as(
+ "@some_other_guy:red", extra_content=extra_content
+ )
+ self.helper.join(room=room_id, user=self.user_id)
+
+ # When I send a state event
+ path = "/rooms/{room_id}/state/custom.event/my_state_key".format(
+ room_id=urlparse.quote(room_id),
+ )
+ channel = self.make_request("PUT", path, "{}")
+
+ # Then I am not allowed
+ self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ "You don't have permission to post that to the room. "
+ + "user_level (0) < send_level (1)",
+ channel.json_body["error"],
+ )
+
+ @unittest.override_config(
+ {
+ "default_power_level_content_override": {
+ "public_chat": {"events": {"custom.event": 0}},
+ "private_chat": None,
+ "trusted_private_chat": None,
+ }
+ },
+ )
+ def test_config_override_applies_only_to_specific_preset(self) -> None:
+ # Given the server has config for public_chats,
+ # and I am a normal member of a private_chat room
+ room_id = self.helper.create_room_as("@some_other_guy:red", is_public=False)
+ self.helper.invite(room=room_id, src="@some_other_guy:red", targ=self.user_id)
+ self.helper.join(room=room_id, user=self.user_id)
+
+ # When I send a state event
+ path = "/rooms/{room_id}/state/custom.event/my_state_key".format(
+ room_id=urlparse.quote(room_id),
+ )
+ channel = self.make_request("PUT", path, "{}")
+
+ # Then I am not allowed because the public_chat config does not
+ # affect this room, because this room is a private_chat
+ self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ "You don't have permission to post that to the room. "
+ + "user_level (0) < send_level (50)",
+ channel.json_body["error"],
+ )
+
+
class RoomInitialSyncTestCase(RoomBase):
"""Tests /rooms/$room_id/initialSync."""
@@ -2598,7 +2856,9 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
# Add a mock to the spamchecker callbacks for user_may_send_3pid_invite. Make it
# allow everything for now.
- mock = Mock(return_value=make_awaitable(True))
+ # `spec` argument is needed for this function mock to have `__qualname__`, which
+ # is needed for `Measure` metrics buried in SpamChecker.
+ mock = Mock(return_value=make_awaitable(True), spec=lambda *x: None)
self.hs.get_spam_checker()._user_may_send_3pid_invite_callbacks.append(mock)
# Send a 3PID invite into the room and check that it succeeded.
diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py
index 0108337649..74b6560cbc 100644
--- a/tests/rest/client/test_sync.py
+++ b/tests/rest/client/test_sync.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
+from http import HTTPStatus
from typing import List, Optional
from parameterized import parameterized
@@ -485,30 +486,7 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
# Test that we didn't override the public read receipt
self.assertIsNone(self._get_read_receipt())
- @parameterized.expand(
- [
- # Old Element version, expected to send an empty body
- (
- "agent1",
- "Element/1.2.2 (Linux; U; Android 9; MatrixAndroidSDK_X 0.0.1)",
- 200,
- ),
- # Old SchildiChat version, expected to send an empty body
- ("agent2", "SchildiChat/1.2.1 (Android 10)", 200),
- # Expected 400: Denies empty body starting at version 1.3+
- ("agent3", "Element/1.3.6 (Android 10)", 400),
- ("agent4", "SchildiChat/1.3.6 (Android 11)", 400),
- # Contains "Riot": Receipts with empty bodies expected
- ("agent5", "Element (Riot.im) (Android 9)", 200),
- # Expected 400: Does not contain "Android"
- ("agent6", "Element/1.2.1", 400),
- # Expected 400: Different format, missing "/" after Element; existing build that should allow empty bodies, but minimal ongoing usage
- ("agent7", "Element dbg/1.1.8-dev (Android)", 400),
- ]
- )
- def test_read_receipt_with_empty_body(
- self, name: str, user_agent: str, expected_status_code: int
- ) -> None:
+ def test_read_receipt_with_empty_body_is_rejected(self) -> None:
# Send a message as the first user
res = self.helper.send(self.room_id, body="hello", tok=self.tok)
@@ -517,9 +495,9 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
"POST",
f"/rooms/{self.room_id}/receipt/m.read/{res['event_id']}",
access_token=self.tok2,
- custom_headers=[("User-Agent", user_agent)],
)
- self.assertEqual(channel.code, expected_status_code)
+ self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST)
+ self.assertEqual(channel.json_body["errcode"], "M_NOT_JSON", channel.json_body)
def _get_read_receipt(self) -> Optional[JsonDict]:
"""Syncs and returns the read receipt."""
@@ -678,12 +656,13 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
self._check_unread_count(3)
# Check that custom events with a body increase the unread counter.
- self.helper.send_event(
+ result = self.helper.send_event(
self.room_id,
"org.matrix.custom_type",
{"body": "hello"},
tok=self.tok2,
)
+ event_id = result["event_id"]
self._check_unread_count(4)
# Check that edits don't increase the unread counter.
@@ -693,7 +672,10 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
content={
"body": "hello",
"msgtype": "m.text",
- "m.relates_to": {"rel_type": RelationTypes.REPLACE},
+ "m.relates_to": {
+ "rel_type": RelationTypes.REPLACE,
+ "event_id": event_id,
+ },
},
tok=self.tok2,
)
diff --git a/tests/server.py b/tests/server.py
index 8f30e250c8..b9f465971f 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -109,6 +109,17 @@ class FakeChannel:
_ip: str = "127.0.0.1"
_producer: Optional[Union[IPullProducer, IPushProducer]] = None
resource_usage: Optional[ContextResourceUsage] = None
+ _request: Optional[Request] = None
+
+ @property
+ def request(self) -> Request:
+ assert self._request is not None
+ return self._request
+
+ @request.setter
+ def request(self, request: Request) -> None:
+ assert self._request is None
+ self._request = request
@property
def json_body(self):
@@ -322,6 +333,8 @@ def make_request(
channel = FakeChannel(site, reactor, ip=client_ip)
req = request(channel, site)
+ channel.request = req
+
req.content = BytesIO(content)
# Twisted expects to be at the end of the content when parsing the request.
req.content.seek(0, SEEK_END)
@@ -736,6 +749,7 @@ def setup_test_homeserver(
if config is None:
config = default_config(name, parse=True)
+ config.caches.resize_all_caches()
config.ldap_enabled = False
if "clock" not in kwargs:
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index 9ee9509d3a..07e29788e5 100644
--- a/tests/server_notices/test_resource_limits_server_notices.py
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -75,6 +75,9 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
self._rlsn._server_notices_manager.get_or_create_notice_room_for_user = Mock(
return_value=make_awaitable("!something:localhost")
)
+ self._rlsn._server_notices_manager.maybe_get_notice_room_for_user = Mock(
+ return_value=make_awaitable("!something:localhost")
+ )
self._rlsn._store.add_tag_to_room = Mock(return_value=make_awaitable(None))
self._rlsn._store.get_tags_for_room = Mock(return_value=make_awaitable({}))
@@ -102,6 +105,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
)
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
# Would be better to check the content, but once == remove blocking event
+ self._rlsn._server_notices_manager.maybe_get_notice_room_for_user.assert_called_once()
self._send_notice.assert_called_once()
def test_maybe_send_server_notice_to_user_remove_blocked_notice_noop(self):
@@ -300,7 +304,10 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
hasn't been reached (since it's the only user and the limit is 5), so users
shouldn't receive a server notice.
"""
- self.register_user("user", "password")
+ m = Mock(return_value=make_awaitable(None))
+ self._rlsn._server_notices_manager.maybe_get_notice_room_for_user = m
+
+ user_id = self.register_user("user", "password")
tok = self.login("user", "password")
channel = self.make_request("GET", "/sync?timeout=0", access_token=tok)
@@ -309,6 +316,8 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
"rooms", channel.json_body, "Got invites without server notice"
)
+ m.assert_called_once_with(user_id)
+
def test_invite_with_notice(self):
"""Tests that, if the MAU limit is hit, the server notices user invites each user
to a room in which it has sent a notice.
diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py
index c237a8c7e2..38963ce4a7 100644
--- a/tests/storage/databases/main/test_events_worker.py
+++ b/tests/storage/databases/main/test_events_worker.py
@@ -154,6 +154,31 @@ class EventCacheTestCase(unittest.HomeserverTestCase):
# We should have fetched the event from the DB
self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1)
+ def test_event_ref(self):
+ """Test that we reuse events that are still in memory but have fallen
+ out of the cache, rather than requesting them from the DB.
+ """
+
+ # Reset the event cache
+ self.store._get_event_cache.clear()
+
+ with LoggingContext("test") as ctx:
+ # We keep hold of the event event though we never use it.
+ event = self.get_success(self.store.get_event(self.event_id)) # noqa: F841
+
+ # We should have fetched the event from the DB
+ self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1)
+
+ # Reset the event cache
+ self.store._get_event_cache.clear()
+
+ with LoggingContext("test") as ctx:
+ self.get_success(self.store.get_event(self.event_id))
+
+ # Since the event is still in memory we shouldn't have fetched it
+ # from the DB
+ self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 0)
+
def test_dedupe(self):
"""Test that if we request the same event multiple times we only pull it
out once.
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index 401020fd63..c7661e7186 100644
--- a/tests/storage/test_event_chain.py
+++ b/tests/storage/test_event_chain.py
@@ -393,7 +393,7 @@ class EventChainStoreTestCase(HomeserverTestCase):
# We need to persist the events to the events and state_events
# tables.
persist_events_store._store_event_txn(
- txn, [(e, EventContext()) for e in events]
+ txn, [(e, EventContext(self.hs.get_storage())) for e in events]
)
# Actually call the function that calculates the auth chain stuff.
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 645d564d1c..d92a9ac5b7 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -58,15 +58,6 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
(room_id, event_id),
)
- txn.execute(
- (
- "INSERT INTO event_reference_hashes "
- "(event_id, algorithm, hash) "
- "VALUES (?, 'sha256', ?)"
- ),
- (event_id, bytearray(b"ffff")),
- )
-
for i in range(0, 20):
self.get_success(
self.store.db_pool.runInteraction("insert", insert_event, i)
diff --git a/tests/test_server.py b/tests/test_server.py
index f2ffbc895b..0f1eb43cbc 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -13,18 +13,28 @@
# limitations under the License.
import re
+from http import HTTPStatus
+from typing import Tuple
from twisted.internet.defer import Deferred
from twisted.web.resource import Resource
from synapse.api.errors import Codes, RedirectException, SynapseError
from synapse.config.server import parse_listener_def
-from synapse.http.server import DirectServeHtmlResource, JsonResource, OptionsResource
-from synapse.http.site import SynapseSite
+from synapse.http.server import (
+ DirectServeHtmlResource,
+ DirectServeJsonResource,
+ JsonResource,
+ OptionsResource,
+ cancellable,
+)
+from synapse.http.site import SynapseRequest, SynapseSite
from synapse.logging.context import make_deferred_yieldable
+from synapse.types import JsonDict
from synapse.util import Clock
from tests import unittest
+from tests.http.server._base import EndpointCancellationTestHelperMixin
from tests.server import (
FakeSite,
ThreadedMemoryReactorClock,
@@ -363,3 +373,100 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
self.assertEqual(channel.result["code"], b"200")
self.assertNotIn("body", channel.result)
+
+
+class CancellableDirectServeJsonResource(DirectServeJsonResource):
+ def __init__(self, clock: Clock):
+ super().__init__()
+ self.clock = clock
+
+ @cancellable
+ async def _async_render_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+ await self.clock.sleep(1.0)
+ return HTTPStatus.OK, {"result": True}
+
+ async def _async_render_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+ await self.clock.sleep(1.0)
+ return HTTPStatus.OK, {"result": True}
+
+
+class CancellableDirectServeHtmlResource(DirectServeHtmlResource):
+ ERROR_TEMPLATE = "{code} {msg}"
+
+ def __init__(self, clock: Clock):
+ super().__init__()
+ self.clock = clock
+
+ @cancellable
+ async def _async_render_GET(self, request: SynapseRequest) -> Tuple[int, bytes]:
+ await self.clock.sleep(1.0)
+ return HTTPStatus.OK, b"ok"
+
+ async def _async_render_POST(self, request: SynapseRequest) -> Tuple[int, bytes]:
+ await self.clock.sleep(1.0)
+ return HTTPStatus.OK, b"ok"
+
+
+class DirectServeJsonResourceCancellationTests(EndpointCancellationTestHelperMixin):
+ """Tests for `DirectServeJsonResource` cancellation."""
+
+ def setUp(self):
+ self.reactor = ThreadedMemoryReactorClock()
+ self.clock = Clock(self.reactor)
+ self.resource = CancellableDirectServeJsonResource(self.clock)
+ self.site = FakeSite(self.resource, self.reactor)
+
+ def test_cancellable_disconnect(self) -> None:
+ """Test that handlers with the `@cancellable` flag can be cancelled."""
+ channel = make_request(
+ self.reactor, self.site, "GET", "/sleep", await_result=False
+ )
+ self._test_disconnect(
+ self.reactor,
+ channel,
+ expect_cancellation=True,
+ expected_body={"error": "Request cancelled", "errcode": Codes.UNKNOWN},
+ )
+
+ def test_uncancellable_disconnect(self) -> None:
+ """Test that handlers without the `@cancellable` flag cannot be cancelled."""
+ channel = make_request(
+ self.reactor, self.site, "POST", "/sleep", await_result=False
+ )
+ self._test_disconnect(
+ self.reactor,
+ channel,
+ expect_cancellation=False,
+ expected_body={"result": True},
+ )
+
+
+class DirectServeHtmlResourceCancellationTests(EndpointCancellationTestHelperMixin):
+ """Tests for `DirectServeHtmlResource` cancellation."""
+
+ def setUp(self):
+ self.reactor = ThreadedMemoryReactorClock()
+ self.clock = Clock(self.reactor)
+ self.resource = CancellableDirectServeHtmlResource(self.clock)
+ self.site = FakeSite(self.resource, self.reactor)
+
+ def test_cancellable_disconnect(self) -> None:
+ """Test that handlers with the `@cancellable` flag can be cancelled."""
+ channel = make_request(
+ self.reactor, self.site, "GET", "/sleep", await_result=False
+ )
+ self._test_disconnect(
+ self.reactor,
+ channel,
+ expect_cancellation=True,
+ expected_body=b"499 Request cancelled",
+ )
+
+ def test_uncancellable_disconnect(self) -> None:
+ """Test that handlers without the `@cancellable` flag cannot be cancelled."""
+ channel = make_request(
+ self.reactor, self.site, "POST", "/sleep", await_result=False
+ )
+ self._test_disconnect(
+ self.reactor, channel, expect_cancellation=False, expected_body=b"ok"
+ )
diff --git a/tests/test_state.py b/tests/test_state.py
index e4baa69137..651ec1c7d4 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -88,6 +88,9 @@ class _DummyStore:
return groups
+ async def get_state_ids_for_group(self, state_group):
+ return self._group_to_state[state_group]
+
async def store_state_group(
self, event_id, room_id, prev_group, delta_ids, current_state_ids
):
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index d0230f9ebb..7a9b01ef9d 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -234,7 +234,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
event = self.get_success(builder.build(prev_event_ids=[], auth_event_ids=[]))
event.internal_metadata.outlier = True
self.get_success(
- self.storage.persistence.persist_event(event, EventContext.for_outlier())
+ self.storage.persistence.persist_event(
+ event, EventContext.for_outlier(self.storage)
+ )
)
return event
diff --git a/tests/unittest.py b/tests/unittest.py
index 9afa68c164..e7f255b4fa 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -831,7 +831,7 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
self.site,
method=method,
path=path,
- content=content or "",
+ content=content if content is not None else "",
shorthand=False,
await_result=await_result,
custom_headers=custom_headers,
diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py
index 321fc1776f..67173a4f5b 100644
--- a/tests/util/test_lrucache.py
+++ b/tests/util/test_lrucache.py
@@ -14,8 +14,9 @@
from typing import List
-from unittest.mock import Mock
+from unittest.mock import Mock, patch
+from synapse.metrics.jemalloc import JemallocStats
from synapse.util.caches.lrucache import LruCache, setup_expire_lru_cache_entries
from synapse.util.caches.treecache import TreeCache
@@ -316,3 +317,58 @@ class TimeEvictionTestCase(unittest.HomeserverTestCase):
self.assertEqual(cache.get("key1"), None)
self.assertEqual(cache.get("key2"), 3)
+
+
+class MemoryEvictionTestCase(unittest.HomeserverTestCase):
+ @override_config(
+ {
+ "caches": {
+ "cache_autotuning": {
+ "max_cache_memory_usage": "700M",
+ "target_cache_memory_usage": "500M",
+ "min_cache_ttl": "5m",
+ }
+ }
+ }
+ )
+ @patch("synapse.util.caches.lrucache.get_jemalloc_stats")
+ def test_evict_memory(self, jemalloc_interface) -> None:
+ mock_jemalloc_class = Mock(spec=JemallocStats)
+ jemalloc_interface.return_value = mock_jemalloc_class
+
+ # set the return value of get_stat() to be greater than max_cache_memory_usage
+ mock_jemalloc_class.get_stat.return_value = 924288000
+
+ setup_expire_lru_cache_entries(self.hs)
+ cache = LruCache(4, clock=self.hs.get_clock())
+
+ cache["key1"] = 1
+ cache["key2"] = 2
+
+ # advance the reactor less than the min_cache_ttl
+ self.reactor.advance(60 * 2)
+
+ # our items should still be in the cache
+ self.assertEqual(cache.get("key1"), 1)
+ self.assertEqual(cache.get("key2"), 2)
+
+ # advance the reactor past the min_cache_ttl
+ self.reactor.advance(60 * 6)
+
+ # the items should be cleared from cache
+ self.assertEqual(cache.get("key1"), None)
+ self.assertEqual(cache.get("key2"), None)
+
+ # add more stuff to caches
+ cache["key1"] = 1
+ cache["key2"] = 2
+
+ # set the return value of get_stat() to be lower than target_cache_memory_usage
+ mock_jemalloc_class.get_stat.return_value = 10000
+
+ # advance the reactor past the min_cache_ttl
+ self.reactor.advance(60 * 6)
+
+ # the items should still be in the cache
+ self.assertEqual(cache.get("key1"), 1)
+ self.assertEqual(cache.get("key2"), 2)
|