diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index cb6f29d670..0fd55f428a 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -29,6 +29,7 @@ from synapse.api.errors import (
MissingClientTokenError,
ResourceLimitError,
)
+from synapse.storage.databases.main.registration import TokenLookupResult
from synapse.types import UserID
from tests import unittest
@@ -61,7 +62,9 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_get_user_by_req_user_valid_token(self):
- user_info = {"name": self.test_user, "token_id": "ditto", "device_id": "device"}
+ user_info = TokenLookupResult(
+ user_id=self.test_user, token_id=5, device_id="device"
+ )
self.store.get_user_by_access_token = Mock(
return_value=defer.succeed(user_info)
)
@@ -84,7 +87,7 @@ class AuthTestCase(unittest.TestCase):
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
def test_get_user_by_req_user_missing_token(self):
- user_info = {"name": self.test_user, "token_id": "ditto"}
+ user_info = TokenLookupResult(user_id=self.test_user, token_id=5)
self.store.get_user_by_access_token = Mock(
return_value=defer.succeed(user_info)
)
@@ -221,7 +224,7 @@ class AuthTestCase(unittest.TestCase):
def test_get_user_from_macaroon(self):
self.store.get_user_by_access_token = Mock(
return_value=defer.succeed(
- {"name": "@baldrick:matrix.org", "device_id": "device"}
+ TokenLookupResult(user_id="@baldrick:matrix.org", device_id="device")
)
)
@@ -237,12 +240,11 @@ class AuthTestCase(unittest.TestCase):
user_info = yield defer.ensureDeferred(
self.auth.get_user_by_access_token(macaroon.serialize())
)
- user = user_info["user"]
- self.assertEqual(UserID.from_string(user_id), user)
+ self.assertEqual(user_id, user_info.user_id)
# TODO: device_id should come from the macaroon, but currently comes
# from the db.
- self.assertEqual(user_info["device_id"], "device")
+ self.assertEqual(user_info.device_id, "device")
@defer.inlineCallbacks
def test_get_guest_user_from_macaroon(self):
@@ -264,10 +266,8 @@ class AuthTestCase(unittest.TestCase):
user_info = yield defer.ensureDeferred(
self.auth.get_user_by_access_token(serialized)
)
- user = user_info["user"]
- is_guest = user_info["is_guest"]
- self.assertEqual(UserID.from_string(user_id), user)
- self.assertTrue(is_guest)
+ self.assertEqual(user_id, user_info.user_id)
+ self.assertTrue(user_info.is_guest)
self.store.get_user_by_id.assert_called_with(user_id)
@defer.inlineCallbacks
@@ -289,12 +289,9 @@ class AuthTestCase(unittest.TestCase):
if token != tok:
return defer.succeed(None)
return defer.succeed(
- {
- "name": USER_ID,
- "is_guest": False,
- "token_id": 1234,
- "device_id": "DEVICE",
- }
+ TokenLookupResult(
+ user_id=USER_ID, is_guest=False, token_id=1234, device_id="DEVICE",
+ )
)
self.store.get_user_by_access_token = get_user
diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py
index 1e1f30d790..fe504d0869 100644
--- a/tests/api/test_ratelimiting.py
+++ b/tests/api/test_ratelimiting.py
@@ -43,7 +43,7 @@ class TestRatelimiter(unittest.TestCase):
def test_allowed_appservice_ratelimited_via_can_requester_do_action(self):
appservice = ApplicationService(
- None, "example.com", id="foo", rate_limited=True,
+ None, "example.com", id="foo", rate_limited=True, sender="@as:example.com",
)
as_requester = create_requester("@user:example.com", app_service=appservice)
@@ -68,7 +68,7 @@ class TestRatelimiter(unittest.TestCase):
def test_allowed_appservice_via_can_requester_do_action(self):
appservice = ApplicationService(
- None, "example.com", id="foo", rate_limited=False,
+ None, "example.com", id="foo", rate_limited=False, sender="@as:example.com",
)
as_requester = create_requester("@user:example.com", app_service=appservice)
diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py
index 236b608d58..0bffeb1150 100644
--- a/tests/appservice/test_appservice.py
+++ b/tests/appservice/test_appservice.py
@@ -31,6 +31,7 @@ class ApplicationServiceTestCase(unittest.TestCase):
def setUp(self):
self.service = ApplicationService(
id="unique_identifier",
+ sender="@as:test",
url="some_url",
token="some_token",
hostname="matrix.org", # only used by get_groups_for_user
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index ee4f3da31c..53763cd0f9 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -42,7 +42,6 @@ class AppServiceHandlerTestCase(unittest.TestCase):
hs.get_clock.return_value = MockClock()
self.handler = ApplicationServicesHandler(hs)
- @defer.inlineCallbacks
def test_notify_interested_services(self):
interested_service = self._mkservice(is_interested=True)
services = [
@@ -62,14 +61,12 @@ class AppServiceHandlerTestCase(unittest.TestCase):
defer.succeed((0, [event])),
defer.succeed((0, [])),
]
- yield defer.ensureDeferred(
- self.handler.notify_interested_services(RoomStreamToken(None, 0))
- )
+ self.handler.notify_interested_services(RoomStreamToken(None, 0))
+
self.mock_scheduler.submit_event_for_as.assert_called_once_with(
interested_service, event
)
- @defer.inlineCallbacks
def test_query_user_exists_unknown_user(self):
user_id = "@someone:anywhere"
services = [self._mkservice(is_interested=True)]
@@ -83,12 +80,11 @@ class AppServiceHandlerTestCase(unittest.TestCase):
defer.succeed((0, [event])),
defer.succeed((0, [])),
]
- yield defer.ensureDeferred(
- self.handler.notify_interested_services(RoomStreamToken(None, 0))
- )
+
+ self.handler.notify_interested_services(RoomStreamToken(None, 0))
+
self.mock_as_api.query_user.assert_called_once_with(services[0], user_id)
- @defer.inlineCallbacks
def test_query_user_exists_known_user(self):
user_id = "@someone:anywhere"
services = [self._mkservice(is_interested=True)]
@@ -102,9 +98,9 @@ class AppServiceHandlerTestCase(unittest.TestCase):
defer.succeed((0, [event])),
defer.succeed((0, [])),
]
- yield defer.ensureDeferred(
- self.handler.notify_interested_services(RoomStreamToken(None, 0))
- )
+
+ self.handler.notify_interested_services(RoomStreamToken(None, 0))
+
self.assertFalse(
self.mock_as_api.query_user.called,
"query_user called when it shouldn't have been.",
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 4512c51311..875aaec2c6 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -289,7 +289,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
# make sure that our device ID has changed
user_info = self.get_success(self.auth.get_user_by_access_token(access_token))
- self.assertEqual(user_info["device_id"], retrieved_device_id)
+ self.assertEqual(user_info.device_id, retrieved_device_id)
# make sure the device has the display name that was set from the login
res = self.get_success(self.handler.get_device(user_id, retrieved_device_id))
diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py
index 9f6f21a6e2..2e0fea04af 100644
--- a/tests/handlers/test_message.py
+++ b/tests/handlers/test_message.py
@@ -46,7 +46,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
self.info = self.get_success(
self.hs.get_datastore().get_user_by_access_token(self.access_token,)
)
- self.token_id = self.info["token_id"]
+ self.token_id = self.info.token_id
self.requester = create_requester(self.user_id, access_token_id=self.token_id)
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index b6f436c016..0d51705849 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -394,7 +394,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
self.handler._auth_handler.complete_sso_login = simple_async_mock()
request = Mock(
- spec=["args", "getCookie", "addCookie", "requestHeaders", "getClientIP"]
+ spec=[
+ "args",
+ "getCookie",
+ "addCookie",
+ "requestHeaders",
+ "getClientIP",
+ "get_user_agent",
+ ]
)
code = "code"
@@ -414,9 +421,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
request.args[b"code"] = [code.encode("utf-8")]
request.args[b"state"] = [state.encode("utf-8")]
- request.requestHeaders = Mock(spec=["getRawHeaders"])
- request.requestHeaders.getRawHeaders.return_value = [user_agent.encode("ascii")]
request.getClientIP.return_value = ip_address
+ request.get_user_agent.return_value = user_agent
self.get_success(self.handler.handle_oidc_callback(request))
@@ -621,7 +627,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
self.handler._auth_handler.complete_sso_login = simple_async_mock()
request = Mock(
- spec=["args", "getCookie", "addCookie", "requestHeaders", "getClientIP"]
+ spec=[
+ "args",
+ "getCookie",
+ "addCookie",
+ "requestHeaders",
+ "getClientIP",
+ "get_user_agent",
+ ]
)
state = "state"
@@ -637,9 +650,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
request.args[b"code"] = [b"code"]
request.args[b"state"] = [state.encode("utf-8")]
- request.requestHeaders = Mock(spec=["getRawHeaders"])
- request.requestHeaders.getRawHeaders.return_value = [b"Browser"]
request.getClientIP.return_value = "10.0.0.1"
+ request.get_user_agent.return_value = "Browser"
self.get_success(self.handler.handle_oidc_callback(request))
diff --git a/tests/logging/__init__.py b/tests/logging/__init__.py
index e69de29bb2..a58d51441c 100644
--- a/tests/logging/__init__.py
+++ b/tests/logging/__init__.py
@@ -0,0 +1,34 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+
+class LoggerCleanupMixin:
+ def get_logger(self, handler):
+ """
+ Attach a handler to a logger and add clean-ups to remove revert this.
+ """
+ # Create a logger and add the handler to it.
+ logger = logging.getLogger(__name__)
+ logger.addHandler(handler)
+
+ # Ensure the logger actually logs something.
+ logger.setLevel(logging.INFO)
+
+ # Ensure the logger gets cleaned-up appropriately.
+ self.addCleanup(logger.removeHandler, handler)
+ self.addCleanup(logger.setLevel, logging.NOTSET)
+
+ return logger
diff --git a/tests/logging/test_remote_handler.py b/tests/logging/test_remote_handler.py
new file mode 100644
index 0000000000..4bc27a1d7d
--- /dev/null
+++ b/tests/logging/test_remote_handler.py
@@ -0,0 +1,169 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from twisted.test.proto_helpers import AccumulatingProtocol
+
+from synapse.logging import RemoteHandler
+
+from tests.logging import LoggerCleanupMixin
+from tests.server import FakeTransport, get_clock
+from tests.unittest import TestCase
+
+
+def connect_logging_client(reactor, client_id):
+ # This is essentially tests.server.connect_client, but disabling autoflush on
+ # the client transport. This is necessary to avoid an infinite loop due to
+ # sending of data via the logging transport causing additional logs to be
+ # written.
+ factory = reactor.tcpClients.pop(client_id)[2]
+ client = factory.buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, reactor))
+ client.makeConnection(FakeTransport(server, reactor, autoflush=False))
+
+ return client, server
+
+
+class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase):
+ def setUp(self):
+ self.reactor, _ = get_clock()
+
+ def test_log_output(self):
+ """
+ The remote handler delivers logs over TCP.
+ """
+ handler = RemoteHandler("127.0.0.1", 9000, _reactor=self.reactor)
+ logger = self.get_logger(handler)
+
+ logger.info("Hello there, %s!", "wally")
+
+ # Trigger the connection
+ client, server = connect_logging_client(self.reactor, 0)
+
+ # Trigger data being sent
+ client.transport.flush()
+
+ # One log message, with a single trailing newline
+ logs = server.data.decode("utf8").splitlines()
+ self.assertEqual(len(logs), 1)
+ self.assertEqual(server.data.count(b"\n"), 1)
+
+ # Ensure the data passed through properly.
+ self.assertEqual(logs[0], "Hello there, wally!")
+
+ def test_log_backpressure_debug(self):
+ """
+ When backpressure is hit, DEBUG logs will be shed.
+ """
+ handler = RemoteHandler(
+ "127.0.0.1", 9000, maximum_buffer=10, _reactor=self.reactor
+ )
+ logger = self.get_logger(handler)
+
+ # Send some debug messages
+ for i in range(0, 3):
+ logger.debug("debug %s" % (i,))
+
+ # Send a bunch of useful messages
+ for i in range(0, 7):
+ logger.info("info %s" % (i,))
+
+ # The last debug message pushes it past the maximum buffer
+ logger.debug("too much debug")
+
+ # Allow the reconnection
+ client, server = connect_logging_client(self.reactor, 0)
+ client.transport.flush()
+
+ # Only the 7 infos made it through, the debugs were elided
+ logs = server.data.splitlines()
+ self.assertEqual(len(logs), 7)
+ self.assertNotIn(b"debug", server.data)
+
+ def test_log_backpressure_info(self):
+ """
+ When backpressure is hit, DEBUG and INFO logs will be shed.
+ """
+ handler = RemoteHandler(
+ "127.0.0.1", 9000, maximum_buffer=10, _reactor=self.reactor
+ )
+ logger = self.get_logger(handler)
+
+ # Send some debug messages
+ for i in range(0, 3):
+ logger.debug("debug %s" % (i,))
+
+ # Send a bunch of useful messages
+ for i in range(0, 10):
+ logger.warning("warn %s" % (i,))
+
+ # Send a bunch of info messages
+ for i in range(0, 3):
+ logger.info("info %s" % (i,))
+
+ # The last debug message pushes it past the maximum buffer
+ logger.debug("too much debug")
+
+ # Allow the reconnection
+ client, server = connect_logging_client(self.reactor, 0)
+ client.transport.flush()
+
+ # The 10 warnings made it through, the debugs and infos were elided
+ logs = server.data.splitlines()
+ self.assertEqual(len(logs), 10)
+ self.assertNotIn(b"debug", server.data)
+ self.assertNotIn(b"info", server.data)
+
+ def test_log_backpressure_cut_middle(self):
+ """
+ When backpressure is hit, and no more DEBUG and INFOs cannot be culled,
+ it will cut the middle messages out.
+ """
+ handler = RemoteHandler(
+ "127.0.0.1", 9000, maximum_buffer=10, _reactor=self.reactor
+ )
+ logger = self.get_logger(handler)
+
+ # Send a bunch of useful messages
+ for i in range(0, 20):
+ logger.warning("warn %s" % (i,))
+
+ # Allow the reconnection
+ client, server = connect_logging_client(self.reactor, 0)
+ client.transport.flush()
+
+ # The first five and last five warnings made it through, the debugs and
+ # infos were elided
+ logs = server.data.decode("utf8").splitlines()
+ self.assertEqual(
+ ["warn %s" % (i,) for i in range(5)]
+ + ["warn %s" % (i,) for i in range(15, 20)],
+ logs,
+ )
+
+ def test_cancel_connection(self):
+ """
+ Gracefully handle the connection being cancelled.
+ """
+ handler = RemoteHandler(
+ "127.0.0.1", 9000, maximum_buffer=10, _reactor=self.reactor
+ )
+ logger = self.get_logger(handler)
+
+ # Send a message.
+ logger.info("Hello there, %s!", "wally")
+
+ # Do not accept the connection and shutdown. This causes the pending
+ # connection to be cancelled (and should not raise any exceptions).
+ handler.close()
diff --git a/tests/logging/test_structured.py b/tests/logging/test_structured.py
deleted file mode 100644
index d36f5f426c..0000000000
--- a/tests/logging/test_structured.py
+++ /dev/null
@@ -1,214 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2019 The Matrix.org Foundation C.I.C.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import logging
-import os
-import os.path
-import shutil
-import sys
-import textwrap
-
-from twisted.logger import Logger, eventAsText, eventsFromJSONLogFile
-
-from synapse.config.logger import setup_logging
-from synapse.logging._structured import setup_structured_logging
-from synapse.logging.context import LoggingContext
-
-from tests.unittest import DEBUG, HomeserverTestCase
-
-
-class FakeBeginner:
- def beginLoggingTo(self, observers, **kwargs):
- self.observers = observers
-
-
-class StructuredLoggingTestBase:
- """
- Test base that registers a cleanup handler to reset the stdlib log handler
- to 'unset'.
- """
-
- def prepare(self, reactor, clock, hs):
- def _cleanup():
- logging.getLogger("synapse").setLevel(logging.NOTSET)
-
- self.addCleanup(_cleanup)
-
-
-class StructuredLoggingTestCase(StructuredLoggingTestBase, HomeserverTestCase):
- """
- Tests for Synapse's structured logging support.
- """
-
- def test_output_to_json_round_trip(self):
- """
- Synapse logs can be outputted to JSON and then read back again.
- """
- temp_dir = self.mktemp()
- os.mkdir(temp_dir)
- self.addCleanup(shutil.rmtree, temp_dir)
-
- json_log_file = os.path.abspath(os.path.join(temp_dir, "out.json"))
-
- log_config = {
- "drains": {"jsonfile": {"type": "file_json", "location": json_log_file}}
- }
-
- # Begin the logger with our config
- beginner = FakeBeginner()
- setup_structured_logging(
- self.hs, self.hs.config, log_config, logBeginner=beginner
- )
-
- # Make a logger and send an event
- logger = Logger(
- namespace="tests.logging.test_structured", observer=beginner.observers[0]
- )
- logger.info("Hello there, {name}!", name="wally")
-
- # Read the log file and check it has the event we sent
- with open(json_log_file, "r") as f:
- logged_events = list(eventsFromJSONLogFile(f))
- self.assertEqual(len(logged_events), 1)
-
- # The event pulled from the file should render fine
- self.assertEqual(
- eventAsText(logged_events[0], includeTimestamp=False),
- "[tests.logging.test_structured#info] Hello there, wally!",
- )
-
- def test_output_to_text(self):
- """
- Synapse logs can be outputted to text.
- """
- temp_dir = self.mktemp()
- os.mkdir(temp_dir)
- self.addCleanup(shutil.rmtree, temp_dir)
-
- log_file = os.path.abspath(os.path.join(temp_dir, "out.log"))
-
- log_config = {"drains": {"file": {"type": "file", "location": log_file}}}
-
- # Begin the logger with our config
- beginner = FakeBeginner()
- setup_structured_logging(
- self.hs, self.hs.config, log_config, logBeginner=beginner
- )
-
- # Make a logger and send an event
- logger = Logger(
- namespace="tests.logging.test_structured", observer=beginner.observers[0]
- )
- logger.info("Hello there, {name}!", name="wally")
-
- # Read the log file and check it has the event we sent
- with open(log_file, "r") as f:
- logged_events = f.read().strip().split("\n")
- self.assertEqual(len(logged_events), 1)
-
- # The event pulled from the file should render fine
- self.assertTrue(
- logged_events[0].endswith(
- " - tests.logging.test_structured - INFO - None - Hello there, wally!"
- )
- )
-
- def test_collects_logcontext(self):
- """
- Test that log outputs have the attached logging context.
- """
- log_config = {"drains": {}}
-
- # Begin the logger with our config
- beginner = FakeBeginner()
- publisher = setup_structured_logging(
- self.hs, self.hs.config, log_config, logBeginner=beginner
- )
-
- logs = []
-
- publisher.addObserver(logs.append)
-
- # Make a logger and send an event
- logger = Logger(
- namespace="tests.logging.test_structured", observer=beginner.observers[0]
- )
-
- with LoggingContext("testcontext", request="somereq"):
- logger.info("Hello there, {name}!", name="steve")
-
- self.assertEqual(len(logs), 1)
- self.assertEqual(logs[0]["request"], "somereq")
-
-
-class StructuredLoggingConfigurationFileTestCase(
- StructuredLoggingTestBase, HomeserverTestCase
-):
- def make_homeserver(self, reactor, clock):
-
- tempdir = self.mktemp()
- os.mkdir(tempdir)
- log_config_file = os.path.abspath(os.path.join(tempdir, "log.config.yaml"))
- self.homeserver_log = os.path.abspath(os.path.join(tempdir, "homeserver.log"))
-
- config = self.default_config()
- config["log_config"] = log_config_file
-
- with open(log_config_file, "w") as f:
- f.write(
- textwrap.dedent(
- """\
- structured: true
-
- drains:
- file:
- type: file_json
- location: %s
- """
- % (self.homeserver_log,)
- )
- )
-
- self.addCleanup(self._sys_cleanup)
-
- return self.setup_test_homeserver(config=config)
-
- def _sys_cleanup(self):
- sys.stdout = sys.__stdout__
- sys.stderr = sys.__stderr__
-
- # Do not remove! We need the logging system to be set other than WARNING.
- @DEBUG
- def test_log_output(self):
- """
- When a structured logging config is given, Synapse will use it.
- """
- beginner = FakeBeginner()
- publisher = setup_logging(self.hs, self.hs.config, logBeginner=beginner)
-
- # Make a logger and send an event
- logger = Logger(namespace="tests.logging.test_structured", observer=publisher)
-
- with LoggingContext("testcontext", request="somereq"):
- logger.info("Hello there, {name}!", name="steve")
-
- with open(self.homeserver_log, "r") as f:
- logged_events = [
- eventAsText(x, includeTimestamp=False) for x in eventsFromJSONLogFile(f)
- ]
-
- logs = "\n".join(logged_events)
- self.assertTrue("***** STARTING SERVER *****" in logs)
- self.assertTrue("Hello there, steve!" in logs)
diff --git a/tests/logging/test_terse_json.py b/tests/logging/test_terse_json.py
index fd128b88e0..73f469b802 100644
--- a/tests/logging/test_terse_json.py
+++ b/tests/logging/test_terse_json.py
@@ -14,57 +14,33 @@
# limitations under the License.
import json
-from collections import Counter
+import logging
+from io import StringIO
-from twisted.logger import Logger
+from synapse.logging._terse_json import JsonFormatter, TerseJsonFormatter
-from synapse.logging._structured import setup_structured_logging
+from tests.logging import LoggerCleanupMixin
+from tests.unittest import TestCase
-from tests.server import connect_client
-from tests.unittest import HomeserverTestCase
-from .test_structured import FakeBeginner, StructuredLoggingTestBase
-
-
-class TerseJSONTCPTestCase(StructuredLoggingTestBase, HomeserverTestCase):
- def test_log_output(self):
+class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
+ def test_terse_json_output(self):
"""
- The Terse JSON outputter delivers simplified structured logs over TCP.
+ The Terse JSON formatter converts log messages to JSON.
"""
- log_config = {
- "drains": {
- "tersejson": {
- "type": "network_json_terse",
- "host": "127.0.0.1",
- "port": 8000,
- }
- }
- }
-
- # Begin the logger with our config
- beginner = FakeBeginner()
- setup_structured_logging(
- self.hs, self.hs.config, log_config, logBeginner=beginner
- )
-
- logger = Logger(
- namespace="tests.logging.test_terse_json", observer=beginner.observers[0]
- )
- logger.info("Hello there, {name}!", name="wally")
-
- # Trigger the connection
- self.pump()
+ output = StringIO()
- _, server = connect_client(self.reactor, 0)
+ handler = logging.StreamHandler(output)
+ handler.setFormatter(TerseJsonFormatter())
+ logger = self.get_logger(handler)
- # Trigger data being sent
- self.pump()
+ logger.info("Hello there, %s!", "wally")
- # One log message, with a single trailing newline
- logs = server.data.decode("utf8").splitlines()
+ # One log message, with a single trailing newline.
+ data = output.getvalue()
+ logs = data.splitlines()
self.assertEqual(len(logs), 1)
- self.assertEqual(server.data.count(b"\n"), 1)
-
+ self.assertEqual(data.count("\n"), 1)
log = json.loads(logs[0])
# The terse logger should give us these keys.
@@ -72,163 +48,74 @@ class TerseJSONTCPTestCase(StructuredLoggingTestBase, HomeserverTestCase):
"log",
"time",
"level",
- "log_namespace",
- "request",
- "scope",
- "server_name",
- "name",
+ "namespace",
]
self.assertCountEqual(log.keys(), expected_log_keys)
+ self.assertEqual(log["log"], "Hello there, wally!")
- # It contains the data we expect.
- self.assertEqual(log["name"], "wally")
-
- def test_log_backpressure_debug(self):
+ def test_extra_data(self):
"""
- When backpressure is hit, DEBUG logs will be shed.
+ Additional information can be included in the structured logging.
"""
- log_config = {
- "loggers": {"synapse": {"level": "DEBUG"}},
- "drains": {
- "tersejson": {
- "type": "network_json_terse",
- "host": "127.0.0.1",
- "port": 8000,
- "maximum_buffer": 10,
- }
- },
- }
-
- # Begin the logger with our config
- beginner = FakeBeginner()
- setup_structured_logging(
- self.hs,
- self.hs.config,
- log_config,
- logBeginner=beginner,
- redirect_stdlib_logging=False,
- )
-
- logger = Logger(
- namespace="synapse.logging.test_terse_json", observer=beginner.observers[0]
- )
+ output = StringIO()
- # Send some debug messages
- for i in range(0, 3):
- logger.debug("debug %s" % (i,))
+ handler = logging.StreamHandler(output)
+ handler.setFormatter(TerseJsonFormatter())
+ logger = self.get_logger(handler)
- # Send a bunch of useful messages
- for i in range(0, 7):
- logger.info("test message %s" % (i,))
-
- # The last debug message pushes it past the maximum buffer
- logger.debug("too much debug")
-
- # Allow the reconnection
- _, server = connect_client(self.reactor, 0)
- self.pump()
-
- # Only the 7 infos made it through, the debugs were elided
- logs = server.data.splitlines()
- self.assertEqual(len(logs), 7)
-
- def test_log_backpressure_info(self):
- """
- When backpressure is hit, DEBUG and INFO logs will be shed.
- """
- log_config = {
- "loggers": {"synapse": {"level": "DEBUG"}},
- "drains": {
- "tersejson": {
- "type": "network_json_terse",
- "host": "127.0.0.1",
- "port": 8000,
- "maximum_buffer": 10,
- }
- },
- }
-
- # Begin the logger with our config
- beginner = FakeBeginner()
- setup_structured_logging(
- self.hs,
- self.hs.config,
- log_config,
- logBeginner=beginner,
- redirect_stdlib_logging=False,
- )
-
- logger = Logger(
- namespace="synapse.logging.test_terse_json", observer=beginner.observers[0]
+ logger.info(
+ "Hello there, %s!", "wally", extra={"foo": "bar", "int": 3, "bool": True}
)
- # Send some debug messages
- for i in range(0, 3):
- logger.debug("debug %s" % (i,))
-
- # Send a bunch of useful messages
- for i in range(0, 10):
- logger.warn("test warn %s" % (i,))
-
- # Send a bunch of info messages
- for i in range(0, 3):
- logger.info("test message %s" % (i,))
-
- # The last debug message pushes it past the maximum buffer
- logger.debug("too much debug")
-
- # Allow the reconnection
- client, server = connect_client(self.reactor, 0)
- self.pump()
+ # One log message, with a single trailing newline.
+ data = output.getvalue()
+ logs = data.splitlines()
+ self.assertEqual(len(logs), 1)
+ self.assertEqual(data.count("\n"), 1)
+ log = json.loads(logs[0])
- # The 10 warnings made it through, the debugs and infos were elided
- logs = list(map(json.loads, server.data.decode("utf8").splitlines()))
- self.assertEqual(len(logs), 10)
+ # The terse logger should give us these keys.
+ expected_log_keys = [
+ "log",
+ "time",
+ "level",
+ "namespace",
+ # The additional keys given via extra.
+ "foo",
+ "int",
+ "bool",
+ ]
+ self.assertCountEqual(log.keys(), expected_log_keys)
- self.assertEqual(Counter([x["level"] for x in logs]), {"WARN": 10})
+ # Check the values of the extra fields.
+ self.assertEqual(log["foo"], "bar")
+ self.assertEqual(log["int"], 3)
+ self.assertIs(log["bool"], True)
- def test_log_backpressure_cut_middle(self):
+ def test_json_output(self):
"""
- When backpressure is hit, and no more DEBUG and INFOs cannot be culled,
- it will cut the middle messages out.
+ The Terse JSON formatter converts log messages to JSON.
"""
- log_config = {
- "loggers": {"synapse": {"level": "DEBUG"}},
- "drains": {
- "tersejson": {
- "type": "network_json_terse",
- "host": "127.0.0.1",
- "port": 8000,
- "maximum_buffer": 10,
- }
- },
- }
-
- # Begin the logger with our config
- beginner = FakeBeginner()
- setup_structured_logging(
- self.hs,
- self.hs.config,
- log_config,
- logBeginner=beginner,
- redirect_stdlib_logging=False,
- )
+ output = StringIO()
- logger = Logger(
- namespace="synapse.logging.test_terse_json", observer=beginner.observers[0]
- )
+ handler = logging.StreamHandler(output)
+ handler.setFormatter(JsonFormatter())
+ logger = self.get_logger(handler)
- # Send a bunch of useful messages
- for i in range(0, 20):
- logger.warn("test warn", num=i)
+ logger.info("Hello there, %s!", "wally")
- # Allow the reconnection
- client, server = connect_client(self.reactor, 0)
- self.pump()
+ # One log message, with a single trailing newline.
+ data = output.getvalue()
+ logs = data.splitlines()
+ self.assertEqual(len(logs), 1)
+ self.assertEqual(data.count("\n"), 1)
+ log = json.loads(logs[0])
- # The first five and last five warnings made it through, the debugs and
- # infos were elided
- logs = list(map(json.loads, server.data.decode("utf8").splitlines()))
- self.assertEqual(len(logs), 10)
- self.assertEqual(Counter([x["level"] for x in logs]), {"WARN": 10})
- self.assertEqual([0, 1, 2, 3, 4, 15, 16, 17, 18, 19], [x["num"] for x in logs])
+ # The terse logger should give us these keys.
+ expected_log_keys = [
+ "log",
+ "level",
+ "namespace",
+ ]
+ self.assertCountEqual(log.keys(), expected_log_keys)
+ self.assertEqual(log["log"], "Hello there, wally!")
diff --git a/tests/push/test_email.py b/tests/push/test_email.py
index 55545d9341..bcdcafa5a9 100644
--- a/tests/push/test_email.py
+++ b/tests/push/test_email.py
@@ -100,7 +100,7 @@ class EmailPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastore().get_user_by_access_token(self.access_token)
)
- token_id = user_tuple["token_id"]
+ token_id = user_tuple.token_id
self.pusher = self.get_success(
self.hs.get_pusherpool().add_pusher(
@@ -131,6 +131,35 @@ class EmailPusherTests(HomeserverTestCase):
# We should get emailed about that message
self._check_for_mail()
+ def test_invite_sends_email(self):
+ # Create a room and invite the user to it
+ room = self.helper.create_room_as(self.others[0].id, tok=self.others[0].token)
+ self.helper.invite(
+ room=room,
+ src=self.others[0].id,
+ tok=self.others[0].token,
+ targ=self.user_id,
+ )
+
+ # We should get emailed about the invite
+ self._check_for_mail()
+
+ def test_invite_to_empty_room_sends_email(self):
+ # Create a room and invite the user to it
+ room = self.helper.create_room_as(self.others[0].id, tok=self.others[0].token)
+ self.helper.invite(
+ room=room,
+ src=self.others[0].id,
+ tok=self.others[0].token,
+ targ=self.user_id,
+ )
+
+ # Then have the original user leave
+ self.helper.leave(room, self.others[0].id, tok=self.others[0].token)
+
+ # We should get emailed about the invite
+ self._check_for_mail()
+
def test_multiple_members_email(self):
# We want to test multiple notifications, so we pause processing of push
# while we send messages.
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index b567868b02..8571924b29 100644
--- a/tests/push/test_http.py
+++ b/tests/push/test_http.py
@@ -69,7 +69,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastore().get_user_by_access_token(access_token)
)
- token_id = user_tuple["token_id"]
+ token_id = user_tuple.token_id
self.get_success(
self.hs.get_pusherpool().add_pusher(
@@ -181,7 +181,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastore().get_user_by_access_token(access_token)
)
- token_id = user_tuple["token_id"]
+ token_id = user_tuple.token_id
self.get_success(
self.hs.get_pusherpool().add_pusher(
@@ -297,7 +297,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastore().get_user_by_access_token(access_token)
)
- token_id = user_tuple["token_id"]
+ token_id = user_tuple.token_id
self.get_success(
self.hs.get_pusherpool().add_pusher(
@@ -379,7 +379,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastore().get_user_by_access_token(access_token)
)
- token_id = user_tuple["token_id"]
+ token_id = user_tuple.token_id
self.get_success(
self.hs.get_pusherpool().add_pusher(
@@ -452,7 +452,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastore().get_user_by_access_token(access_token)
)
- token_id = user_tuple["token_id"]
+ token_id = user_tuple.token_id
self.get_success(
self.hs.get_pusherpool().add_pusher(
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 093e2faac7..5c633ac6df 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -16,7 +16,6 @@ import logging
from typing import Any, Callable, List, Optional, Tuple
import attr
-import hiredis
from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
from twisted.internet.protocol import Protocol
@@ -39,12 +38,22 @@ from synapse.util import Clock
from tests import unittest
from tests.server import FakeTransport, render
+try:
+ import hiredis
+except ImportError:
+ hiredis = None
+
logger = logging.getLogger(__name__)
class BaseStreamTestCase(unittest.HomeserverTestCase):
"""Base class for tests of the replication streams"""
+ # hiredis is an optional dependency so we don't want to require it for running
+ # the tests.
+ if not hiredis:
+ skip = "Requires hiredis"
+
servlets = [
streams.register_servlets,
]
@@ -269,7 +278,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
homeserver_to_use=GenericWorkerServer,
config=config,
reactor=self.reactor,
- **kwargs
+ **kwargs,
)
# If the instance is in the `instance_map` config then workers may try
diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py
index c9998e88e6..bad0df08cf 100644
--- a/tests/replication/tcp/streams/test_events.py
+++ b/tests/replication/tcp/streams/test_events.py
@@ -449,7 +449,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
sender=sender,
type="test_event",
content={"body": body},
- **kwargs
+ **kwargs,
)
)
diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py
new file mode 100644
index 0000000000..77c261dbf7
--- /dev/null
+++ b/tests/replication/test_multi_media_repo.py
@@ -0,0 +1,277 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+import os
+from binascii import unhexlify
+from typing import Tuple
+
+from twisted.internet.protocol import Factory
+from twisted.protocols.tls import TLSMemoryBIOFactory
+from twisted.web.http import HTTPChannel
+from twisted.web.server import Request
+
+from synapse.rest import admin
+from synapse.rest.client.v1 import login
+from synapse.server import HomeServer
+
+from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file
+from tests.replication._base import BaseMultiWorkerStreamTestCase
+from tests.server import FakeChannel, FakeTransport
+
+logger = logging.getLogger(__name__)
+
+test_server_connection_factory = None
+
+
+class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
+ """Checks running multiple media repos work correctly.
+ """
+
+ servlets = [
+ admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.user_id = self.register_user("user", "pass")
+ self.access_token = self.login("user", "pass")
+
+ self.reactor.lookups["example.com"] = "127.0.0.2"
+
+ def default_config(self):
+ conf = super().default_config()
+ conf["federation_custom_ca_list"] = [get_test_ca_cert_file()]
+ return conf
+
+ def _get_media_req(
+ self, hs: HomeServer, target: str, media_id: str
+ ) -> Tuple[FakeChannel, Request]:
+ """Request some remote media from the given HS by calling the download
+ API.
+
+ This then triggers an outbound request from the HS to the target.
+
+ Returns:
+ The channel for the *client* request and the *outbound* request for
+ the media which the caller should respond to.
+ """
+
+ request, channel = self.make_request(
+ "GET",
+ "/{}/{}".format(target, media_id),
+ shorthand=False,
+ access_token=self.access_token,
+ )
+ request.render(hs.get_media_repository_resource().children[b"download"])
+ self.pump()
+
+ clients = self.reactor.tcpClients
+ self.assertGreaterEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients.pop()
+
+ # build the test server
+ server_tls_protocol = _build_test_server(get_connection_factory())
+
+ # now, tell the client protocol factory to build the client protocol (it will be a
+ # _WrappingProtocol, around a TLSMemoryBIOProtocol, around an
+ # HTTP11ClientProtocol) and wire the output of said protocol up to the server via
+ # a FakeTransport.
+ #
+ # Normally this would be done by the TCP socket code in Twisted, but we are
+ # stubbing that out here.
+ client_protocol = client_factory.buildProtocol(None)
+ client_protocol.makeConnection(
+ FakeTransport(server_tls_protocol, self.reactor, client_protocol)
+ )
+
+ # tell the server tls protocol to send its stuff back to the client, too
+ server_tls_protocol.makeConnection(
+ FakeTransport(client_protocol, self.reactor, server_tls_protocol)
+ )
+
+ # fish the test server back out of the server-side TLS protocol.
+ http_server = server_tls_protocol.wrappedProtocol
+
+ # give the reactor a pump to get the TLS juices flowing.
+ self.reactor.pump((0.1,))
+
+ self.assertEqual(len(http_server.requests), 1)
+ request = http_server.requests[0]
+
+ self.assertEqual(request.method, b"GET")
+ self.assertEqual(
+ request.path,
+ "/_matrix/media/r0/download/{}/{}".format(target, media_id).encode("utf-8"),
+ )
+ self.assertEqual(
+ request.requestHeaders.getRawHeaders(b"host"), [target.encode("utf-8")]
+ )
+
+ return channel, request
+
+ def test_basic(self):
+ """Test basic fetching of remote media from a single worker.
+ """
+ hs1 = self.make_worker_hs("synapse.app.generic_worker")
+
+ channel, request = self._get_media_req(hs1, "example.com:443", "ABC123")
+
+ request.setResponseCode(200)
+ request.responseHeaders.setRawHeaders(b"Content-Type", [b"text/plain"])
+ request.write(b"Hello!")
+ request.finish()
+
+ self.pump(0.1)
+
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.result["body"], b"Hello!")
+
+ def test_download_simple_file_race(self):
+ """Test that fetching remote media from two different processes at the
+ same time works.
+ """
+ hs1 = self.make_worker_hs("synapse.app.generic_worker")
+ hs2 = self.make_worker_hs("synapse.app.generic_worker")
+
+ start_count = self._count_remote_media()
+
+ # Make two requests without responding to the outbound media requests.
+ channel1, request1 = self._get_media_req(hs1, "example.com:443", "ABC123")
+ channel2, request2 = self._get_media_req(hs2, "example.com:443", "ABC123")
+
+ # Respond to the first outbound media request and check that the client
+ # request is successful
+ request1.setResponseCode(200)
+ request1.responseHeaders.setRawHeaders(b"Content-Type", [b"text/plain"])
+ request1.write(b"Hello!")
+ request1.finish()
+
+ self.pump(0.1)
+
+ self.assertEqual(channel1.code, 200, channel1.result["body"])
+ self.assertEqual(channel1.result["body"], b"Hello!")
+
+ # Now respond to the second with the same content.
+ request2.setResponseCode(200)
+ request2.responseHeaders.setRawHeaders(b"Content-Type", [b"text/plain"])
+ request2.write(b"Hello!")
+ request2.finish()
+
+ self.pump(0.1)
+
+ self.assertEqual(channel2.code, 200, channel2.result["body"])
+ self.assertEqual(channel2.result["body"], b"Hello!")
+
+ # We expect only one new file to have been persisted.
+ self.assertEqual(start_count + 1, self._count_remote_media())
+
+ def test_download_image_race(self):
+ """Test that fetching remote *images* from two different processes at
+ the same time works.
+
+ This checks that races generating thumbnails are handled correctly.
+ """
+ hs1 = self.make_worker_hs("synapse.app.generic_worker")
+ hs2 = self.make_worker_hs("synapse.app.generic_worker")
+
+ start_count = self._count_remote_thumbnails()
+
+ channel1, request1 = self._get_media_req(hs1, "example.com:443", "PIC1")
+ channel2, request2 = self._get_media_req(hs2, "example.com:443", "PIC1")
+
+ png_data = unhexlify(
+ b"89504e470d0a1a0a0000000d4948445200000001000000010806"
+ b"0000001f15c4890000000a49444154789c63000100000500010d"
+ b"0a2db40000000049454e44ae426082"
+ )
+
+ request1.setResponseCode(200)
+ request1.responseHeaders.setRawHeaders(b"Content-Type", [b"image/png"])
+ request1.write(png_data)
+ request1.finish()
+
+ self.pump(0.1)
+
+ self.assertEqual(channel1.code, 200, channel1.result["body"])
+ self.assertEqual(channel1.result["body"], png_data)
+
+ request2.setResponseCode(200)
+ request2.responseHeaders.setRawHeaders(b"Content-Type", [b"image/png"])
+ request2.write(png_data)
+ request2.finish()
+
+ self.pump(0.1)
+
+ self.assertEqual(channel2.code, 200, channel2.result["body"])
+ self.assertEqual(channel2.result["body"], png_data)
+
+ # We expect only three new thumbnails to have been persisted.
+ self.assertEqual(start_count + 3, self._count_remote_thumbnails())
+
+ def _count_remote_media(self) -> int:
+ """Count the number of files in our remote media directory.
+ """
+ path = os.path.join(
+ self.hs.get_media_repository().primary_base_path, "remote_content"
+ )
+ return sum(len(files) for _, _, files in os.walk(path))
+
+ def _count_remote_thumbnails(self) -> int:
+ """Count the number of files in our remote thumbnails directory.
+ """
+ path = os.path.join(
+ self.hs.get_media_repository().primary_base_path, "remote_thumbnail"
+ )
+ return sum(len(files) for _, _, files in os.walk(path))
+
+
+def get_connection_factory():
+ # this needs to happen once, but not until we are ready to run the first test
+ global test_server_connection_factory
+ if test_server_connection_factory is None:
+ test_server_connection_factory = TestServerTLSConnectionFactory(
+ sanlist=[b"DNS:example.com"]
+ )
+ return test_server_connection_factory
+
+
+def _build_test_server(connection_creator):
+ """Construct a test server
+
+ This builds an HTTP channel, wrapped with a TLSMemoryBIOProtocol
+
+ Args:
+ connection_creator (IOpenSSLServerConnectionCreator): thing to build
+ SSL connections
+ sanlist (list[bytes]): list of the SAN entries for the cert returned
+ by the server
+
+ Returns:
+ TLSMemoryBIOProtocol
+ """
+ server_factory = Factory.forProtocol(HTTPChannel)
+ # Request.finish expects the factory to have a 'log' method.
+ server_factory.log = _log_request
+
+ server_tls_factory = TLSMemoryBIOFactory(
+ connection_creator, isClient=False, wrappedFactory=server_factory
+ )
+
+ return server_tls_factory.buildProtocol(None)
+
+
+def _log_request(request):
+ """Implements Factory.log, which is expected by Request.finish"""
+ logger.info("Completed request %s", request)
diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py
index 2bdc6edbb1..67c27a089f 100644
--- a/tests/replication/test_pusher_shard.py
+++ b/tests/replication/test_pusher_shard.py
@@ -55,7 +55,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
user_dict = self.get_success(
self.hs.get_datastore().get_user_by_access_token(access_token)
)
- token_id = user_dict["token_id"]
+ token_id = user_dict.token_id
self.get_success(
self.hs.get_pusherpool().add_pusher(
diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py
index 92c9058887..d89eb90cfe 100644
--- a/tests/rest/admin/test_device.py
+++ b/tests/rest/admin/test_device.py
@@ -393,6 +393,22 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only lookup local users", channel.json_body["error"])
+ def test_user_has_no_devices(self):
+ """
+ Tests that a normal lookup for devices is successfully
+ if user has no devices
+ """
+
+ # Get devices
+ request, channel = self.make_request(
+ "GET", self.url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(0, channel.json_body["total"])
+ self.assertEqual(0, len(channel.json_body["devices"]))
+
def test_get_devices(self):
"""
Tests that a normal lookup for devices is successfully
@@ -409,6 +425,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
self.render(request)
self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(number_devices, channel.json_body["total"])
self.assertEqual(number_devices, len(channel.json_body["devices"]))
self.assertEqual(self.other_user, channel.json_body["devices"][0]["user_id"])
# Check that all fields are available
diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py
index bf79086f78..303622217f 100644
--- a/tests/rest/admin/test_event_reports.py
+++ b/tests/rest/admin/test_event_reports.py
@@ -70,6 +70,16 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
self.url = "/_synapse/admin/v1/event_reports"
+ def test_no_auth(self):
+ """
+ Try to get an event report without authentication.
+ """
+ request, channel = self.make_request("GET", self.url, b"{}")
+ self.render(request)
+
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
def test_requester_is_no_admin(self):
"""
If the user is not a server admin, an error 403 is returned.
@@ -266,7 +276,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
def test_limit_is_negative(self):
"""
- Testing that a negative list parameter returns a 400
+ Testing that a negative limit parameter returns a 400
"""
request, channel = self.make_request(
@@ -360,7 +370,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
def _check_fields(self, content):
- """Checks that all attributes are present in a event report
+ """Checks that all attributes are present in an event report
"""
for c in content:
self.assertIn("id", c)
@@ -368,15 +378,175 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
self.assertIn("room_id", c)
self.assertIn("event_id", c)
self.assertIn("user_id", c)
- self.assertIn("reason", c)
- self.assertIn("content", c)
self.assertIn("sender", c)
- self.assertIn("room_alias", c)
- self.assertIn("event_json", c)
- self.assertIn("score", c["content"])
- self.assertIn("reason", c["content"])
- self.assertIn("auth_events", c["event_json"])
- self.assertIn("type", c["event_json"])
- self.assertIn("room_id", c["event_json"])
- self.assertIn("sender", c["event_json"])
- self.assertIn("content", c["event_json"])
+ self.assertIn("canonical_alias", c)
+ self.assertIn("name", c)
+ self.assertIn("score", c)
+ self.assertIn("reason", c)
+
+
+class EventReportDetailTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ report_event.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.other_user_tok = self.login("user", "pass")
+
+ self.room_id1 = self.helper.create_room_as(
+ self.other_user, tok=self.other_user_tok, is_public=True
+ )
+ self.helper.join(self.room_id1, user=self.admin_user, tok=self.admin_user_tok)
+
+ self._create_event_and_report(
+ room_id=self.room_id1, user_tok=self.other_user_tok,
+ )
+
+ # first created event report gets `id`=2
+ self.url = "/_synapse/admin/v1/event_reports/2"
+
+ def test_no_auth(self):
+ """
+ Try to get event report without authentication.
+ """
+ request, channel = self.make_request("GET", self.url, b"{}")
+ self.render(request)
+
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_requester_is_no_admin(self):
+ """
+ If the user is not a server admin, an error 403 is returned.
+ """
+
+ request, channel = self.make_request(
+ "GET", self.url, access_token=self.other_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_default_success(self):
+ """
+ Testing get a reported event
+ """
+
+ request, channel = self.make_request(
+ "GET", self.url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self._check_fields(channel.json_body)
+
+ def test_invalid_report_id(self):
+ """
+ Testing that an invalid `report_id` returns a 400.
+ """
+
+ # `report_id` is negative
+ request, channel = self.make_request(
+ "GET",
+ "/_synapse/admin/v1/event_reports/-123",
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+ self.assertEqual(
+ "The report_id parameter must be a string representing a positive integer.",
+ channel.json_body["error"],
+ )
+
+ # `report_id` is a non-numerical string
+ request, channel = self.make_request(
+ "GET",
+ "/_synapse/admin/v1/event_reports/abcdef",
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+ self.assertEqual(
+ "The report_id parameter must be a string representing a positive integer.",
+ channel.json_body["error"],
+ )
+
+ # `report_id` is undefined
+ request, channel = self.make_request(
+ "GET",
+ "/_synapse/admin/v1/event_reports/",
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+ self.assertEqual(
+ "The report_id parameter must be a string representing a positive integer.",
+ channel.json_body["error"],
+ )
+
+ def test_report_id_not_found(self):
+ """
+ Testing that a not existing `report_id` returns a 404.
+ """
+
+ request, channel = self.make_request(
+ "GET",
+ "/_synapse/admin/v1/event_reports/123",
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+ self.assertEqual("Event report not found", channel.json_body["error"])
+
+ def _create_event_and_report(self, room_id, user_tok):
+ """Create and report events
+ """
+ resp = self.helper.send(room_id, tok=user_tok)
+ event_id = resp["event_id"]
+
+ request, channel = self.make_request(
+ "POST",
+ "rooms/%s/report/%s" % (room_id, event_id),
+ json.dumps({"score": -100, "reason": "this makes me sad"}),
+ access_token=user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ def _check_fields(self, content):
+ """Checks that all attributes are present in a event report
+ """
+ self.assertIn("id", content)
+ self.assertIn("received_ts", content)
+ self.assertIn("room_id", content)
+ self.assertIn("event_id", content)
+ self.assertIn("user_id", content)
+ self.assertIn("sender", content)
+ self.assertIn("canonical_alias", content)
+ self.assertIn("name", content)
+ self.assertIn("event_json", content)
+ self.assertIn("score", content)
+ self.assertIn("reason", content)
+ self.assertIn("auth_events", content["event_json"])
+ self.assertIn("type", content["event_json"])
+ self.assertIn("room_id", content["event_json"])
+ self.assertIn("sender", content["event_json"])
+ self.assertIn("content", content["event_json"])
diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py
new file mode 100644
index 0000000000..721fa1ed51
--- /dev/null
+++ b/tests/rest/admin/test_media.py
@@ -0,0 +1,568 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Dirk Klimpel
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import os
+from binascii import unhexlify
+
+import synapse.rest.admin
+from synapse.api.errors import Codes
+from synapse.rest.client.v1 import login, profile, room
+from synapse.rest.media.v1.filepath import MediaFilePaths
+
+from tests import unittest
+
+
+class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ synapse.rest.admin.register_servlets_for_media_repo,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.handler = hs.get_device_handler()
+ self.media_repo = hs.get_media_repository_resource()
+ self.server_name = hs.hostname
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.filepaths = MediaFilePaths(hs.config.media_store_path)
+
+ def test_no_auth(self):
+ """
+ Try to delete media without authentication.
+ """
+ url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, "12345")
+
+ request, channel = self.make_request("DELETE", url, b"{}")
+ self.render(request)
+
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_requester_is_no_admin(self):
+ """
+ If the user is not a server admin, an error is returned.
+ """
+ self.other_user = self.register_user("user", "pass")
+ self.other_user_token = self.login("user", "pass")
+
+ url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, "12345")
+
+ request, channel = self.make_request(
+ "DELETE", url, access_token=self.other_user_token,
+ )
+ self.render(request)
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_media_does_not_exist(self):
+ """
+ Tests that a lookup for a media that does not exist returns a 404
+ """
+ url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, "12345")
+
+ request, channel = self.make_request(
+ "DELETE", url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ def test_media_is_not_local(self):
+ """
+ Tests that a lookup for a media that is not a local returns a 400
+ """
+ url = "/_synapse/admin/v1/media/%s/%s" % ("unknown_domain", "12345")
+
+ request, channel = self.make_request(
+ "DELETE", url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual("Can only delete local media", channel.json_body["error"])
+
+ def test_delete_media(self):
+ """
+ Tests that delete a media is successfully
+ """
+
+ download_resource = self.media_repo.children[b"download"]
+ upload_resource = self.media_repo.children[b"upload"]
+ image_data = unhexlify(
+ b"89504e470d0a1a0a0000000d4948445200000001000000010806"
+ b"0000001f15c4890000000a49444154789c63000100000500010d"
+ b"0a2db40000000049454e44ae426082"
+ )
+
+ # Upload some media into the room
+ response = self.helper.upload_media(
+ upload_resource, image_data, tok=self.admin_user_tok, expect_code=200
+ )
+ # Extract media ID from the response
+ server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://'
+ server_name, media_id = server_and_media_id.split("/")
+
+ self.assertEqual(server_name, self.server_name)
+
+ # Attempt to access media
+ request, channel = self.make_request(
+ "GET",
+ server_and_media_id,
+ shorthand=False,
+ access_token=self.admin_user_tok,
+ )
+ request.render(download_resource)
+ self.pump(1.0)
+
+ # Should be successful
+ self.assertEqual(
+ 200,
+ channel.code,
+ msg=(
+ "Expected to receive a 200 on accessing media: %s" % server_and_media_id
+ ),
+ )
+
+ # Test if the file exists
+ local_path = self.filepaths.local_media_filepath(media_id)
+ self.assertTrue(os.path.exists(local_path))
+
+ url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, media_id)
+
+ # Delete media
+ request, channel = self.make_request(
+ "DELETE", url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(1, channel.json_body["total"])
+ self.assertEqual(
+ media_id, channel.json_body["deleted_media"][0],
+ )
+
+ # Attempt to access media
+ request, channel = self.make_request(
+ "GET",
+ server_and_media_id,
+ shorthand=False,
+ access_token=self.admin_user_tok,
+ )
+ request.render(download_resource)
+ self.pump(1.0)
+ self.assertEqual(
+ 404,
+ channel.code,
+ msg=(
+ "Expected to receive a 404 on accessing deleted media: %s"
+ % server_and_media_id
+ ),
+ )
+
+ # Test if the file is deleted
+ self.assertFalse(os.path.exists(local_path))
+
+
+class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ synapse.rest.admin.register_servlets_for_media_repo,
+ login.register_servlets,
+ profile.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.handler = hs.get_device_handler()
+ self.media_repo = hs.get_media_repository_resource()
+ self.server_name = hs.hostname
+ self.clock = hs.clock
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.filepaths = MediaFilePaths(hs.config.media_store_path)
+ self.url = "/_synapse/admin/v1/media/%s/delete" % self.server_name
+
+ def test_no_auth(self):
+ """
+ Try to delete media without authentication.
+ """
+
+ request, channel = self.make_request("POST", self.url, b"{}")
+ self.render(request)
+
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_requester_is_no_admin(self):
+ """
+ If the user is not a server admin, an error is returned.
+ """
+ self.other_user = self.register_user("user", "pass")
+ self.other_user_token = self.login("user", "pass")
+
+ request, channel = self.make_request(
+ "POST", self.url, access_token=self.other_user_token,
+ )
+ self.render(request)
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_media_is_not_local(self):
+ """
+ Tests that a lookup for media that is not local returns a 400
+ """
+ url = "/_synapse/admin/v1/media/%s/delete" % "unknown_domain"
+
+ request, channel = self.make_request(
+ "POST", url + "?before_ts=1234", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual("Can only delete local media", channel.json_body["error"])
+
+ def test_missing_parameter(self):
+ """
+ If the parameter `before_ts` is missing, an error is returned.
+ """
+ request, channel = self.make_request(
+ "POST", self.url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
+ self.assertEqual(
+ "Missing integer query parameter b'before_ts'", channel.json_body["error"]
+ )
+
+ def test_invalid_parameter(self):
+ """
+ If parameters are invalid, an error is returned.
+ """
+ request, channel = self.make_request(
+ "POST", self.url + "?before_ts=-1234", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+ self.assertEqual(
+ "Query parameter before_ts must be a string representing a positive integer.",
+ channel.json_body["error"],
+ )
+
+ request, channel = self.make_request(
+ "POST",
+ self.url + "?before_ts=1234&size_gt=-1234",
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+ self.assertEqual(
+ "Query parameter size_gt must be a string representing a positive integer.",
+ channel.json_body["error"],
+ )
+
+ request, channel = self.make_request(
+ "POST",
+ self.url + "?before_ts=1234&keep_profiles=not_bool",
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+ self.assertEqual(
+ "Boolean query parameter b'keep_profiles' must be one of ['true', 'false']",
+ channel.json_body["error"],
+ )
+
+ def test_delete_media_never_accessed(self):
+ """
+ Tests that media deleted if it is older than `before_ts` and never accessed
+ `last_access_ts` is `NULL` and `created_ts` < `before_ts`
+ """
+
+ # upload and do not access
+ server_and_media_id = self._create_media()
+ self.pump(1.0)
+
+ # test that the file exists
+ media_id = server_and_media_id.split("/")[1]
+ local_path = self.filepaths.local_media_filepath(media_id)
+ self.assertTrue(os.path.exists(local_path))
+
+ # timestamp after upload/create
+ now_ms = self.clock.time_msec()
+ request, channel = self.make_request(
+ "POST",
+ self.url + "?before_ts=" + str(now_ms),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(1, channel.json_body["total"])
+ self.assertEqual(
+ media_id, channel.json_body["deleted_media"][0],
+ )
+
+ self._access_media(server_and_media_id, False)
+
+ def test_keep_media_by_date(self):
+ """
+ Tests that media is not deleted if it is newer than `before_ts`
+ """
+
+ # timestamp before upload
+ now_ms = self.clock.time_msec()
+ server_and_media_id = self._create_media()
+
+ self._access_media(server_and_media_id)
+
+ request, channel = self.make_request(
+ "POST",
+ self.url + "?before_ts=" + str(now_ms),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(0, channel.json_body["total"])
+
+ self._access_media(server_and_media_id)
+
+ # timestamp after upload
+ now_ms = self.clock.time_msec()
+ request, channel = self.make_request(
+ "POST",
+ self.url + "?before_ts=" + str(now_ms),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(1, channel.json_body["total"])
+ self.assertEqual(
+ server_and_media_id.split("/")[1], channel.json_body["deleted_media"][0],
+ )
+
+ self._access_media(server_and_media_id, False)
+
+ def test_keep_media_by_size(self):
+ """
+ Tests that media is not deleted if its size is smaller than or equal
+ to `size_gt`
+ """
+ server_and_media_id = self._create_media()
+
+ self._access_media(server_and_media_id)
+
+ now_ms = self.clock.time_msec()
+ request, channel = self.make_request(
+ "POST",
+ self.url + "?before_ts=" + str(now_ms) + "&size_gt=67",
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(0, channel.json_body["total"])
+
+ self._access_media(server_and_media_id)
+
+ now_ms = self.clock.time_msec()
+ request, channel = self.make_request(
+ "POST",
+ self.url + "?before_ts=" + str(now_ms) + "&size_gt=66",
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(1, channel.json_body["total"])
+ self.assertEqual(
+ server_and_media_id.split("/")[1], channel.json_body["deleted_media"][0],
+ )
+
+ self._access_media(server_and_media_id, False)
+
+ def test_keep_media_by_user_avatar(self):
+ """
+ Tests that we do not delete media if is used as a user avatar
+ Tests parameter `keep_profiles`
+ """
+ server_and_media_id = self._create_media()
+
+ self._access_media(server_and_media_id)
+
+ # set media as avatar
+ request, channel = self.make_request(
+ "PUT",
+ "/profile/%s/avatar_url" % (self.admin_user,),
+ content=json.dumps({"avatar_url": "mxc://%s" % (server_and_media_id,)}),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ now_ms = self.clock.time_msec()
+ request, channel = self.make_request(
+ "POST",
+ self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=true",
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(0, channel.json_body["total"])
+
+ self._access_media(server_and_media_id)
+
+ now_ms = self.clock.time_msec()
+ request, channel = self.make_request(
+ "POST",
+ self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=false",
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(1, channel.json_body["total"])
+ self.assertEqual(
+ server_and_media_id.split("/")[1], channel.json_body["deleted_media"][0],
+ )
+
+ self._access_media(server_and_media_id, False)
+
+ def test_keep_media_by_room_avatar(self):
+ """
+ Tests that we do not delete media if it is used as a room avatar
+ Tests parameter `keep_profiles`
+ """
+ server_and_media_id = self._create_media()
+
+ self._access_media(server_and_media_id)
+
+ # set media as room avatar
+ room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+ request, channel = self.make_request(
+ "PUT",
+ "/rooms/%s/state/m.room.avatar" % (room_id,),
+ content=json.dumps({"url": "mxc://%s" % (server_and_media_id,)}),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ now_ms = self.clock.time_msec()
+ request, channel = self.make_request(
+ "POST",
+ self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=true",
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(0, channel.json_body["total"])
+
+ self._access_media(server_and_media_id)
+
+ now_ms = self.clock.time_msec()
+ request, channel = self.make_request(
+ "POST",
+ self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=false",
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(1, channel.json_body["total"])
+ self.assertEqual(
+ server_and_media_id.split("/")[1], channel.json_body["deleted_media"][0],
+ )
+
+ self._access_media(server_and_media_id, False)
+
+ def _create_media(self):
+ """
+ Create a media and return media_id and server_and_media_id
+ """
+ upload_resource = self.media_repo.children[b"upload"]
+ # file size is 67 Byte
+ image_data = unhexlify(
+ b"89504e470d0a1a0a0000000d4948445200000001000000010806"
+ b"0000001f15c4890000000a49444154789c63000100000500010d"
+ b"0a2db40000000049454e44ae426082"
+ )
+
+ # Upload some media into the room
+ response = self.helper.upload_media(
+ upload_resource, image_data, tok=self.admin_user_tok, expect_code=200
+ )
+ # Extract media ID from the response
+ server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://'
+ server_name = server_and_media_id.split("/")[0]
+
+ # Check that new media is a local and not remote
+ self.assertEqual(server_name, self.server_name)
+
+ return server_and_media_id
+
+ def _access_media(self, server_and_media_id, expect_success=True):
+ """
+ Try to access a media and check the result
+ """
+ download_resource = self.media_repo.children[b"download"]
+
+ media_id = server_and_media_id.split("/")[1]
+ local_path = self.filepaths.local_media_filepath(media_id)
+
+ request, channel = self.make_request(
+ "GET",
+ server_and_media_id,
+ shorthand=False,
+ access_token=self.admin_user_tok,
+ )
+ request.render(download_resource)
+ self.pump(1.0)
+
+ if expect_success:
+ self.assertEqual(
+ 200,
+ channel.code,
+ msg=(
+ "Expected to receive a 200 on accessing media: %s"
+ % server_and_media_id
+ ),
+ )
+ # Test that the file exists
+ self.assertTrue(os.path.exists(local_path))
+ else:
+ self.assertEqual(
+ 404,
+ channel.code,
+ msg=(
+ "Expected to receive a 404 on accessing deleted media: %s"
+ % (server_and_media_id)
+ ),
+ )
+ # Test that the file is deleted
+ self.assertFalse(os.path.exists(local_path))
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 98d0623734..7df32e5093 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -17,6 +17,7 @@ import hashlib
import hmac
import json
import urllib.parse
+from binascii import unhexlify
from mock import Mock
@@ -1016,7 +1017,6 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
- sync.register_servlets,
room.register_servlets,
]
@@ -1082,6 +1082,21 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only lookup local users", channel.json_body["error"])
+ def test_no_memberships(self):
+ """
+ Tests that a normal lookup for rooms is successfully
+ if user has no memberships
+ """
+ # Get rooms
+ request, channel = self.make_request(
+ "GET", self.url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(0, channel.json_body["total"])
+ self.assertEqual(0, len(channel.json_body["joined_rooms"]))
+
def test_get_rooms(self):
"""
Tests that a normal lookup for rooms is successfully
@@ -1101,3 +1116,408 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(number_rooms, channel.json_body["total"])
self.assertEqual(number_rooms, len(channel.json_body["joined_rooms"]))
+
+
+class PushersRestTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.url = "/_synapse/admin/v1/users/%s/pushers" % urllib.parse.quote(
+ self.other_user
+ )
+
+ def test_no_auth(self):
+ """
+ Try to list pushers of an user without authentication.
+ """
+ request, channel = self.make_request("GET", self.url, b"{}")
+ self.render(request)
+
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_requester_is_no_admin(self):
+ """
+ If the user is not a server admin, an error is returned.
+ """
+ other_user_token = self.login("user", "pass")
+
+ request, channel = self.make_request(
+ "GET", self.url, access_token=other_user_token,
+ )
+ self.render(request)
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_user_does_not_exist(self):
+ """
+ Tests that a lookup for a user that does not exist returns a 404
+ """
+ url = "/_synapse/admin/v1/users/@unknown_person:test/pushers"
+ request, channel = self.make_request(
+ "GET", url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ def test_user_is_not_local(self):
+ """
+ Tests that a lookup for a user that is not a local returns a 400
+ """
+ url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/pushers"
+
+ request, channel = self.make_request(
+ "GET", url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual("Can only lookup local users", channel.json_body["error"])
+
+ def test_get_pushers(self):
+ """
+ Tests that a normal lookup for pushers is successfully
+ """
+
+ # Get pushers
+ request, channel = self.make_request(
+ "GET", self.url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(0, channel.json_body["total"])
+
+ # Register the pusher
+ other_user_token = self.login("user", "pass")
+ user_tuple = self.get_success(
+ self.store.get_user_by_access_token(other_user_token)
+ )
+ token_id = user_tuple.token_id
+
+ self.get_success(
+ self.hs.get_pusherpool().add_pusher(
+ user_id=self.other_user,
+ access_token=token_id,
+ kind="http",
+ app_id="m.http",
+ app_display_name="HTTP Push Notifications",
+ device_display_name="pushy push",
+ pushkey="a@example.com",
+ lang=None,
+ data={"url": "example.com"},
+ )
+ )
+
+ # Get pushers
+ request, channel = self.make_request(
+ "GET", self.url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(1, channel.json_body["total"])
+
+ for p in channel.json_body["pushers"]:
+ self.assertIn("pushkey", p)
+ self.assertIn("kind", p)
+ self.assertIn("app_id", p)
+ self.assertIn("app_display_name", p)
+ self.assertIn("device_display_name", p)
+ self.assertIn("profile_tag", p)
+ self.assertIn("lang", p)
+ self.assertIn("url", p["data"])
+
+
+class UserMediaRestTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+ self.media_repo = hs.get_media_repository_resource()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.url = "/_synapse/admin/v1/users/%s/media" % urllib.parse.quote(
+ self.other_user
+ )
+
+ def test_no_auth(self):
+ """
+ Try to list media of an user without authentication.
+ """
+ request, channel = self.make_request("GET", self.url, b"{}")
+ self.render(request)
+
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_requester_is_no_admin(self):
+ """
+ If the user is not a server admin, an error is returned.
+ """
+ other_user_token = self.login("user", "pass")
+
+ request, channel = self.make_request(
+ "GET", self.url, access_token=other_user_token,
+ )
+ self.render(request)
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_user_does_not_exist(self):
+ """
+ Tests that a lookup for a user that does not exist returns a 404
+ """
+ url = "/_synapse/admin/v1/users/@unknown_person:test/media"
+ request, channel = self.make_request(
+ "GET", url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ def test_user_is_not_local(self):
+ """
+ Tests that a lookup for a user that is not a local returns a 400
+ """
+ url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/media"
+
+ request, channel = self.make_request(
+ "GET", url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual("Can only lookup local users", channel.json_body["error"])
+
+ def test_limit(self):
+ """
+ Testing list of media with limit
+ """
+
+ number_media = 20
+ other_user_tok = self.login("user", "pass")
+ self._create_media(other_user_tok, number_media)
+
+ request, channel = self.make_request(
+ "GET", self.url + "?limit=5", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], number_media)
+ self.assertEqual(len(channel.json_body["media"]), 5)
+ self.assertEqual(channel.json_body["next_token"], 5)
+ self._check_fields(channel.json_body["media"])
+
+ def test_from(self):
+ """
+ Testing list of media with a defined starting point (from)
+ """
+
+ number_media = 20
+ other_user_tok = self.login("user", "pass")
+ self._create_media(other_user_tok, number_media)
+
+ request, channel = self.make_request(
+ "GET", self.url + "?from=5", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], number_media)
+ self.assertEqual(len(channel.json_body["media"]), 15)
+ self.assertNotIn("next_token", channel.json_body)
+ self._check_fields(channel.json_body["media"])
+
+ def test_limit_and_from(self):
+ """
+ Testing list of media with a defined starting point and limit
+ """
+
+ number_media = 20
+ other_user_tok = self.login("user", "pass")
+ self._create_media(other_user_tok, number_media)
+
+ request, channel = self.make_request(
+ "GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], number_media)
+ self.assertEqual(channel.json_body["next_token"], 15)
+ self.assertEqual(len(channel.json_body["media"]), 10)
+ self._check_fields(channel.json_body["media"])
+
+ def test_limit_is_negative(self):
+ """
+ Testing that a negative limit parameter returns a 400
+ """
+
+ request, channel = self.make_request(
+ "GET", self.url + "?limit=-5", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+
+ def test_from_is_negative(self):
+ """
+ Testing that a negative from parameter returns a 400
+ """
+
+ request, channel = self.make_request(
+ "GET", self.url + "?from=-5", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+
+ def test_next_token(self):
+ """
+ Testing that `next_token` appears at the right place
+ """
+
+ number_media = 20
+ other_user_tok = self.login("user", "pass")
+ self._create_media(other_user_tok, number_media)
+
+ # `next_token` does not appear
+ # Number of results is the number of entries
+ request, channel = self.make_request(
+ "GET", self.url + "?limit=20", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], number_media)
+ self.assertEqual(len(channel.json_body["media"]), number_media)
+ self.assertNotIn("next_token", channel.json_body)
+
+ # `next_token` does not appear
+ # Number of max results is larger than the number of entries
+ request, channel = self.make_request(
+ "GET", self.url + "?limit=21", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], number_media)
+ self.assertEqual(len(channel.json_body["media"]), number_media)
+ self.assertNotIn("next_token", channel.json_body)
+
+ # `next_token` does appear
+ # Number of max results is smaller than the number of entries
+ request, channel = self.make_request(
+ "GET", self.url + "?limit=19", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], number_media)
+ self.assertEqual(len(channel.json_body["media"]), 19)
+ self.assertEqual(channel.json_body["next_token"], 19)
+
+ # Check
+ # Set `from` to value of `next_token` for request remaining entries
+ # `next_token` does not appear
+ request, channel = self.make_request(
+ "GET", self.url + "?from=19", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], number_media)
+ self.assertEqual(len(channel.json_body["media"]), 1)
+ self.assertNotIn("next_token", channel.json_body)
+
+ def test_user_has_no_media(self):
+ """
+ Tests that a normal lookup for media is successfully
+ if user has no media created
+ """
+
+ request, channel = self.make_request(
+ "GET", self.url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(0, channel.json_body["total"])
+ self.assertEqual(0, len(channel.json_body["media"]))
+
+ def test_get_media(self):
+ """
+ Tests that a normal lookup for media is successfully
+ """
+
+ number_media = 5
+ other_user_tok = self.login("user", "pass")
+ self._create_media(other_user_tok, number_media)
+
+ request, channel = self.make_request(
+ "GET", self.url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(number_media, channel.json_body["total"])
+ self.assertEqual(number_media, len(channel.json_body["media"]))
+ self.assertNotIn("next_token", channel.json_body)
+ self._check_fields(channel.json_body["media"])
+
+ def _create_media(self, user_token, number_media):
+ """
+ Create a number of media for a specific user
+ """
+ upload_resource = self.media_repo.children[b"upload"]
+ for i in range(number_media):
+ # file size is 67 Byte
+ image_data = unhexlify(
+ b"89504e470d0a1a0a0000000d4948445200000001000000010806"
+ b"0000001f15c4890000000a49444154789c63000100000500010d"
+ b"0a2db40000000049454e44ae426082"
+ )
+
+ # Upload some media into the room
+ self.helper.upload_media(
+ upload_resource, image_data, tok=user_token, expect_code=200
+ )
+
+ def _check_fields(self, content):
+ """Checks that all attributes are present in content
+ """
+ for m in content:
+ self.assertIn("media_id", m)
+ self.assertIn("media_type", m)
+ self.assertIn("media_length", m)
+ self.assertIn("upload_name", m)
+ self.assertIn("created_ts", m)
+ self.assertIn("last_access_ts", m)
+ self.assertIn("quarantined_by", m)
+ self.assertIn("safe_from_quarantine", m)
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 2fc3a60fc5..98c3887bbf 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -55,6 +55,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.hs.config.server_name,
id="1234",
namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
+ sender="@as:test",
)
self.hs.get_datastore().services_cache.append(appservice)
diff --git a/tests/server.py b/tests/server.py
index 4d33b84097..3dd2cfc072 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -46,7 +46,7 @@ class FakeChannel:
site = attr.ib(type=Site)
_reactor = attr.ib()
- result = attr.ib(default=attr.Factory(dict))
+ result = attr.ib(type=dict, default=attr.Factory(dict))
_producer = None
@property
@@ -380,7 +380,7 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs):
pool._runWithConnection,
func,
*args,
- **kwargs
+ **kwargs,
)
def runInteraction(interaction, *args, **kwargs):
@@ -390,7 +390,7 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs):
pool._runInteraction,
interaction,
*args,
- **kwargs
+ **kwargs,
)
pool.runWithConnection = runWithConnection
@@ -571,12 +571,10 @@ def connect_client(reactor: IReactorTCP, client_id: int) -> AccumulatingProtocol
reactor
factory: The connecting factory to build.
"""
- factory = reactor.tcpClients[client_id][2]
+ factory = reactor.tcpClients.pop(client_id)[2]
client = factory.buildProtocol(None)
server = AccumulatingProtocol()
server.makeConnection(FakeTransport(client, reactor))
client.makeConnection(FakeTransport(server, reactor))
- reactor.tcpClients.pop(client_id)
-
return client, server
diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index 080761d1d2..5a1e5c4e66 100644
--- a/tests/storage/test_cleanup_extrems.py
+++ b/tests/storage/test_cleanup_extrems.py
@@ -22,7 +22,7 @@ import synapse.rest.admin
from synapse.api.constants import EventTypes
from synapse.rest.client.v1 import login, room
from synapse.storage import prepare_database
-from synapse.types import Requester, UserID
+from synapse.types import UserID, create_requester
from tests.unittest import HomeserverTestCase
@@ -38,7 +38,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
# Create a test user and room
self.user = UserID("alice", "test")
- self.requester = Requester(self.user, None, False, False, None, None)
+ self.requester = create_requester(self.user)
info, _ = self.get_success(self.room_creator.create_room(self.requester, {}))
self.room_id = info["room_id"]
@@ -260,7 +260,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
# Create a test user and room
self.user = UserID.from_string(self.register_user("user1", "password"))
self.token1 = self.login("user1", "password")
- self.requester = Requester(self.user, None, False, False, None, None)
+ self.requester = create_requester(self.user)
info, _ = self.get_success(self.room_creator.create_room(self.requester, {}))
self.room_id = info["room_id"]
self.event_creator = homeserver.get_event_creation_handler()
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 755c70db31..e96ca1c8ca 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -412,7 +412,7 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase):
"GET",
"/_matrix/client/r0/admin/users/" + self.user_id,
access_token=access_token,
- **make_request_args
+ **make_request_args,
)
request.requestHeaders.addRawHeader(b"User-Agent", b"Mozzila pizza")
diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py
index 3957471f3f..7691f2d790 100644
--- a/tests/storage/test_event_metrics.py
+++ b/tests/storage/test_event_metrics.py
@@ -14,7 +14,7 @@
# limitations under the License.
from synapse.metrics import REGISTRY, generate_latest
-from synapse.types import Requester, UserID
+from synapse.types import UserID, create_requester
from tests.unittest import HomeserverTestCase
@@ -27,7 +27,7 @@ class ExtremStatisticsTestCase(HomeserverTestCase):
room_creator = self.hs.get_room_creation_handler()
user = UserID("alice", "test")
- requester = Requester(user, None, False, False, None, None)
+ requester = create_requester(user)
# Real events, forward extremities
events = [(3, 2), (6, 2), (4, 6)]
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 6b582771fe..c8c7a90e5d 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -69,11 +69,9 @@ class RegistrationStoreTestCase(unittest.TestCase):
self.store.get_user_by_access_token(self.tokens[1])
)
- self.assertDictContainsSubset(
- {"name": self.user_id, "device_id": self.device_id}, result
- )
-
- self.assertTrue("token_id" in result)
+ self.assertEqual(result.user_id, self.user_id)
+ self.assertEqual(result.device_id, self.device_id)
+ self.assertIsNotNone(result.token_id)
@defer.inlineCallbacks
def test_user_delete_access_tokens(self):
@@ -105,7 +103,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
user = yield defer.ensureDeferred(
self.store.get_user_by_access_token(self.tokens[0])
)
- self.assertEqual(self.user_id, user["name"])
+ self.assertEqual(self.user_id, user.user_id)
# now delete the rest
yield defer.ensureDeferred(self.store.user_delete_access_tokens(self.user_id))
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 12ccc1f53e..ff972daeaa 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -19,7 +19,7 @@ from unittest.mock import Mock
from synapse.api.constants import Membership
from synapse.rest.admin import register_servlets_for_client_rest_resource
from synapse.rest.client.v1 import login, room
-from synapse.types import Requester, UserID
+from synapse.types import UserID, create_requester
from tests import unittest
from tests.test_utils import event_injection
@@ -187,7 +187,7 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
# Now let's create a room, which will insert a membership
user = UserID("alice", "test")
- requester = Requester(user, None, False, False, None, None)
+ requester = create_requester(user)
self.get_success(self.room_creator.create_room(requester, {}))
# Register the background update to run again.
diff --git a/tests/test_federation.py b/tests/test_federation.py
index d39e792580..1ce4ea3a01 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -20,7 +20,7 @@ from twisted.internet.defer import succeed
from synapse.api.errors import FederationError
from synapse.events import make_event_from_dict
from synapse.logging.context import LoggingContext
-from synapse.types import Requester, UserID
+from synapse.types import UserID, create_requester
from synapse.util import Clock
from synapse.util.retryutils import NotRetryingDestination
@@ -43,7 +43,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
)
user_id = UserID("us", "test")
- our_user = Requester(user_id, None, False, False, None, None)
+ our_user = create_requester(user_id)
room_creator = self.homeserver.get_room_creation_handler()
self.room_id = self.get_success(
room_creator.create_room(
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index a298cc0fd3..d232b72264 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -17,8 +17,10 @@
"""
Utilities for running the unit tests
"""
+import sys
+import warnings
from asyncio import Future
-from typing import Any, Awaitable, TypeVar
+from typing import Any, Awaitable, Callable, TypeVar
TV = TypeVar("TV")
@@ -48,3 +50,33 @@ def make_awaitable(result: Any) -> Awaitable[Any]:
future = Future() # type: ignore
future.set_result(result)
return future
+
+
+def setup_awaitable_errors() -> Callable[[], None]:
+ """
+ Convert warnings from a non-awaited coroutines into errors.
+ """
+ warnings.simplefilter("error", RuntimeWarning)
+
+ # unraisablehook was added in Python 3.8.
+ if not hasattr(sys, "unraisablehook"):
+ return lambda: None
+
+ # State shared between unraisablehook and check_for_unraisable_exceptions.
+ unraisable_exceptions = []
+ orig_unraisablehook = sys.unraisablehook # type: ignore
+
+ def unraisablehook(unraisable):
+ unraisable_exceptions.append(unraisable.exc_value)
+
+ def cleanup():
+ """
+ A method to be used as a clean-up that fails a test-case if there are any new unraisable exceptions.
+ """
+ sys.unraisablehook = orig_unraisablehook # type: ignore
+ if unraisable_exceptions:
+ raise unraisable_exceptions.pop()
+
+ sys.unraisablehook = unraisablehook # type: ignore
+
+ return cleanup
diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py
index e93aa84405..c3c4a93e1f 100644
--- a/tests/test_utils/event_injection.py
+++ b/tests/test_utils/event_injection.py
@@ -50,7 +50,7 @@ async def inject_member_event(
sender=sender,
state_key=target,
content=content,
- **kwargs
+ **kwargs,
)
diff --git a/tests/unittest.py b/tests/unittest.py
index 040b126a27..08cf9b10c5 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -44,7 +44,7 @@ from synapse.logging.context import (
set_current_context,
)
from synapse.server import HomeServer
-from synapse.types import Requester, UserID, create_requester
+from synapse.types import UserID, create_requester
from synapse.util.ratelimitutils import FederationRateLimiter
from tests.server import (
@@ -54,7 +54,7 @@ from tests.server import (
render,
setup_test_homeserver,
)
-from tests.test_utils import event_injection
+from tests.test_utils import event_injection, setup_awaitable_errors
from tests.test_utils.logging_setup import setup_logging
from tests.utils import default_config, setupdb
@@ -119,6 +119,10 @@ class TestCase(unittest.TestCase):
logging.getLogger().setLevel(level)
+ # Trial messes with the warnings configuration, thus this has to be
+ # done in the context of an individual TestCase.
+ self.addCleanup(setup_awaitable_errors())
+
return orig()
@around(self)
@@ -627,7 +631,7 @@ class HomeserverTestCase(TestCase):
"""
event_creator = self.hs.get_event_creation_handler()
secrets = self.hs.get_secrets()
- requester = Requester(user, None, False, False, None, None)
+ requester = create_requester(user)
event, context = self.get_success(
event_creator.create_event(
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 2ad08f541b..cf1e3203a4 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -29,13 +29,46 @@ from synapse.logging.context import (
make_deferred_yieldable,
)
from synapse.util.caches import descriptors
-from synapse.util.caches.descriptors import cached
+from synapse.util.caches.descriptors import cached, lru_cache
from tests import unittest
+from tests.test_utils import get_awaitable_result
logger = logging.getLogger(__name__)
+class LruCacheDecoratorTestCase(unittest.TestCase):
+ def test_base(self):
+ class Cls:
+ def __init__(self):
+ self.mock = mock.Mock()
+
+ @lru_cache()
+ def fn(self, arg1, arg2):
+ return self.mock(arg1, arg2)
+
+ obj = Cls()
+ obj.mock.return_value = "fish"
+ r = obj.fn(1, 2)
+ self.assertEqual(r, "fish")
+ obj.mock.assert_called_once_with(1, 2)
+ obj.mock.reset_mock()
+
+ # a call with different params should call the mock again
+ obj.mock.return_value = "chips"
+ r = obj.fn(1, 3)
+ self.assertEqual(r, "chips")
+ obj.mock.assert_called_once_with(1, 3)
+ obj.mock.reset_mock()
+
+ # the two values should now be cached
+ r = obj.fn(1, 2)
+ self.assertEqual(r, "fish")
+ r = obj.fn(1, 3)
+ self.assertEqual(r, "chips")
+ obj.mock.assert_not_called()
+
+
def run_on_reactor():
d = defer.Deferred()
reactor.callLater(0, d.callback, 0)
@@ -362,6 +395,31 @@ class DescriptorTestCase(unittest.TestCase):
d = obj.fn(1)
self.failureResultOf(d, SynapseError)
+ def test_invalidate_cascade(self):
+ """Invalidations should cascade up through cache contexts"""
+
+ class Cls:
+ @cached(cache_context=True)
+ async def func1(self, key, cache_context):
+ return await self.func2(key, on_invalidate=cache_context.invalidate)
+
+ @cached(cache_context=True)
+ async def func2(self, key, cache_context):
+ return self.func3(key, on_invalidate=cache_context.invalidate)
+
+ @lru_cache(cache_context=True)
+ def func3(self, key, cache_context):
+ self.invalidate = cache_context.invalidate
+ return 42
+
+ obj = Cls()
+
+ top_invalidate = mock.Mock()
+ r = get_awaitable_result(obj.func1("k1", on_invalidate=top_invalidate))
+ self.assertEqual(r, 42)
+ obj.invalidate()
+ top_invalidate.assert_called_once()
+
class CacheDecoratorTestCase(unittest.HomeserverTestCase):
"""More tests for @cached
|