summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/api/test_auth.py29
-rw-r--r--tests/api/test_ratelimiting.py4
-rw-r--r--tests/appservice/test_appservice.py1
-rw-r--r--tests/handlers/test_appservice.py20
-rw-r--r--tests/handlers/test_device.py2
-rw-r--r--tests/handlers/test_message.py2
-rw-r--r--tests/handlers/test_oidc.py24
-rw-r--r--tests/logging/__init__.py34
-rw-r--r--tests/logging/test_remote_handler.py169
-rw-r--r--tests/logging/test_structured.py214
-rw-r--r--tests/logging/test_terse_json.py253
-rw-r--r--tests/push/test_email.py31
-rw-r--r--tests/push/test_http.py10
-rw-r--r--tests/replication/_base.py13
-rw-r--r--tests/replication/tcp/streams/test_events.py2
-rw-r--r--tests/replication/test_multi_media_repo.py277
-rw-r--r--tests/replication/test_pusher_shard.py2
-rw-r--r--tests/rest/admin/test_device.py17
-rw-r--r--tests/rest/admin/test_event_reports.py196
-rw-r--r--tests/rest/admin/test_media.py568
-rw-r--r--tests/rest/admin/test_user.py422
-rw-r--r--tests/rest/client/v2_alpha/test_register.py1
-rw-r--r--tests/server.py10
-rw-r--r--tests/storage/test_cleanup_extrems.py6
-rw-r--r--tests/storage/test_client_ips.py2
-rw-r--r--tests/storage/test_event_metrics.py4
-rw-r--r--tests/storage/test_registration.py10
-rw-r--r--tests/storage/test_roommember.py4
-rw-r--r--tests/test_federation.py4
-rw-r--r--tests/test_utils/__init__.py34
-rw-r--r--tests/test_utils/event_injection.py2
-rw-r--r--tests/unittest.py10
-rw-r--r--tests/util/caches/test_descriptors.py60
33 files changed, 1950 insertions, 487 deletions
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py

index cb6f29d670..0fd55f428a 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py
@@ -29,6 +29,7 @@ from synapse.api.errors import ( MissingClientTokenError, ResourceLimitError, ) +from synapse.storage.databases.main.registration import TokenLookupResult from synapse.types import UserID from tests import unittest @@ -61,7 +62,9 @@ class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_user_by_req_user_valid_token(self): - user_info = {"name": self.test_user, "token_id": "ditto", "device_id": "device"} + user_info = TokenLookupResult( + user_id=self.test_user, token_id=5, device_id="device" + ) self.store.get_user_by_access_token = Mock( return_value=defer.succeed(user_info) ) @@ -84,7 +87,7 @@ class AuthTestCase(unittest.TestCase): self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN") def test_get_user_by_req_user_missing_token(self): - user_info = {"name": self.test_user, "token_id": "ditto"} + user_info = TokenLookupResult(user_id=self.test_user, token_id=5) self.store.get_user_by_access_token = Mock( return_value=defer.succeed(user_info) ) @@ -221,7 +224,7 @@ class AuthTestCase(unittest.TestCase): def test_get_user_from_macaroon(self): self.store.get_user_by_access_token = Mock( return_value=defer.succeed( - {"name": "@baldrick:matrix.org", "device_id": "device"} + TokenLookupResult(user_id="@baldrick:matrix.org", device_id="device") ) ) @@ -237,12 +240,11 @@ class AuthTestCase(unittest.TestCase): user_info = yield defer.ensureDeferred( self.auth.get_user_by_access_token(macaroon.serialize()) ) - user = user_info["user"] - self.assertEqual(UserID.from_string(user_id), user) + self.assertEqual(user_id, user_info.user_id) # TODO: device_id should come from the macaroon, but currently comes # from the db. - self.assertEqual(user_info["device_id"], "device") + self.assertEqual(user_info.device_id, "device") @defer.inlineCallbacks def test_get_guest_user_from_macaroon(self): @@ -264,10 +266,8 @@ class AuthTestCase(unittest.TestCase): user_info = yield defer.ensureDeferred( self.auth.get_user_by_access_token(serialized) ) - user = user_info["user"] - is_guest = user_info["is_guest"] - self.assertEqual(UserID.from_string(user_id), user) - self.assertTrue(is_guest) + self.assertEqual(user_id, user_info.user_id) + self.assertTrue(user_info.is_guest) self.store.get_user_by_id.assert_called_with(user_id) @defer.inlineCallbacks @@ -289,12 +289,9 @@ class AuthTestCase(unittest.TestCase): if token != tok: return defer.succeed(None) return defer.succeed( - { - "name": USER_ID, - "is_guest": False, - "token_id": 1234, - "device_id": "DEVICE", - } + TokenLookupResult( + user_id=USER_ID, is_guest=False, token_id=1234, device_id="DEVICE", + ) ) self.store.get_user_by_access_token = get_user diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py
index 1e1f30d790..fe504d0869 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py
@@ -43,7 +43,7 @@ class TestRatelimiter(unittest.TestCase): def test_allowed_appservice_ratelimited_via_can_requester_do_action(self): appservice = ApplicationService( - None, "example.com", id="foo", rate_limited=True, + None, "example.com", id="foo", rate_limited=True, sender="@as:example.com", ) as_requester = create_requester("@user:example.com", app_service=appservice) @@ -68,7 +68,7 @@ class TestRatelimiter(unittest.TestCase): def test_allowed_appservice_via_can_requester_do_action(self): appservice = ApplicationService( - None, "example.com", id="foo", rate_limited=False, + None, "example.com", id="foo", rate_limited=False, sender="@as:example.com", ) as_requester = create_requester("@user:example.com", app_service=appservice) diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py
index 236b608d58..0bffeb1150 100644 --- a/tests/appservice/test_appservice.py +++ b/tests/appservice/test_appservice.py
@@ -31,6 +31,7 @@ class ApplicationServiceTestCase(unittest.TestCase): def setUp(self): self.service = ApplicationService( id="unique_identifier", + sender="@as:test", url="some_url", token="some_token", hostname="matrix.org", # only used by get_groups_for_user diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index ee4f3da31c..53763cd0f9 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py
@@ -42,7 +42,6 @@ class AppServiceHandlerTestCase(unittest.TestCase): hs.get_clock.return_value = MockClock() self.handler = ApplicationServicesHandler(hs) - @defer.inlineCallbacks def test_notify_interested_services(self): interested_service = self._mkservice(is_interested=True) services = [ @@ -62,14 +61,12 @@ class AppServiceHandlerTestCase(unittest.TestCase): defer.succeed((0, [event])), defer.succeed((0, [])), ] - yield defer.ensureDeferred( - self.handler.notify_interested_services(RoomStreamToken(None, 0)) - ) + self.handler.notify_interested_services(RoomStreamToken(None, 0)) + self.mock_scheduler.submit_event_for_as.assert_called_once_with( interested_service, event ) - @defer.inlineCallbacks def test_query_user_exists_unknown_user(self): user_id = "@someone:anywhere" services = [self._mkservice(is_interested=True)] @@ -83,12 +80,11 @@ class AppServiceHandlerTestCase(unittest.TestCase): defer.succeed((0, [event])), defer.succeed((0, [])), ] - yield defer.ensureDeferred( - self.handler.notify_interested_services(RoomStreamToken(None, 0)) - ) + + self.handler.notify_interested_services(RoomStreamToken(None, 0)) + self.mock_as_api.query_user.assert_called_once_with(services[0], user_id) - @defer.inlineCallbacks def test_query_user_exists_known_user(self): user_id = "@someone:anywhere" services = [self._mkservice(is_interested=True)] @@ -102,9 +98,9 @@ class AppServiceHandlerTestCase(unittest.TestCase): defer.succeed((0, [event])), defer.succeed((0, [])), ] - yield defer.ensureDeferred( - self.handler.notify_interested_services(RoomStreamToken(None, 0)) - ) + + self.handler.notify_interested_services(RoomStreamToken(None, 0)) + self.assertFalse( self.mock_as_api.query_user.called, "query_user called when it shouldn't have been.", diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 4512c51311..875aaec2c6 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py
@@ -289,7 +289,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase): # make sure that our device ID has changed user_info = self.get_success(self.auth.get_user_by_access_token(access_token)) - self.assertEqual(user_info["device_id"], retrieved_device_id) + self.assertEqual(user_info.device_id, retrieved_device_id) # make sure the device has the display name that was set from the login res = self.get_success(self.handler.get_device(user_id, retrieved_device_id)) diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py
index 9f6f21a6e2..2e0fea04af 100644 --- a/tests/handlers/test_message.py +++ b/tests/handlers/test_message.py
@@ -46,7 +46,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase): self.info = self.get_success( self.hs.get_datastore().get_user_by_access_token(self.access_token,) ) - self.token_id = self.info["token_id"] + self.token_id = self.info.token_id self.requester = create_requester(self.user_id, access_token_id=self.token_id) diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index b6f436c016..0d51705849 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py
@@ -394,7 +394,14 @@ class OidcHandlerTestCase(HomeserverTestCase): self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id) self.handler._auth_handler.complete_sso_login = simple_async_mock() request = Mock( - spec=["args", "getCookie", "addCookie", "requestHeaders", "getClientIP"] + spec=[ + "args", + "getCookie", + "addCookie", + "requestHeaders", + "getClientIP", + "get_user_agent", + ] ) code = "code" @@ -414,9 +421,8 @@ class OidcHandlerTestCase(HomeserverTestCase): request.args[b"code"] = [code.encode("utf-8")] request.args[b"state"] = [state.encode("utf-8")] - request.requestHeaders = Mock(spec=["getRawHeaders"]) - request.requestHeaders.getRawHeaders.return_value = [user_agent.encode("ascii")] request.getClientIP.return_value = ip_address + request.get_user_agent.return_value = user_agent self.get_success(self.handler.handle_oidc_callback(request)) @@ -621,7 +627,14 @@ class OidcHandlerTestCase(HomeserverTestCase): self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id) self.handler._auth_handler.complete_sso_login = simple_async_mock() request = Mock( - spec=["args", "getCookie", "addCookie", "requestHeaders", "getClientIP"] + spec=[ + "args", + "getCookie", + "addCookie", + "requestHeaders", + "getClientIP", + "get_user_agent", + ] ) state = "state" @@ -637,9 +650,8 @@ class OidcHandlerTestCase(HomeserverTestCase): request.args[b"code"] = [b"code"] request.args[b"state"] = [state.encode("utf-8")] - request.requestHeaders = Mock(spec=["getRawHeaders"]) - request.requestHeaders.getRawHeaders.return_value = [b"Browser"] request.getClientIP.return_value = "10.0.0.1" + request.get_user_agent.return_value = "Browser" self.get_success(self.handler.handle_oidc_callback(request)) diff --git a/tests/logging/__init__.py b/tests/logging/__init__.py
index e69de29bb2..a58d51441c 100644 --- a/tests/logging/__init__.py +++ b/tests/logging/__init__.py
@@ -0,0 +1,34 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + + +class LoggerCleanupMixin: + def get_logger(self, handler): + """ + Attach a handler to a logger and add clean-ups to remove revert this. + """ + # Create a logger and add the handler to it. + logger = logging.getLogger(__name__) + logger.addHandler(handler) + + # Ensure the logger actually logs something. + logger.setLevel(logging.INFO) + + # Ensure the logger gets cleaned-up appropriately. + self.addCleanup(logger.removeHandler, handler) + self.addCleanup(logger.setLevel, logging.NOTSET) + + return logger diff --git a/tests/logging/test_remote_handler.py b/tests/logging/test_remote_handler.py new file mode 100644
index 0000000000..4bc27a1d7d --- /dev/null +++ b/tests/logging/test_remote_handler.py
@@ -0,0 +1,169 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from twisted.test.proto_helpers import AccumulatingProtocol + +from synapse.logging import RemoteHandler + +from tests.logging import LoggerCleanupMixin +from tests.server import FakeTransport, get_clock +from tests.unittest import TestCase + + +def connect_logging_client(reactor, client_id): + # This is essentially tests.server.connect_client, but disabling autoflush on + # the client transport. This is necessary to avoid an infinite loop due to + # sending of data via the logging transport causing additional logs to be + # written. + factory = reactor.tcpClients.pop(client_id)[2] + client = factory.buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, reactor)) + client.makeConnection(FakeTransport(server, reactor, autoflush=False)) + + return client, server + + +class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase): + def setUp(self): + self.reactor, _ = get_clock() + + def test_log_output(self): + """ + The remote handler delivers logs over TCP. + """ + handler = RemoteHandler("127.0.0.1", 9000, _reactor=self.reactor) + logger = self.get_logger(handler) + + logger.info("Hello there, %s!", "wally") + + # Trigger the connection + client, server = connect_logging_client(self.reactor, 0) + + # Trigger data being sent + client.transport.flush() + + # One log message, with a single trailing newline + logs = server.data.decode("utf8").splitlines() + self.assertEqual(len(logs), 1) + self.assertEqual(server.data.count(b"\n"), 1) + + # Ensure the data passed through properly. + self.assertEqual(logs[0], "Hello there, wally!") + + def test_log_backpressure_debug(self): + """ + When backpressure is hit, DEBUG logs will be shed. + """ + handler = RemoteHandler( + "127.0.0.1", 9000, maximum_buffer=10, _reactor=self.reactor + ) + logger = self.get_logger(handler) + + # Send some debug messages + for i in range(0, 3): + logger.debug("debug %s" % (i,)) + + # Send a bunch of useful messages + for i in range(0, 7): + logger.info("info %s" % (i,)) + + # The last debug message pushes it past the maximum buffer + logger.debug("too much debug") + + # Allow the reconnection + client, server = connect_logging_client(self.reactor, 0) + client.transport.flush() + + # Only the 7 infos made it through, the debugs were elided + logs = server.data.splitlines() + self.assertEqual(len(logs), 7) + self.assertNotIn(b"debug", server.data) + + def test_log_backpressure_info(self): + """ + When backpressure is hit, DEBUG and INFO logs will be shed. + """ + handler = RemoteHandler( + "127.0.0.1", 9000, maximum_buffer=10, _reactor=self.reactor + ) + logger = self.get_logger(handler) + + # Send some debug messages + for i in range(0, 3): + logger.debug("debug %s" % (i,)) + + # Send a bunch of useful messages + for i in range(0, 10): + logger.warning("warn %s" % (i,)) + + # Send a bunch of info messages + for i in range(0, 3): + logger.info("info %s" % (i,)) + + # The last debug message pushes it past the maximum buffer + logger.debug("too much debug") + + # Allow the reconnection + client, server = connect_logging_client(self.reactor, 0) + client.transport.flush() + + # The 10 warnings made it through, the debugs and infos were elided + logs = server.data.splitlines() + self.assertEqual(len(logs), 10) + self.assertNotIn(b"debug", server.data) + self.assertNotIn(b"info", server.data) + + def test_log_backpressure_cut_middle(self): + """ + When backpressure is hit, and no more DEBUG and INFOs cannot be culled, + it will cut the middle messages out. + """ + handler = RemoteHandler( + "127.0.0.1", 9000, maximum_buffer=10, _reactor=self.reactor + ) + logger = self.get_logger(handler) + + # Send a bunch of useful messages + for i in range(0, 20): + logger.warning("warn %s" % (i,)) + + # Allow the reconnection + client, server = connect_logging_client(self.reactor, 0) + client.transport.flush() + + # The first five and last five warnings made it through, the debugs and + # infos were elided + logs = server.data.decode("utf8").splitlines() + self.assertEqual( + ["warn %s" % (i,) for i in range(5)] + + ["warn %s" % (i,) for i in range(15, 20)], + logs, + ) + + def test_cancel_connection(self): + """ + Gracefully handle the connection being cancelled. + """ + handler = RemoteHandler( + "127.0.0.1", 9000, maximum_buffer=10, _reactor=self.reactor + ) + logger = self.get_logger(handler) + + # Send a message. + logger.info("Hello there, %s!", "wally") + + # Do not accept the connection and shutdown. This causes the pending + # connection to be cancelled (and should not raise any exceptions). + handler.close() diff --git a/tests/logging/test_structured.py b/tests/logging/test_structured.py deleted file mode 100644
index d36f5f426c..0000000000 --- a/tests/logging/test_structured.py +++ /dev/null
@@ -1,214 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import os -import os.path -import shutil -import sys -import textwrap - -from twisted.logger import Logger, eventAsText, eventsFromJSONLogFile - -from synapse.config.logger import setup_logging -from synapse.logging._structured import setup_structured_logging -from synapse.logging.context import LoggingContext - -from tests.unittest import DEBUG, HomeserverTestCase - - -class FakeBeginner: - def beginLoggingTo(self, observers, **kwargs): - self.observers = observers - - -class StructuredLoggingTestBase: - """ - Test base that registers a cleanup handler to reset the stdlib log handler - to 'unset'. - """ - - def prepare(self, reactor, clock, hs): - def _cleanup(): - logging.getLogger("synapse").setLevel(logging.NOTSET) - - self.addCleanup(_cleanup) - - -class StructuredLoggingTestCase(StructuredLoggingTestBase, HomeserverTestCase): - """ - Tests for Synapse's structured logging support. - """ - - def test_output_to_json_round_trip(self): - """ - Synapse logs can be outputted to JSON and then read back again. - """ - temp_dir = self.mktemp() - os.mkdir(temp_dir) - self.addCleanup(shutil.rmtree, temp_dir) - - json_log_file = os.path.abspath(os.path.join(temp_dir, "out.json")) - - log_config = { - "drains": {"jsonfile": {"type": "file_json", "location": json_log_file}} - } - - # Begin the logger with our config - beginner = FakeBeginner() - setup_structured_logging( - self.hs, self.hs.config, log_config, logBeginner=beginner - ) - - # Make a logger and send an event - logger = Logger( - namespace="tests.logging.test_structured", observer=beginner.observers[0] - ) - logger.info("Hello there, {name}!", name="wally") - - # Read the log file and check it has the event we sent - with open(json_log_file, "r") as f: - logged_events = list(eventsFromJSONLogFile(f)) - self.assertEqual(len(logged_events), 1) - - # The event pulled from the file should render fine - self.assertEqual( - eventAsText(logged_events[0], includeTimestamp=False), - "[tests.logging.test_structured#info] Hello there, wally!", - ) - - def test_output_to_text(self): - """ - Synapse logs can be outputted to text. - """ - temp_dir = self.mktemp() - os.mkdir(temp_dir) - self.addCleanup(shutil.rmtree, temp_dir) - - log_file = os.path.abspath(os.path.join(temp_dir, "out.log")) - - log_config = {"drains": {"file": {"type": "file", "location": log_file}}} - - # Begin the logger with our config - beginner = FakeBeginner() - setup_structured_logging( - self.hs, self.hs.config, log_config, logBeginner=beginner - ) - - # Make a logger and send an event - logger = Logger( - namespace="tests.logging.test_structured", observer=beginner.observers[0] - ) - logger.info("Hello there, {name}!", name="wally") - - # Read the log file and check it has the event we sent - with open(log_file, "r") as f: - logged_events = f.read().strip().split("\n") - self.assertEqual(len(logged_events), 1) - - # The event pulled from the file should render fine - self.assertTrue( - logged_events[0].endswith( - " - tests.logging.test_structured - INFO - None - Hello there, wally!" - ) - ) - - def test_collects_logcontext(self): - """ - Test that log outputs have the attached logging context. - """ - log_config = {"drains": {}} - - # Begin the logger with our config - beginner = FakeBeginner() - publisher = setup_structured_logging( - self.hs, self.hs.config, log_config, logBeginner=beginner - ) - - logs = [] - - publisher.addObserver(logs.append) - - # Make a logger and send an event - logger = Logger( - namespace="tests.logging.test_structured", observer=beginner.observers[0] - ) - - with LoggingContext("testcontext", request="somereq"): - logger.info("Hello there, {name}!", name="steve") - - self.assertEqual(len(logs), 1) - self.assertEqual(logs[0]["request"], "somereq") - - -class StructuredLoggingConfigurationFileTestCase( - StructuredLoggingTestBase, HomeserverTestCase -): - def make_homeserver(self, reactor, clock): - - tempdir = self.mktemp() - os.mkdir(tempdir) - log_config_file = os.path.abspath(os.path.join(tempdir, "log.config.yaml")) - self.homeserver_log = os.path.abspath(os.path.join(tempdir, "homeserver.log")) - - config = self.default_config() - config["log_config"] = log_config_file - - with open(log_config_file, "w") as f: - f.write( - textwrap.dedent( - """\ - structured: true - - drains: - file: - type: file_json - location: %s - """ - % (self.homeserver_log,) - ) - ) - - self.addCleanup(self._sys_cleanup) - - return self.setup_test_homeserver(config=config) - - def _sys_cleanup(self): - sys.stdout = sys.__stdout__ - sys.stderr = sys.__stderr__ - - # Do not remove! We need the logging system to be set other than WARNING. - @DEBUG - def test_log_output(self): - """ - When a structured logging config is given, Synapse will use it. - """ - beginner = FakeBeginner() - publisher = setup_logging(self.hs, self.hs.config, logBeginner=beginner) - - # Make a logger and send an event - logger = Logger(namespace="tests.logging.test_structured", observer=publisher) - - with LoggingContext("testcontext", request="somereq"): - logger.info("Hello there, {name}!", name="steve") - - with open(self.homeserver_log, "r") as f: - logged_events = [ - eventAsText(x, includeTimestamp=False) for x in eventsFromJSONLogFile(f) - ] - - logs = "\n".join(logged_events) - self.assertTrue("***** STARTING SERVER *****" in logs) - self.assertTrue("Hello there, steve!" in logs) diff --git a/tests/logging/test_terse_json.py b/tests/logging/test_terse_json.py
index fd128b88e0..73f469b802 100644 --- a/tests/logging/test_terse_json.py +++ b/tests/logging/test_terse_json.py
@@ -14,57 +14,33 @@ # limitations under the License. import json -from collections import Counter +import logging +from io import StringIO -from twisted.logger import Logger +from synapse.logging._terse_json import JsonFormatter, TerseJsonFormatter -from synapse.logging._structured import setup_structured_logging +from tests.logging import LoggerCleanupMixin +from tests.unittest import TestCase -from tests.server import connect_client -from tests.unittest import HomeserverTestCase -from .test_structured import FakeBeginner, StructuredLoggingTestBase - - -class TerseJSONTCPTestCase(StructuredLoggingTestBase, HomeserverTestCase): - def test_log_output(self): +class TerseJsonTestCase(LoggerCleanupMixin, TestCase): + def test_terse_json_output(self): """ - The Terse JSON outputter delivers simplified structured logs over TCP. + The Terse JSON formatter converts log messages to JSON. """ - log_config = { - "drains": { - "tersejson": { - "type": "network_json_terse", - "host": "127.0.0.1", - "port": 8000, - } - } - } - - # Begin the logger with our config - beginner = FakeBeginner() - setup_structured_logging( - self.hs, self.hs.config, log_config, logBeginner=beginner - ) - - logger = Logger( - namespace="tests.logging.test_terse_json", observer=beginner.observers[0] - ) - logger.info("Hello there, {name}!", name="wally") - - # Trigger the connection - self.pump() + output = StringIO() - _, server = connect_client(self.reactor, 0) + handler = logging.StreamHandler(output) + handler.setFormatter(TerseJsonFormatter()) + logger = self.get_logger(handler) - # Trigger data being sent - self.pump() + logger.info("Hello there, %s!", "wally") - # One log message, with a single trailing newline - logs = server.data.decode("utf8").splitlines() + # One log message, with a single trailing newline. + data = output.getvalue() + logs = data.splitlines() self.assertEqual(len(logs), 1) - self.assertEqual(server.data.count(b"\n"), 1) - + self.assertEqual(data.count("\n"), 1) log = json.loads(logs[0]) # The terse logger should give us these keys. @@ -72,163 +48,74 @@ class TerseJSONTCPTestCase(StructuredLoggingTestBase, HomeserverTestCase): "log", "time", "level", - "log_namespace", - "request", - "scope", - "server_name", - "name", + "namespace", ] self.assertCountEqual(log.keys(), expected_log_keys) + self.assertEqual(log["log"], "Hello there, wally!") - # It contains the data we expect. - self.assertEqual(log["name"], "wally") - - def test_log_backpressure_debug(self): + def test_extra_data(self): """ - When backpressure is hit, DEBUG logs will be shed. + Additional information can be included in the structured logging. """ - log_config = { - "loggers": {"synapse": {"level": "DEBUG"}}, - "drains": { - "tersejson": { - "type": "network_json_terse", - "host": "127.0.0.1", - "port": 8000, - "maximum_buffer": 10, - } - }, - } - - # Begin the logger with our config - beginner = FakeBeginner() - setup_structured_logging( - self.hs, - self.hs.config, - log_config, - logBeginner=beginner, - redirect_stdlib_logging=False, - ) - - logger = Logger( - namespace="synapse.logging.test_terse_json", observer=beginner.observers[0] - ) + output = StringIO() - # Send some debug messages - for i in range(0, 3): - logger.debug("debug %s" % (i,)) + handler = logging.StreamHandler(output) + handler.setFormatter(TerseJsonFormatter()) + logger = self.get_logger(handler) - # Send a bunch of useful messages - for i in range(0, 7): - logger.info("test message %s" % (i,)) - - # The last debug message pushes it past the maximum buffer - logger.debug("too much debug") - - # Allow the reconnection - _, server = connect_client(self.reactor, 0) - self.pump() - - # Only the 7 infos made it through, the debugs were elided - logs = server.data.splitlines() - self.assertEqual(len(logs), 7) - - def test_log_backpressure_info(self): - """ - When backpressure is hit, DEBUG and INFO logs will be shed. - """ - log_config = { - "loggers": {"synapse": {"level": "DEBUG"}}, - "drains": { - "tersejson": { - "type": "network_json_terse", - "host": "127.0.0.1", - "port": 8000, - "maximum_buffer": 10, - } - }, - } - - # Begin the logger with our config - beginner = FakeBeginner() - setup_structured_logging( - self.hs, - self.hs.config, - log_config, - logBeginner=beginner, - redirect_stdlib_logging=False, - ) - - logger = Logger( - namespace="synapse.logging.test_terse_json", observer=beginner.observers[0] + logger.info( + "Hello there, %s!", "wally", extra={"foo": "bar", "int": 3, "bool": True} ) - # Send some debug messages - for i in range(0, 3): - logger.debug("debug %s" % (i,)) - - # Send a bunch of useful messages - for i in range(0, 10): - logger.warn("test warn %s" % (i,)) - - # Send a bunch of info messages - for i in range(0, 3): - logger.info("test message %s" % (i,)) - - # The last debug message pushes it past the maximum buffer - logger.debug("too much debug") - - # Allow the reconnection - client, server = connect_client(self.reactor, 0) - self.pump() + # One log message, with a single trailing newline. + data = output.getvalue() + logs = data.splitlines() + self.assertEqual(len(logs), 1) + self.assertEqual(data.count("\n"), 1) + log = json.loads(logs[0]) - # The 10 warnings made it through, the debugs and infos were elided - logs = list(map(json.loads, server.data.decode("utf8").splitlines())) - self.assertEqual(len(logs), 10) + # The terse logger should give us these keys. + expected_log_keys = [ + "log", + "time", + "level", + "namespace", + # The additional keys given via extra. + "foo", + "int", + "bool", + ] + self.assertCountEqual(log.keys(), expected_log_keys) - self.assertEqual(Counter([x["level"] for x in logs]), {"WARN": 10}) + # Check the values of the extra fields. + self.assertEqual(log["foo"], "bar") + self.assertEqual(log["int"], 3) + self.assertIs(log["bool"], True) - def test_log_backpressure_cut_middle(self): + def test_json_output(self): """ - When backpressure is hit, and no more DEBUG and INFOs cannot be culled, - it will cut the middle messages out. + The Terse JSON formatter converts log messages to JSON. """ - log_config = { - "loggers": {"synapse": {"level": "DEBUG"}}, - "drains": { - "tersejson": { - "type": "network_json_terse", - "host": "127.0.0.1", - "port": 8000, - "maximum_buffer": 10, - } - }, - } - - # Begin the logger with our config - beginner = FakeBeginner() - setup_structured_logging( - self.hs, - self.hs.config, - log_config, - logBeginner=beginner, - redirect_stdlib_logging=False, - ) + output = StringIO() - logger = Logger( - namespace="synapse.logging.test_terse_json", observer=beginner.observers[0] - ) + handler = logging.StreamHandler(output) + handler.setFormatter(JsonFormatter()) + logger = self.get_logger(handler) - # Send a bunch of useful messages - for i in range(0, 20): - logger.warn("test warn", num=i) + logger.info("Hello there, %s!", "wally") - # Allow the reconnection - client, server = connect_client(self.reactor, 0) - self.pump() + # One log message, with a single trailing newline. + data = output.getvalue() + logs = data.splitlines() + self.assertEqual(len(logs), 1) + self.assertEqual(data.count("\n"), 1) + log = json.loads(logs[0]) - # The first five and last five warnings made it through, the debugs and - # infos were elided - logs = list(map(json.loads, server.data.decode("utf8").splitlines())) - self.assertEqual(len(logs), 10) - self.assertEqual(Counter([x["level"] for x in logs]), {"WARN": 10}) - self.assertEqual([0, 1, 2, 3, 4, 15, 16, 17, 18, 19], [x["num"] for x in logs]) + # The terse logger should give us these keys. + expected_log_keys = [ + "log", + "level", + "namespace", + ] + self.assertCountEqual(log.keys(), expected_log_keys) + self.assertEqual(log["log"], "Hello there, wally!") diff --git a/tests/push/test_email.py b/tests/push/test_email.py
index 55545d9341..bcdcafa5a9 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py
@@ -100,7 +100,7 @@ class EmailPusherTests(HomeserverTestCase): user_tuple = self.get_success( self.hs.get_datastore().get_user_by_access_token(self.access_token) ) - token_id = user_tuple["token_id"] + token_id = user_tuple.token_id self.pusher = self.get_success( self.hs.get_pusherpool().add_pusher( @@ -131,6 +131,35 @@ class EmailPusherTests(HomeserverTestCase): # We should get emailed about that message self._check_for_mail() + def test_invite_sends_email(self): + # Create a room and invite the user to it + room = self.helper.create_room_as(self.others[0].id, tok=self.others[0].token) + self.helper.invite( + room=room, + src=self.others[0].id, + tok=self.others[0].token, + targ=self.user_id, + ) + + # We should get emailed about the invite + self._check_for_mail() + + def test_invite_to_empty_room_sends_email(self): + # Create a room and invite the user to it + room = self.helper.create_room_as(self.others[0].id, tok=self.others[0].token) + self.helper.invite( + room=room, + src=self.others[0].id, + tok=self.others[0].token, + targ=self.user_id, + ) + + # Then have the original user leave + self.helper.leave(room, self.others[0].id, tok=self.others[0].token) + + # We should get emailed about the invite + self._check_for_mail() + def test_multiple_members_email(self): # We want to test multiple notifications, so we pause processing of push # while we send messages. diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index b567868b02..8571924b29 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py
@@ -69,7 +69,7 @@ class HTTPPusherTests(HomeserverTestCase): user_tuple = self.get_success( self.hs.get_datastore().get_user_by_access_token(access_token) ) - token_id = user_tuple["token_id"] + token_id = user_tuple.token_id self.get_success( self.hs.get_pusherpool().add_pusher( @@ -181,7 +181,7 @@ class HTTPPusherTests(HomeserverTestCase): user_tuple = self.get_success( self.hs.get_datastore().get_user_by_access_token(access_token) ) - token_id = user_tuple["token_id"] + token_id = user_tuple.token_id self.get_success( self.hs.get_pusherpool().add_pusher( @@ -297,7 +297,7 @@ class HTTPPusherTests(HomeserverTestCase): user_tuple = self.get_success( self.hs.get_datastore().get_user_by_access_token(access_token) ) - token_id = user_tuple["token_id"] + token_id = user_tuple.token_id self.get_success( self.hs.get_pusherpool().add_pusher( @@ -379,7 +379,7 @@ class HTTPPusherTests(HomeserverTestCase): user_tuple = self.get_success( self.hs.get_datastore().get_user_by_access_token(access_token) ) - token_id = user_tuple["token_id"] + token_id = user_tuple.token_id self.get_success( self.hs.get_pusherpool().add_pusher( @@ -452,7 +452,7 @@ class HTTPPusherTests(HomeserverTestCase): user_tuple = self.get_success( self.hs.get_datastore().get_user_by_access_token(access_token) ) - token_id = user_tuple["token_id"] + token_id = user_tuple.token_id self.get_success( self.hs.get_pusherpool().add_pusher( diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 093e2faac7..5c633ac6df 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py
@@ -16,7 +16,6 @@ import logging from typing import Any, Callable, List, Optional, Tuple import attr -import hiredis from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime from twisted.internet.protocol import Protocol @@ -39,12 +38,22 @@ from synapse.util import Clock from tests import unittest from tests.server import FakeTransport, render +try: + import hiredis +except ImportError: + hiredis = None + logger = logging.getLogger(__name__) class BaseStreamTestCase(unittest.HomeserverTestCase): """Base class for tests of the replication streams""" + # hiredis is an optional dependency so we don't want to require it for running + # the tests. + if not hiredis: + skip = "Requires hiredis" + servlets = [ streams.register_servlets, ] @@ -269,7 +278,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): homeserver_to_use=GenericWorkerServer, config=config, reactor=self.reactor, - **kwargs + **kwargs, ) # If the instance is in the `instance_map` config then workers may try diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py
index c9998e88e6..bad0df08cf 100644 --- a/tests/replication/tcp/streams/test_events.py +++ b/tests/replication/tcp/streams/test_events.py
@@ -449,7 +449,7 @@ class EventsStreamTestCase(BaseStreamTestCase): sender=sender, type="test_event", content={"body": body}, - **kwargs + **kwargs, ) ) diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py new file mode 100644
index 0000000000..77c261dbf7 --- /dev/null +++ b/tests/replication/test_multi_media_repo.py
@@ -0,0 +1,277 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os +from binascii import unhexlify +from typing import Tuple + +from twisted.internet.protocol import Factory +from twisted.protocols.tls import TLSMemoryBIOFactory +from twisted.web.http import HTTPChannel +from twisted.web.server import Request + +from synapse.rest import admin +from synapse.rest.client.v1 import login +from synapse.server import HomeServer + +from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file +from tests.replication._base import BaseMultiWorkerStreamTestCase +from tests.server import FakeChannel, FakeTransport + +logger = logging.getLogger(__name__) + +test_server_connection_factory = None + + +class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase): + """Checks running multiple media repos work correctly. + """ + + servlets = [ + admin.register_servlets_for_client_rest_resource, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.user_id = self.register_user("user", "pass") + self.access_token = self.login("user", "pass") + + self.reactor.lookups["example.com"] = "127.0.0.2" + + def default_config(self): + conf = super().default_config() + conf["federation_custom_ca_list"] = [get_test_ca_cert_file()] + return conf + + def _get_media_req( + self, hs: HomeServer, target: str, media_id: str + ) -> Tuple[FakeChannel, Request]: + """Request some remote media from the given HS by calling the download + API. + + This then triggers an outbound request from the HS to the target. + + Returns: + The channel for the *client* request and the *outbound* request for + the media which the caller should respond to. + """ + + request, channel = self.make_request( + "GET", + "/{}/{}".format(target, media_id), + shorthand=False, + access_token=self.access_token, + ) + request.render(hs.get_media_repository_resource().children[b"download"]) + self.pump() + + clients = self.reactor.tcpClients + self.assertGreaterEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients.pop() + + # build the test server + server_tls_protocol = _build_test_server(get_connection_factory()) + + # now, tell the client protocol factory to build the client protocol (it will be a + # _WrappingProtocol, around a TLSMemoryBIOProtocol, around an + # HTTP11ClientProtocol) and wire the output of said protocol up to the server via + # a FakeTransport. + # + # Normally this would be done by the TCP socket code in Twisted, but we are + # stubbing that out here. + client_protocol = client_factory.buildProtocol(None) + client_protocol.makeConnection( + FakeTransport(server_tls_protocol, self.reactor, client_protocol) + ) + + # tell the server tls protocol to send its stuff back to the client, too + server_tls_protocol.makeConnection( + FakeTransport(client_protocol, self.reactor, server_tls_protocol) + ) + + # fish the test server back out of the server-side TLS protocol. + http_server = server_tls_protocol.wrappedProtocol + + # give the reactor a pump to get the TLS juices flowing. + self.reactor.pump((0.1,)) + + self.assertEqual(len(http_server.requests), 1) + request = http_server.requests[0] + + self.assertEqual(request.method, b"GET") + self.assertEqual( + request.path, + "/_matrix/media/r0/download/{}/{}".format(target, media_id).encode("utf-8"), + ) + self.assertEqual( + request.requestHeaders.getRawHeaders(b"host"), [target.encode("utf-8")] + ) + + return channel, request + + def test_basic(self): + """Test basic fetching of remote media from a single worker. + """ + hs1 = self.make_worker_hs("synapse.app.generic_worker") + + channel, request = self._get_media_req(hs1, "example.com:443", "ABC123") + + request.setResponseCode(200) + request.responseHeaders.setRawHeaders(b"Content-Type", [b"text/plain"]) + request.write(b"Hello!") + request.finish() + + self.pump(0.1) + + self.assertEqual(channel.code, 200) + self.assertEqual(channel.result["body"], b"Hello!") + + def test_download_simple_file_race(self): + """Test that fetching remote media from two different processes at the + same time works. + """ + hs1 = self.make_worker_hs("synapse.app.generic_worker") + hs2 = self.make_worker_hs("synapse.app.generic_worker") + + start_count = self._count_remote_media() + + # Make two requests without responding to the outbound media requests. + channel1, request1 = self._get_media_req(hs1, "example.com:443", "ABC123") + channel2, request2 = self._get_media_req(hs2, "example.com:443", "ABC123") + + # Respond to the first outbound media request and check that the client + # request is successful + request1.setResponseCode(200) + request1.responseHeaders.setRawHeaders(b"Content-Type", [b"text/plain"]) + request1.write(b"Hello!") + request1.finish() + + self.pump(0.1) + + self.assertEqual(channel1.code, 200, channel1.result["body"]) + self.assertEqual(channel1.result["body"], b"Hello!") + + # Now respond to the second with the same content. + request2.setResponseCode(200) + request2.responseHeaders.setRawHeaders(b"Content-Type", [b"text/plain"]) + request2.write(b"Hello!") + request2.finish() + + self.pump(0.1) + + self.assertEqual(channel2.code, 200, channel2.result["body"]) + self.assertEqual(channel2.result["body"], b"Hello!") + + # We expect only one new file to have been persisted. + self.assertEqual(start_count + 1, self._count_remote_media()) + + def test_download_image_race(self): + """Test that fetching remote *images* from two different processes at + the same time works. + + This checks that races generating thumbnails are handled correctly. + """ + hs1 = self.make_worker_hs("synapse.app.generic_worker") + hs2 = self.make_worker_hs("synapse.app.generic_worker") + + start_count = self._count_remote_thumbnails() + + channel1, request1 = self._get_media_req(hs1, "example.com:443", "PIC1") + channel2, request2 = self._get_media_req(hs2, "example.com:443", "PIC1") + + png_data = unhexlify( + b"89504e470d0a1a0a0000000d4948445200000001000000010806" + b"0000001f15c4890000000a49444154789c63000100000500010d" + b"0a2db40000000049454e44ae426082" + ) + + request1.setResponseCode(200) + request1.responseHeaders.setRawHeaders(b"Content-Type", [b"image/png"]) + request1.write(png_data) + request1.finish() + + self.pump(0.1) + + self.assertEqual(channel1.code, 200, channel1.result["body"]) + self.assertEqual(channel1.result["body"], png_data) + + request2.setResponseCode(200) + request2.responseHeaders.setRawHeaders(b"Content-Type", [b"image/png"]) + request2.write(png_data) + request2.finish() + + self.pump(0.1) + + self.assertEqual(channel2.code, 200, channel2.result["body"]) + self.assertEqual(channel2.result["body"], png_data) + + # We expect only three new thumbnails to have been persisted. + self.assertEqual(start_count + 3, self._count_remote_thumbnails()) + + def _count_remote_media(self) -> int: + """Count the number of files in our remote media directory. + """ + path = os.path.join( + self.hs.get_media_repository().primary_base_path, "remote_content" + ) + return sum(len(files) for _, _, files in os.walk(path)) + + def _count_remote_thumbnails(self) -> int: + """Count the number of files in our remote thumbnails directory. + """ + path = os.path.join( + self.hs.get_media_repository().primary_base_path, "remote_thumbnail" + ) + return sum(len(files) for _, _, files in os.walk(path)) + + +def get_connection_factory(): + # this needs to happen once, but not until we are ready to run the first test + global test_server_connection_factory + if test_server_connection_factory is None: + test_server_connection_factory = TestServerTLSConnectionFactory( + sanlist=[b"DNS:example.com"] + ) + return test_server_connection_factory + + +def _build_test_server(connection_creator): + """Construct a test server + + This builds an HTTP channel, wrapped with a TLSMemoryBIOProtocol + + Args: + connection_creator (IOpenSSLServerConnectionCreator): thing to build + SSL connections + sanlist (list[bytes]): list of the SAN entries for the cert returned + by the server + + Returns: + TLSMemoryBIOProtocol + """ + server_factory = Factory.forProtocol(HTTPChannel) + # Request.finish expects the factory to have a 'log' method. + server_factory.log = _log_request + + server_tls_factory = TLSMemoryBIOFactory( + connection_creator, isClient=False, wrappedFactory=server_factory + ) + + return server_tls_factory.buildProtocol(None) + + +def _log_request(request): + """Implements Factory.log, which is expected by Request.finish""" + logger.info("Completed request %s", request) diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py
index 2bdc6edbb1..67c27a089f 100644 --- a/tests/replication/test_pusher_shard.py +++ b/tests/replication/test_pusher_shard.py
@@ -55,7 +55,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase): user_dict = self.get_success( self.hs.get_datastore().get_user_by_access_token(access_token) ) - token_id = user_dict["token_id"] + token_id = user_dict.token_id self.get_success( self.hs.get_pusherpool().add_pusher( diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py
index 92c9058887..d89eb90cfe 100644 --- a/tests/rest/admin/test_device.py +++ b/tests/rest/admin/test_device.py
@@ -393,6 +393,22 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) + def test_user_has_no_devices(self): + """ + Tests that a normal lookup for devices is successfully + if user has no devices + """ + + # Get devices + request, channel = self.make_request( + "GET", self.url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(0, channel.json_body["total"]) + self.assertEqual(0, len(channel.json_body["devices"])) + def test_get_devices(self): """ Tests that a normal lookup for devices is successfully @@ -409,6 +425,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): self.render(request) self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(number_devices, channel.json_body["total"]) self.assertEqual(number_devices, len(channel.json_body["devices"])) self.assertEqual(self.other_user, channel.json_body["devices"][0]["user_id"]) # Check that all fields are available diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py
index bf79086f78..303622217f 100644 --- a/tests/rest/admin/test_event_reports.py +++ b/tests/rest/admin/test_event_reports.py
@@ -70,6 +70,16 @@ class EventReportsTestCase(unittest.HomeserverTestCase): self.url = "/_synapse/admin/v1/event_reports" + def test_no_auth(self): + """ + Try to get an event report without authentication. + """ + request, channel = self.make_request("GET", self.url, b"{}") + self.render(request) + + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + def test_requester_is_no_admin(self): """ If the user is not a server admin, an error 403 is returned. @@ -266,7 +276,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): def test_limit_is_negative(self): """ - Testing that a negative list parameter returns a 400 + Testing that a negative limit parameter returns a 400 """ request, channel = self.make_request( @@ -360,7 +370,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) def _check_fields(self, content): - """Checks that all attributes are present in a event report + """Checks that all attributes are present in an event report """ for c in content: self.assertIn("id", c) @@ -368,15 +378,175 @@ class EventReportsTestCase(unittest.HomeserverTestCase): self.assertIn("room_id", c) self.assertIn("event_id", c) self.assertIn("user_id", c) - self.assertIn("reason", c) - self.assertIn("content", c) self.assertIn("sender", c) - self.assertIn("room_alias", c) - self.assertIn("event_json", c) - self.assertIn("score", c["content"]) - self.assertIn("reason", c["content"]) - self.assertIn("auth_events", c["event_json"]) - self.assertIn("type", c["event_json"]) - self.assertIn("room_id", c["event_json"]) - self.assertIn("sender", c["event_json"]) - self.assertIn("content", c["event_json"]) + self.assertIn("canonical_alias", c) + self.assertIn("name", c) + self.assertIn("score", c) + self.assertIn("reason", c) + + +class EventReportDetailTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + report_event.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.other_user_tok = self.login("user", "pass") + + self.room_id1 = self.helper.create_room_as( + self.other_user, tok=self.other_user_tok, is_public=True + ) + self.helper.join(self.room_id1, user=self.admin_user, tok=self.admin_user_tok) + + self._create_event_and_report( + room_id=self.room_id1, user_tok=self.other_user_tok, + ) + + # first created event report gets `id`=2 + self.url = "/_synapse/admin/v1/event_reports/2" + + def test_no_auth(self): + """ + Try to get event report without authentication. + """ + request, channel = self.make_request("GET", self.url, b"{}") + self.render(request) + + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_requester_is_no_admin(self): + """ + If the user is not a server admin, an error 403 is returned. + """ + + request, channel = self.make_request( + "GET", self.url, access_token=self.other_user_tok, + ) + self.render(request) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_default_success(self): + """ + Testing get a reported event + """ + + request, channel = self.make_request( + "GET", self.url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self._check_fields(channel.json_body) + + def test_invalid_report_id(self): + """ + Testing that an invalid `report_id` returns a 400. + """ + + # `report_id` is negative + request, channel = self.make_request( + "GET", + "/_synapse/admin/v1/event_reports/-123", + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + self.assertEqual( + "The report_id parameter must be a string representing a positive integer.", + channel.json_body["error"], + ) + + # `report_id` is a non-numerical string + request, channel = self.make_request( + "GET", + "/_synapse/admin/v1/event_reports/abcdef", + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + self.assertEqual( + "The report_id parameter must be a string representing a positive integer.", + channel.json_body["error"], + ) + + # `report_id` is undefined + request, channel = self.make_request( + "GET", + "/_synapse/admin/v1/event_reports/", + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + self.assertEqual( + "The report_id parameter must be a string representing a positive integer.", + channel.json_body["error"], + ) + + def test_report_id_not_found(self): + """ + Testing that a not existing `report_id` returns a 404. + """ + + request, channel = self.make_request( + "GET", + "/_synapse/admin/v1/event_reports/123", + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + self.assertEqual("Event report not found", channel.json_body["error"]) + + def _create_event_and_report(self, room_id, user_tok): + """Create and report events + """ + resp = self.helper.send(room_id, tok=user_tok) + event_id = resp["event_id"] + + request, channel = self.make_request( + "POST", + "rooms/%s/report/%s" % (room_id, event_id), + json.dumps({"score": -100, "reason": "this makes me sad"}), + access_token=user_tok, + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + def _check_fields(self, content): + """Checks that all attributes are present in a event report + """ + self.assertIn("id", content) + self.assertIn("received_ts", content) + self.assertIn("room_id", content) + self.assertIn("event_id", content) + self.assertIn("user_id", content) + self.assertIn("sender", content) + self.assertIn("canonical_alias", content) + self.assertIn("name", content) + self.assertIn("event_json", content) + self.assertIn("score", content) + self.assertIn("reason", content) + self.assertIn("auth_events", content["event_json"]) + self.assertIn("type", content["event_json"]) + self.assertIn("room_id", content["event_json"]) + self.assertIn("sender", content["event_json"]) + self.assertIn("content", content["event_json"]) diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py new file mode 100644
index 0000000000..721fa1ed51 --- /dev/null +++ b/tests/rest/admin/test_media.py
@@ -0,0 +1,568 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Dirk Klimpel +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +from binascii import unhexlify + +import synapse.rest.admin +from synapse.api.errors import Codes +from synapse.rest.client.v1 import login, profile, room +from synapse.rest.media.v1.filepath import MediaFilePaths + +from tests import unittest + + +class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets, + synapse.rest.admin.register_servlets_for_media_repo, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.handler = hs.get_device_handler() + self.media_repo = hs.get_media_repository_resource() + self.server_name = hs.hostname + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.filepaths = MediaFilePaths(hs.config.media_store_path) + + def test_no_auth(self): + """ + Try to delete media without authentication. + """ + url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, "12345") + + request, channel = self.make_request("DELETE", url, b"{}") + self.render(request) + + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_requester_is_no_admin(self): + """ + If the user is not a server admin, an error is returned. + """ + self.other_user = self.register_user("user", "pass") + self.other_user_token = self.login("user", "pass") + + url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, "12345") + + request, channel = self.make_request( + "DELETE", url, access_token=self.other_user_token, + ) + self.render(request) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_media_does_not_exist(self): + """ + Tests that a lookup for a media that does not exist returns a 404 + """ + url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, "12345") + + request, channel = self.make_request( + "DELETE", url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + def test_media_is_not_local(self): + """ + Tests that a lookup for a media that is not a local returns a 400 + """ + url = "/_synapse/admin/v1/media/%s/%s" % ("unknown_domain", "12345") + + request, channel = self.make_request( + "DELETE", url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual("Can only delete local media", channel.json_body["error"]) + + def test_delete_media(self): + """ + Tests that delete a media is successfully + """ + + download_resource = self.media_repo.children[b"download"] + upload_resource = self.media_repo.children[b"upload"] + image_data = unhexlify( + b"89504e470d0a1a0a0000000d4948445200000001000000010806" + b"0000001f15c4890000000a49444154789c63000100000500010d" + b"0a2db40000000049454e44ae426082" + ) + + # Upload some media into the room + response = self.helper.upload_media( + upload_resource, image_data, tok=self.admin_user_tok, expect_code=200 + ) + # Extract media ID from the response + server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://' + server_name, media_id = server_and_media_id.split("/") + + self.assertEqual(server_name, self.server_name) + + # Attempt to access media + request, channel = self.make_request( + "GET", + server_and_media_id, + shorthand=False, + access_token=self.admin_user_tok, + ) + request.render(download_resource) + self.pump(1.0) + + # Should be successful + self.assertEqual( + 200, + channel.code, + msg=( + "Expected to receive a 200 on accessing media: %s" % server_and_media_id + ), + ) + + # Test if the file exists + local_path = self.filepaths.local_media_filepath(media_id) + self.assertTrue(os.path.exists(local_path)) + + url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, media_id) + + # Delete media + request, channel = self.make_request( + "DELETE", url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(1, channel.json_body["total"]) + self.assertEqual( + media_id, channel.json_body["deleted_media"][0], + ) + + # Attempt to access media + request, channel = self.make_request( + "GET", + server_and_media_id, + shorthand=False, + access_token=self.admin_user_tok, + ) + request.render(download_resource) + self.pump(1.0) + self.assertEqual( + 404, + channel.code, + msg=( + "Expected to receive a 404 on accessing deleted media: %s" + % server_and_media_id + ), + ) + + # Test if the file is deleted + self.assertFalse(os.path.exists(local_path)) + + +class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets, + synapse.rest.admin.register_servlets_for_media_repo, + login.register_servlets, + profile.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.handler = hs.get_device_handler() + self.media_repo = hs.get_media_repository_resource() + self.server_name = hs.hostname + self.clock = hs.clock + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.filepaths = MediaFilePaths(hs.config.media_store_path) + self.url = "/_synapse/admin/v1/media/%s/delete" % self.server_name + + def test_no_auth(self): + """ + Try to delete media without authentication. + """ + + request, channel = self.make_request("POST", self.url, b"{}") + self.render(request) + + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_requester_is_no_admin(self): + """ + If the user is not a server admin, an error is returned. + """ + self.other_user = self.register_user("user", "pass") + self.other_user_token = self.login("user", "pass") + + request, channel = self.make_request( + "POST", self.url, access_token=self.other_user_token, + ) + self.render(request) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_media_is_not_local(self): + """ + Tests that a lookup for media that is not local returns a 400 + """ + url = "/_synapse/admin/v1/media/%s/delete" % "unknown_domain" + + request, channel = self.make_request( + "POST", url + "?before_ts=1234", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual("Can only delete local media", channel.json_body["error"]) + + def test_missing_parameter(self): + """ + If the parameter `before_ts` is missing, an error is returned. + """ + request, channel = self.make_request( + "POST", self.url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) + self.assertEqual( + "Missing integer query parameter b'before_ts'", channel.json_body["error"] + ) + + def test_invalid_parameter(self): + """ + If parameters are invalid, an error is returned. + """ + request, channel = self.make_request( + "POST", self.url + "?before_ts=-1234", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + self.assertEqual( + "Query parameter before_ts must be a string representing a positive integer.", + channel.json_body["error"], + ) + + request, channel = self.make_request( + "POST", + self.url + "?before_ts=1234&size_gt=-1234", + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + self.assertEqual( + "Query parameter size_gt must be a string representing a positive integer.", + channel.json_body["error"], + ) + + request, channel = self.make_request( + "POST", + self.url + "?before_ts=1234&keep_profiles=not_bool", + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + self.assertEqual( + "Boolean query parameter b'keep_profiles' must be one of ['true', 'false']", + channel.json_body["error"], + ) + + def test_delete_media_never_accessed(self): + """ + Tests that media deleted if it is older than `before_ts` and never accessed + `last_access_ts` is `NULL` and `created_ts` < `before_ts` + """ + + # upload and do not access + server_and_media_id = self._create_media() + self.pump(1.0) + + # test that the file exists + media_id = server_and_media_id.split("/")[1] + local_path = self.filepaths.local_media_filepath(media_id) + self.assertTrue(os.path.exists(local_path)) + + # timestamp after upload/create + now_ms = self.clock.time_msec() + request, channel = self.make_request( + "POST", + self.url + "?before_ts=" + str(now_ms), + access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(1, channel.json_body["total"]) + self.assertEqual( + media_id, channel.json_body["deleted_media"][0], + ) + + self._access_media(server_and_media_id, False) + + def test_keep_media_by_date(self): + """ + Tests that media is not deleted if it is newer than `before_ts` + """ + + # timestamp before upload + now_ms = self.clock.time_msec() + server_and_media_id = self._create_media() + + self._access_media(server_and_media_id) + + request, channel = self.make_request( + "POST", + self.url + "?before_ts=" + str(now_ms), + access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(0, channel.json_body["total"]) + + self._access_media(server_and_media_id) + + # timestamp after upload + now_ms = self.clock.time_msec() + request, channel = self.make_request( + "POST", + self.url + "?before_ts=" + str(now_ms), + access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(1, channel.json_body["total"]) + self.assertEqual( + server_and_media_id.split("/")[1], channel.json_body["deleted_media"][0], + ) + + self._access_media(server_and_media_id, False) + + def test_keep_media_by_size(self): + """ + Tests that media is not deleted if its size is smaller than or equal + to `size_gt` + """ + server_and_media_id = self._create_media() + + self._access_media(server_and_media_id) + + now_ms = self.clock.time_msec() + request, channel = self.make_request( + "POST", + self.url + "?before_ts=" + str(now_ms) + "&size_gt=67", + access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(0, channel.json_body["total"]) + + self._access_media(server_and_media_id) + + now_ms = self.clock.time_msec() + request, channel = self.make_request( + "POST", + self.url + "?before_ts=" + str(now_ms) + "&size_gt=66", + access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(1, channel.json_body["total"]) + self.assertEqual( + server_and_media_id.split("/")[1], channel.json_body["deleted_media"][0], + ) + + self._access_media(server_and_media_id, False) + + def test_keep_media_by_user_avatar(self): + """ + Tests that we do not delete media if is used as a user avatar + Tests parameter `keep_profiles` + """ + server_and_media_id = self._create_media() + + self._access_media(server_and_media_id) + + # set media as avatar + request, channel = self.make_request( + "PUT", + "/profile/%s/avatar_url" % (self.admin_user,), + content=json.dumps({"avatar_url": "mxc://%s" % (server_and_media_id,)}), + access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, channel.code, msg=channel.json_body) + + now_ms = self.clock.time_msec() + request, channel = self.make_request( + "POST", + self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=true", + access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(0, channel.json_body["total"]) + + self._access_media(server_and_media_id) + + now_ms = self.clock.time_msec() + request, channel = self.make_request( + "POST", + self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=false", + access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(1, channel.json_body["total"]) + self.assertEqual( + server_and_media_id.split("/")[1], channel.json_body["deleted_media"][0], + ) + + self._access_media(server_and_media_id, False) + + def test_keep_media_by_room_avatar(self): + """ + Tests that we do not delete media if it is used as a room avatar + Tests parameter `keep_profiles` + """ + server_and_media_id = self._create_media() + + self._access_media(server_and_media_id) + + # set media as room avatar + room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + request, channel = self.make_request( + "PUT", + "/rooms/%s/state/m.room.avatar" % (room_id,), + content=json.dumps({"url": "mxc://%s" % (server_and_media_id,)}), + access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, channel.code, msg=channel.json_body) + + now_ms = self.clock.time_msec() + request, channel = self.make_request( + "POST", + self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=true", + access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(0, channel.json_body["total"]) + + self._access_media(server_and_media_id) + + now_ms = self.clock.time_msec() + request, channel = self.make_request( + "POST", + self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=false", + access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(1, channel.json_body["total"]) + self.assertEqual( + server_and_media_id.split("/")[1], channel.json_body["deleted_media"][0], + ) + + self._access_media(server_and_media_id, False) + + def _create_media(self): + """ + Create a media and return media_id and server_and_media_id + """ + upload_resource = self.media_repo.children[b"upload"] + # file size is 67 Byte + image_data = unhexlify( + b"89504e470d0a1a0a0000000d4948445200000001000000010806" + b"0000001f15c4890000000a49444154789c63000100000500010d" + b"0a2db40000000049454e44ae426082" + ) + + # Upload some media into the room + response = self.helper.upload_media( + upload_resource, image_data, tok=self.admin_user_tok, expect_code=200 + ) + # Extract media ID from the response + server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://' + server_name = server_and_media_id.split("/")[0] + + # Check that new media is a local and not remote + self.assertEqual(server_name, self.server_name) + + return server_and_media_id + + def _access_media(self, server_and_media_id, expect_success=True): + """ + Try to access a media and check the result + """ + download_resource = self.media_repo.children[b"download"] + + media_id = server_and_media_id.split("/")[1] + local_path = self.filepaths.local_media_filepath(media_id) + + request, channel = self.make_request( + "GET", + server_and_media_id, + shorthand=False, + access_token=self.admin_user_tok, + ) + request.render(download_resource) + self.pump(1.0) + + if expect_success: + self.assertEqual( + 200, + channel.code, + msg=( + "Expected to receive a 200 on accessing media: %s" + % server_and_media_id + ), + ) + # Test that the file exists + self.assertTrue(os.path.exists(local_path)) + else: + self.assertEqual( + 404, + channel.code, + msg=( + "Expected to receive a 404 on accessing deleted media: %s" + % (server_and_media_id) + ), + ) + # Test that the file is deleted + self.assertFalse(os.path.exists(local_path)) diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 98d0623734..7df32e5093 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py
@@ -17,6 +17,7 @@ import hashlib import hmac import json import urllib.parse +from binascii import unhexlify from mock import Mock @@ -1016,7 +1017,6 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, - sync.register_servlets, room.register_servlets, ] @@ -1082,6 +1082,21 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) + def test_no_memberships(self): + """ + Tests that a normal lookup for rooms is successfully + if user has no memberships + """ + # Get rooms + request, channel = self.make_request( + "GET", self.url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(0, channel.json_body["total"]) + self.assertEqual(0, len(channel.json_body["joined_rooms"])) + def test_get_rooms(self): """ Tests that a normal lookup for rooms is successfully @@ -1101,3 +1116,408 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(number_rooms, channel.json_body["total"]) self.assertEqual(number_rooms, len(channel.json_body["joined_rooms"])) + + +class PushersRestTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.url = "/_synapse/admin/v1/users/%s/pushers" % urllib.parse.quote( + self.other_user + ) + + def test_no_auth(self): + """ + Try to list pushers of an user without authentication. + """ + request, channel = self.make_request("GET", self.url, b"{}") + self.render(request) + + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_requester_is_no_admin(self): + """ + If the user is not a server admin, an error is returned. + """ + other_user_token = self.login("user", "pass") + + request, channel = self.make_request( + "GET", self.url, access_token=other_user_token, + ) + self.render(request) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_user_does_not_exist(self): + """ + Tests that a lookup for a user that does not exist returns a 404 + """ + url = "/_synapse/admin/v1/users/@unknown_person:test/pushers" + request, channel = self.make_request( + "GET", url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + def test_user_is_not_local(self): + """ + Tests that a lookup for a user that is not a local returns a 400 + """ + url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/pushers" + + request, channel = self.make_request( + "GET", url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual("Can only lookup local users", channel.json_body["error"]) + + def test_get_pushers(self): + """ + Tests that a normal lookup for pushers is successfully + """ + + # Get pushers + request, channel = self.make_request( + "GET", self.url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(0, channel.json_body["total"]) + + # Register the pusher + other_user_token = self.login("user", "pass") + user_tuple = self.get_success( + self.store.get_user_by_access_token(other_user_token) + ) + token_id = user_tuple.token_id + + self.get_success( + self.hs.get_pusherpool().add_pusher( + user_id=self.other_user, + access_token=token_id, + kind="http", + app_id="m.http", + app_display_name="HTTP Push Notifications", + device_display_name="pushy push", + pushkey="a@example.com", + lang=None, + data={"url": "example.com"}, + ) + ) + + # Get pushers + request, channel = self.make_request( + "GET", self.url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(1, channel.json_body["total"]) + + for p in channel.json_body["pushers"]: + self.assertIn("pushkey", p) + self.assertIn("kind", p) + self.assertIn("app_id", p) + self.assertIn("app_display_name", p) + self.assertIn("device_display_name", p) + self.assertIn("profile_tag", p) + self.assertIn("lang", p) + self.assertIn("url", p["data"]) + + +class UserMediaRestTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.media_repo = hs.get_media_repository_resource() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.url = "/_synapse/admin/v1/users/%s/media" % urllib.parse.quote( + self.other_user + ) + + def test_no_auth(self): + """ + Try to list media of an user without authentication. + """ + request, channel = self.make_request("GET", self.url, b"{}") + self.render(request) + + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_requester_is_no_admin(self): + """ + If the user is not a server admin, an error is returned. + """ + other_user_token = self.login("user", "pass") + + request, channel = self.make_request( + "GET", self.url, access_token=other_user_token, + ) + self.render(request) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_user_does_not_exist(self): + """ + Tests that a lookup for a user that does not exist returns a 404 + """ + url = "/_synapse/admin/v1/users/@unknown_person:test/media" + request, channel = self.make_request( + "GET", url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + def test_user_is_not_local(self): + """ + Tests that a lookup for a user that is not a local returns a 400 + """ + url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/media" + + request, channel = self.make_request( + "GET", url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual("Can only lookup local users", channel.json_body["error"]) + + def test_limit(self): + """ + Testing list of media with limit + """ + + number_media = 20 + other_user_tok = self.login("user", "pass") + self._create_media(other_user_tok, number_media) + + request, channel = self.make_request( + "GET", self.url + "?limit=5", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_media) + self.assertEqual(len(channel.json_body["media"]), 5) + self.assertEqual(channel.json_body["next_token"], 5) + self._check_fields(channel.json_body["media"]) + + def test_from(self): + """ + Testing list of media with a defined starting point (from) + """ + + number_media = 20 + other_user_tok = self.login("user", "pass") + self._create_media(other_user_tok, number_media) + + request, channel = self.make_request( + "GET", self.url + "?from=5", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_media) + self.assertEqual(len(channel.json_body["media"]), 15) + self.assertNotIn("next_token", channel.json_body) + self._check_fields(channel.json_body["media"]) + + def test_limit_and_from(self): + """ + Testing list of media with a defined starting point and limit + """ + + number_media = 20 + other_user_tok = self.login("user", "pass") + self._create_media(other_user_tok, number_media) + + request, channel = self.make_request( + "GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_media) + self.assertEqual(channel.json_body["next_token"], 15) + self.assertEqual(len(channel.json_body["media"]), 10) + self._check_fields(channel.json_body["media"]) + + def test_limit_is_negative(self): + """ + Testing that a negative limit parameter returns a 400 + """ + + request, channel = self.make_request( + "GET", self.url + "?limit=-5", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + + def test_from_is_negative(self): + """ + Testing that a negative from parameter returns a 400 + """ + + request, channel = self.make_request( + "GET", self.url + "?from=-5", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + + def test_next_token(self): + """ + Testing that `next_token` appears at the right place + """ + + number_media = 20 + other_user_tok = self.login("user", "pass") + self._create_media(other_user_tok, number_media) + + # `next_token` does not appear + # Number of results is the number of entries + request, channel = self.make_request( + "GET", self.url + "?limit=20", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_media) + self.assertEqual(len(channel.json_body["media"]), number_media) + self.assertNotIn("next_token", channel.json_body) + + # `next_token` does not appear + # Number of max results is larger than the number of entries + request, channel = self.make_request( + "GET", self.url + "?limit=21", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_media) + self.assertEqual(len(channel.json_body["media"]), number_media) + self.assertNotIn("next_token", channel.json_body) + + # `next_token` does appear + # Number of max results is smaller than the number of entries + request, channel = self.make_request( + "GET", self.url + "?limit=19", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_media) + self.assertEqual(len(channel.json_body["media"]), 19) + self.assertEqual(channel.json_body["next_token"], 19) + + # Check + # Set `from` to value of `next_token` for request remaining entries + # `next_token` does not appear + request, channel = self.make_request( + "GET", self.url + "?from=19", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_media) + self.assertEqual(len(channel.json_body["media"]), 1) + self.assertNotIn("next_token", channel.json_body) + + def test_user_has_no_media(self): + """ + Tests that a normal lookup for media is successfully + if user has no media created + """ + + request, channel = self.make_request( + "GET", self.url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(0, channel.json_body["total"]) + self.assertEqual(0, len(channel.json_body["media"])) + + def test_get_media(self): + """ + Tests that a normal lookup for media is successfully + """ + + number_media = 5 + other_user_tok = self.login("user", "pass") + self._create_media(other_user_tok, number_media) + + request, channel = self.make_request( + "GET", self.url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(number_media, channel.json_body["total"]) + self.assertEqual(number_media, len(channel.json_body["media"])) + self.assertNotIn("next_token", channel.json_body) + self._check_fields(channel.json_body["media"]) + + def _create_media(self, user_token, number_media): + """ + Create a number of media for a specific user + """ + upload_resource = self.media_repo.children[b"upload"] + for i in range(number_media): + # file size is 67 Byte + image_data = unhexlify( + b"89504e470d0a1a0a0000000d4948445200000001000000010806" + b"0000001f15c4890000000a49444154789c63000100000500010d" + b"0a2db40000000049454e44ae426082" + ) + + # Upload some media into the room + self.helper.upload_media( + upload_resource, image_data, tok=user_token, expect_code=200 + ) + + def _check_fields(self, content): + """Checks that all attributes are present in content + """ + for m in content: + self.assertIn("media_id", m) + self.assertIn("media_type", m) + self.assertIn("media_length", m) + self.assertIn("upload_name", m) + self.assertIn("created_ts", m) + self.assertIn("last_access_ts", m) + self.assertIn("quarantined_by", m) + self.assertIn("safe_from_quarantine", m) diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 2fc3a60fc5..98c3887bbf 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py
@@ -55,6 +55,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.hs.config.server_name, id="1234", namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, + sender="@as:test", ) self.hs.get_datastore().services_cache.append(appservice) diff --git a/tests/server.py b/tests/server.py
index 4d33b84097..3dd2cfc072 100644 --- a/tests/server.py +++ b/tests/server.py
@@ -46,7 +46,7 @@ class FakeChannel: site = attr.ib(type=Site) _reactor = attr.ib() - result = attr.ib(default=attr.Factory(dict)) + result = attr.ib(type=dict, default=attr.Factory(dict)) _producer = None @property @@ -380,7 +380,7 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs): pool._runWithConnection, func, *args, - **kwargs + **kwargs, ) def runInteraction(interaction, *args, **kwargs): @@ -390,7 +390,7 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs): pool._runInteraction, interaction, *args, - **kwargs + **kwargs, ) pool.runWithConnection = runWithConnection @@ -571,12 +571,10 @@ def connect_client(reactor: IReactorTCP, client_id: int) -> AccumulatingProtocol reactor factory: The connecting factory to build. """ - factory = reactor.tcpClients[client_id][2] + factory = reactor.tcpClients.pop(client_id)[2] client = factory.buildProtocol(None) server = AccumulatingProtocol() server.makeConnection(FakeTransport(client, reactor)) client.makeConnection(FakeTransport(server, reactor)) - reactor.tcpClients.pop(client_id) - return client, server diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index 080761d1d2..5a1e5c4e66 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py
@@ -22,7 +22,7 @@ import synapse.rest.admin from synapse.api.constants import EventTypes from synapse.rest.client.v1 import login, room from synapse.storage import prepare_database -from synapse.types import Requester, UserID +from synapse.types import UserID, create_requester from tests.unittest import HomeserverTestCase @@ -38,7 +38,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): # Create a test user and room self.user = UserID("alice", "test") - self.requester = Requester(self.user, None, False, False, None, None) + self.requester = create_requester(self.user) info, _ = self.get_success(self.room_creator.create_room(self.requester, {})) self.room_id = info["room_id"] @@ -260,7 +260,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase): # Create a test user and room self.user = UserID.from_string(self.register_user("user1", "password")) self.token1 = self.login("user1", "password") - self.requester = Requester(self.user, None, False, False, None, None) + self.requester = create_requester(self.user) info, _ = self.get_success(self.room_creator.create_room(self.requester, {})) self.room_id = info["room_id"] self.event_creator = homeserver.get_event_creation_handler() diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 755c70db31..e96ca1c8ca 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py
@@ -412,7 +412,7 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase): "GET", "/_matrix/client/r0/admin/users/" + self.user_id, access_token=access_token, - **make_request_args + **make_request_args, ) request.requestHeaders.addRawHeader(b"User-Agent", b"Mozzila pizza") diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py
index 3957471f3f..7691f2d790 100644 --- a/tests/storage/test_event_metrics.py +++ b/tests/storage/test_event_metrics.py
@@ -14,7 +14,7 @@ # limitations under the License. from synapse.metrics import REGISTRY, generate_latest -from synapse.types import Requester, UserID +from synapse.types import UserID, create_requester from tests.unittest import HomeserverTestCase @@ -27,7 +27,7 @@ class ExtremStatisticsTestCase(HomeserverTestCase): room_creator = self.hs.get_room_creation_handler() user = UserID("alice", "test") - requester = Requester(user, None, False, False, None, None) + requester = create_requester(user) # Real events, forward extremities events = [(3, 2), (6, 2), (4, 6)] diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 6b582771fe..c8c7a90e5d 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py
@@ -69,11 +69,9 @@ class RegistrationStoreTestCase(unittest.TestCase): self.store.get_user_by_access_token(self.tokens[1]) ) - self.assertDictContainsSubset( - {"name": self.user_id, "device_id": self.device_id}, result - ) - - self.assertTrue("token_id" in result) + self.assertEqual(result.user_id, self.user_id) + self.assertEqual(result.device_id, self.device_id) + self.assertIsNotNone(result.token_id) @defer.inlineCallbacks def test_user_delete_access_tokens(self): @@ -105,7 +103,7 @@ class RegistrationStoreTestCase(unittest.TestCase): user = yield defer.ensureDeferred( self.store.get_user_by_access_token(self.tokens[0]) ) - self.assertEqual(self.user_id, user["name"]) + self.assertEqual(self.user_id, user.user_id) # now delete the rest yield defer.ensureDeferred(self.store.user_delete_access_tokens(self.user_id)) diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 12ccc1f53e..ff972daeaa 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py
@@ -19,7 +19,7 @@ from unittest.mock import Mock from synapse.api.constants import Membership from synapse.rest.admin import register_servlets_for_client_rest_resource from synapse.rest.client.v1 import login, room -from synapse.types import Requester, UserID +from synapse.types import UserID, create_requester from tests import unittest from tests.test_utils import event_injection @@ -187,7 +187,7 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase): # Now let's create a room, which will insert a membership user = UserID("alice", "test") - requester = Requester(user, None, False, False, None, None) + requester = create_requester(user) self.get_success(self.room_creator.create_room(requester, {})) # Register the background update to run again. diff --git a/tests/test_federation.py b/tests/test_federation.py
index d39e792580..1ce4ea3a01 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py
@@ -20,7 +20,7 @@ from twisted.internet.defer import succeed from synapse.api.errors import FederationError from synapse.events import make_event_from_dict from synapse.logging.context import LoggingContext -from synapse.types import Requester, UserID +from synapse.types import UserID, create_requester from synapse.util import Clock from synapse.util.retryutils import NotRetryingDestination @@ -43,7 +43,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): ) user_id = UserID("us", "test") - our_user = Requester(user_id, None, False, False, None, None) + our_user = create_requester(user_id) room_creator = self.homeserver.get_room_creation_handler() self.room_id = self.get_success( room_creator.create_room( diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index a298cc0fd3..d232b72264 100644 --- a/tests/test_utils/__init__.py +++ b/tests/test_utils/__init__.py
@@ -17,8 +17,10 @@ """ Utilities for running the unit tests """ +import sys +import warnings from asyncio import Future -from typing import Any, Awaitable, TypeVar +from typing import Any, Awaitable, Callable, TypeVar TV = TypeVar("TV") @@ -48,3 +50,33 @@ def make_awaitable(result: Any) -> Awaitable[Any]: future = Future() # type: ignore future.set_result(result) return future + + +def setup_awaitable_errors() -> Callable[[], None]: + """ + Convert warnings from a non-awaited coroutines into errors. + """ + warnings.simplefilter("error", RuntimeWarning) + + # unraisablehook was added in Python 3.8. + if not hasattr(sys, "unraisablehook"): + return lambda: None + + # State shared between unraisablehook and check_for_unraisable_exceptions. + unraisable_exceptions = [] + orig_unraisablehook = sys.unraisablehook # type: ignore + + def unraisablehook(unraisable): + unraisable_exceptions.append(unraisable.exc_value) + + def cleanup(): + """ + A method to be used as a clean-up that fails a test-case if there are any new unraisable exceptions. + """ + sys.unraisablehook = orig_unraisablehook # type: ignore + if unraisable_exceptions: + raise unraisable_exceptions.pop() + + sys.unraisablehook = unraisablehook # type: ignore + + return cleanup diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py
index e93aa84405..c3c4a93e1f 100644 --- a/tests/test_utils/event_injection.py +++ b/tests/test_utils/event_injection.py
@@ -50,7 +50,7 @@ async def inject_member_event( sender=sender, state_key=target, content=content, - **kwargs + **kwargs, ) diff --git a/tests/unittest.py b/tests/unittest.py
index 040b126a27..08cf9b10c5 100644 --- a/tests/unittest.py +++ b/tests/unittest.py
@@ -44,7 +44,7 @@ from synapse.logging.context import ( set_current_context, ) from synapse.server import HomeServer -from synapse.types import Requester, UserID, create_requester +from synapse.types import UserID, create_requester from synapse.util.ratelimitutils import FederationRateLimiter from tests.server import ( @@ -54,7 +54,7 @@ from tests.server import ( render, setup_test_homeserver, ) -from tests.test_utils import event_injection +from tests.test_utils import event_injection, setup_awaitable_errors from tests.test_utils.logging_setup import setup_logging from tests.utils import default_config, setupdb @@ -119,6 +119,10 @@ class TestCase(unittest.TestCase): logging.getLogger().setLevel(level) + # Trial messes with the warnings configuration, thus this has to be + # done in the context of an individual TestCase. + self.addCleanup(setup_awaitable_errors()) + return orig() @around(self) @@ -627,7 +631,7 @@ class HomeserverTestCase(TestCase): """ event_creator = self.hs.get_event_creation_handler() secrets = self.hs.get_secrets() - requester = Requester(user, None, False, False, None, None) + requester = create_requester(user) event, context = self.get_success( event_creator.create_event( diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 2ad08f541b..cf1e3203a4 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py
@@ -29,13 +29,46 @@ from synapse.logging.context import ( make_deferred_yieldable, ) from synapse.util.caches import descriptors -from synapse.util.caches.descriptors import cached +from synapse.util.caches.descriptors import cached, lru_cache from tests import unittest +from tests.test_utils import get_awaitable_result logger = logging.getLogger(__name__) +class LruCacheDecoratorTestCase(unittest.TestCase): + def test_base(self): + class Cls: + def __init__(self): + self.mock = mock.Mock() + + @lru_cache() + def fn(self, arg1, arg2): + return self.mock(arg1, arg2) + + obj = Cls() + obj.mock.return_value = "fish" + r = obj.fn(1, 2) + self.assertEqual(r, "fish") + obj.mock.assert_called_once_with(1, 2) + obj.mock.reset_mock() + + # a call with different params should call the mock again + obj.mock.return_value = "chips" + r = obj.fn(1, 3) + self.assertEqual(r, "chips") + obj.mock.assert_called_once_with(1, 3) + obj.mock.reset_mock() + + # the two values should now be cached + r = obj.fn(1, 2) + self.assertEqual(r, "fish") + r = obj.fn(1, 3) + self.assertEqual(r, "chips") + obj.mock.assert_not_called() + + def run_on_reactor(): d = defer.Deferred() reactor.callLater(0, d.callback, 0) @@ -362,6 +395,31 @@ class DescriptorTestCase(unittest.TestCase): d = obj.fn(1) self.failureResultOf(d, SynapseError) + def test_invalidate_cascade(self): + """Invalidations should cascade up through cache contexts""" + + class Cls: + @cached(cache_context=True) + async def func1(self, key, cache_context): + return await self.func2(key, on_invalidate=cache_context.invalidate) + + @cached(cache_context=True) + async def func2(self, key, cache_context): + return self.func3(key, on_invalidate=cache_context.invalidate) + + @lru_cache(cache_context=True) + def func3(self, key, cache_context): + self.invalidate = cache_context.invalidate + return 42 + + obj = Cls() + + top_invalidate = mock.Mock() + r = get_awaitable_result(obj.func1("k1", on_invalidate=top_invalidate)) + self.assertEqual(r, 42) + obj.invalidate() + top_invalidate.assert_called_once() + class CacheDecoratorTestCase(unittest.HomeserverTestCase): """More tests for @cached