diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py
index ba2a2bfd64..07d8105f41 100644
--- a/tests/appservice/test_appservice.py
+++ b/tests/appservice/test_appservice.py
@@ -19,6 +19,7 @@ from twisted.internet import defer
from synapse.appservice import ApplicationService, Namespace
from tests import unittest
+from tests.test_utils import simple_async_mock
def _regex(regex: str, exclusive: bool = True) -> Namespace:
@@ -91,10 +92,10 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#irc_.*:matrix.org")
)
- self.store.get_aliases_for_room.return_value = defer.succeed(
+ self.store.get_aliases_for_room = simple_async_mock(
["#irc_foobar:matrix.org", "#athing:matrix.org"]
)
- self.store.get_users_in_room.return_value = defer.succeed([])
+ self.store.get_users_in_room = simple_async_mock([])
self.assertTrue(
(
yield defer.ensureDeferred(
@@ -144,10 +145,10 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#irc_.*:matrix.org")
)
- self.store.get_aliases_for_room.return_value = defer.succeed(
+ self.store.get_aliases_for_room = simple_async_mock(
["#xmpp_foobar:matrix.org", "#athing:matrix.org"]
)
- self.store.get_users_in_room.return_value = defer.succeed([])
+ self.store.get_users_in_room = simple_async_mock([])
self.assertFalse(
(
yield defer.ensureDeferred(
@@ -163,10 +164,8 @@ class ApplicationServiceTestCase(unittest.TestCase):
)
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
self.event.sender = "@irc_foobar:matrix.org"
- self.store.get_aliases_for_room.return_value = defer.succeed(
- ["#irc_barfoo:matrix.org"]
- )
- self.store.get_users_in_room.return_value = defer.succeed([])
+ self.store.get_aliases_for_room = simple_async_mock(["#irc_barfoo:matrix.org"])
+ self.store.get_users_in_room = simple_async_mock([])
self.assertTrue(
(
yield defer.ensureDeferred(
@@ -191,10 +190,10 @@ class ApplicationServiceTestCase(unittest.TestCase):
def test_member_list_match(self):
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
# Note that @irc_fo:here is the AS user.
- self.store.get_users_in_room.return_value = defer.succeed(
+ self.store.get_users_in_room = simple_async_mock(
["@alice:here", "@irc_fo:here", "@bob:here"]
)
- self.store.get_aliases_for_room.return_value = defer.succeed([])
+ self.store.get_aliases_for_room = simple_async_mock([])
self.event.sender = "@xmpp_foobar:matrix.org"
self.assertTrue(
diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py
index 55f0899bae..8fb6687f89 100644
--- a/tests/appservice/test_scheduler.py
+++ b/tests/appservice/test_scheduler.py
@@ -11,23 +11,29 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import TYPE_CHECKING
from unittest.mock import Mock
from twisted.internet import defer
from synapse.appservice import ApplicationServiceState
from synapse.appservice.scheduler import (
+ ApplicationServiceScheduler,
_Recoverer,
- _ServiceQueuer,
_TransactionController,
)
from synapse.logging.context import make_deferred_yieldable
+from synapse.server import HomeServer
+from synapse.util import Clock
from tests import unittest
from tests.test_utils import simple_async_mock
from ..utils import MockClock
+if TYPE_CHECKING:
+ from twisted.internet.testing import MemoryReactor
+
class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
def setUp(self):
@@ -58,7 +64,10 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
self.store.create_appservice_txn.assert_called_once_with(
- service=service, events=events, ephemeral=[] # txn made and saved
+ service=service,
+ events=events,
+ ephemeral=[],
+ to_device_messages=[], # txn made and saved
)
self.assertEquals(0, len(self.txnctrl.recoverers)) # no recoverer made
txn.complete.assert_called_once_with(self.store) # txn completed
@@ -79,7 +88,10 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
self.store.create_appservice_txn.assert_called_once_with(
- service=service, events=events, ephemeral=[] # txn made and saved
+ service=service,
+ events=events,
+ ephemeral=[],
+ to_device_messages=[], # txn made and saved
)
self.assertEquals(0, txn.send.call_count) # txn not sent though
self.assertEquals(0, txn.complete.call_count) # or completed
@@ -102,7 +114,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
self.store.create_appservice_txn.assert_called_once_with(
- service=service, events=events, ephemeral=[]
+ service=service, events=events, ephemeral=[], to_device_messages=[]
)
self.assertEquals(1, self.recoverer_fn.call_count) # recoverer made
self.assertEquals(1, self.recoverer.recover.call_count) # and invoked
@@ -189,38 +201,41 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
self.callback.assert_called_once_with(self.recoverer)
-class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
- def setUp(self):
+class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
+ def prepare(self, reactor: "MemoryReactor", clock: Clock, hs: HomeServer):
+ self.scheduler = ApplicationServiceScheduler(hs)
self.txn_ctrl = Mock()
self.txn_ctrl.send = simple_async_mock()
- self.queuer = _ServiceQueuer(self.txn_ctrl, MockClock())
+
+ # Replace instantiated _TransactionController instances with our Mock
+ self.scheduler.txn_ctrl = self.txn_ctrl
+ self.scheduler.queuer.txn_ctrl = self.txn_ctrl
def test_send_single_event_no_queue(self):
# Expect the event to be sent immediately.
service = Mock(id=4)
event = Mock()
- self.queuer.enqueue_event(service, event)
- self.txn_ctrl.send.assert_called_once_with(service, [event], [])
+ self.scheduler.enqueue_for_appservice(service, events=[event])
+ self.txn_ctrl.send.assert_called_once_with(service, [event], [], [])
def test_send_single_event_with_queue(self):
d = defer.Deferred()
- self.txn_ctrl.send = Mock(
- side_effect=lambda x, y, z: make_deferred_yieldable(d)
- )
+ self.txn_ctrl.send = Mock(return_value=make_deferred_yieldable(d))
service = Mock(id=4)
event = Mock(event_id="first")
event2 = Mock(event_id="second")
event3 = Mock(event_id="third")
# Send an event and don't resolve it just yet.
- self.queuer.enqueue_event(service, event)
+ self.scheduler.enqueue_for_appservice(service, events=[event])
# Send more events: expect send() to NOT be called multiple times.
- self.queuer.enqueue_event(service, event2)
- self.queuer.enqueue_event(service, event3)
- self.txn_ctrl.send.assert_called_with(service, [event], [])
+ # (call enqueue_for_appservice multiple times deliberately)
+ self.scheduler.enqueue_for_appservice(service, events=[event2])
+ self.scheduler.enqueue_for_appservice(service, events=[event3])
+ self.txn_ctrl.send.assert_called_with(service, [event], [], [])
self.assertEquals(1, self.txn_ctrl.send.call_count)
# Resolve the send event: expect the queued events to be sent
d.callback(service)
- self.txn_ctrl.send.assert_called_with(service, [event2, event3], [])
+ self.txn_ctrl.send.assert_called_with(service, [event2, event3], [], [])
self.assertEquals(2, self.txn_ctrl.send.call_count)
def test_multiple_service_queues(self):
@@ -238,23 +253,23 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
send_return_list = [srv_1_defer, srv_2_defer]
- def do_send(x, y, z):
+ def do_send(*args, **kwargs):
return make_deferred_yieldable(send_return_list.pop(0))
self.txn_ctrl.send = Mock(side_effect=do_send)
# send events for different ASes and make sure they are sent
- self.queuer.enqueue_event(srv1, srv_1_event)
- self.queuer.enqueue_event(srv1, srv_1_event2)
- self.txn_ctrl.send.assert_called_with(srv1, [srv_1_event], [])
- self.queuer.enqueue_event(srv2, srv_2_event)
- self.queuer.enqueue_event(srv2, srv_2_event2)
- self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event], [])
+ self.scheduler.enqueue_for_appservice(srv1, events=[srv_1_event])
+ self.scheduler.enqueue_for_appservice(srv1, events=[srv_1_event2])
+ self.txn_ctrl.send.assert_called_with(srv1, [srv_1_event], [], [])
+ self.scheduler.enqueue_for_appservice(srv2, events=[srv_2_event])
+ self.scheduler.enqueue_for_appservice(srv2, events=[srv_2_event2])
+ self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event], [], [])
# make sure callbacks for a service only send queued events for THAT
# service
srv_2_defer.callback(srv2)
- self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2], [])
+ self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2], [], [])
self.assertEquals(3, self.txn_ctrl.send.call_count)
def test_send_large_txns(self):
@@ -262,7 +277,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
srv_2_defer = defer.Deferred()
send_return_list = [srv_1_defer, srv_2_defer]
- def do_send(x, y, z):
+ def do_send(*args, **kwargs):
return make_deferred_yieldable(send_return_list.pop(0))
self.txn_ctrl.send = Mock(side_effect=do_send)
@@ -270,67 +285,65 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
service = Mock(id=4, name="service")
event_list = [Mock(name="event%i" % (i + 1)) for i in range(200)]
for event in event_list:
- self.queuer.enqueue_event(service, event)
+ self.scheduler.enqueue_for_appservice(service, [event], [])
# Expect the first event to be sent immediately.
- self.txn_ctrl.send.assert_called_with(service, [event_list[0]], [])
+ self.txn_ctrl.send.assert_called_with(service, [event_list[0]], [], [])
srv_1_defer.callback(service)
# Then send the next 100 events
- self.txn_ctrl.send.assert_called_with(service, event_list[1:101], [])
+ self.txn_ctrl.send.assert_called_with(service, event_list[1:101], [], [])
srv_2_defer.callback(service)
# Then the final 99 events
- self.txn_ctrl.send.assert_called_with(service, event_list[101:], [])
+ self.txn_ctrl.send.assert_called_with(service, event_list[101:], [], [])
self.assertEquals(3, self.txn_ctrl.send.call_count)
def test_send_single_ephemeral_no_queue(self):
# Expect the event to be sent immediately.
service = Mock(id=4, name="service")
event_list = [Mock(name="event")]
- self.queuer.enqueue_ephemeral(service, event_list)
- self.txn_ctrl.send.assert_called_once_with(service, [], event_list)
+ self.scheduler.enqueue_for_appservice(service, ephemeral=event_list)
+ self.txn_ctrl.send.assert_called_once_with(service, [], event_list, [])
def test_send_multiple_ephemeral_no_queue(self):
# Expect the event to be sent immediately.
service = Mock(id=4, name="service")
event_list = [Mock(name="event1"), Mock(name="event2"), Mock(name="event3")]
- self.queuer.enqueue_ephemeral(service, event_list)
- self.txn_ctrl.send.assert_called_once_with(service, [], event_list)
+ self.scheduler.enqueue_for_appservice(service, ephemeral=event_list)
+ self.txn_ctrl.send.assert_called_once_with(service, [], event_list, [])
def test_send_single_ephemeral_with_queue(self):
d = defer.Deferred()
- self.txn_ctrl.send = Mock(
- side_effect=lambda x, y, z: make_deferred_yieldable(d)
- )
+ self.txn_ctrl.send = Mock(return_value=make_deferred_yieldable(d))
service = Mock(id=4)
event_list_1 = [Mock(event_id="event1"), Mock(event_id="event2")]
event_list_2 = [Mock(event_id="event3"), Mock(event_id="event4")]
event_list_3 = [Mock(event_id="event5"), Mock(event_id="event6")]
# Send an event and don't resolve it just yet.
- self.queuer.enqueue_ephemeral(service, event_list_1)
+ self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_1)
# Send more events: expect send() to NOT be called multiple times.
- self.queuer.enqueue_ephemeral(service, event_list_2)
- self.queuer.enqueue_ephemeral(service, event_list_3)
- self.txn_ctrl.send.assert_called_with(service, [], event_list_1)
+ self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_2)
+ self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_3)
+ self.txn_ctrl.send.assert_called_with(service, [], event_list_1, [])
self.assertEquals(1, self.txn_ctrl.send.call_count)
# Resolve txn_ctrl.send
d.callback(service)
# Expect the queued events to be sent
- self.txn_ctrl.send.assert_called_with(service, [], event_list_2 + event_list_3)
+ self.txn_ctrl.send.assert_called_with(
+ service, [], event_list_2 + event_list_3, []
+ )
self.assertEquals(2, self.txn_ctrl.send.call_count)
def test_send_large_txns_ephemeral(self):
d = defer.Deferred()
- self.txn_ctrl.send = Mock(
- side_effect=lambda x, y, z: make_deferred_yieldable(d)
- )
+ self.txn_ctrl.send = Mock(return_value=make_deferred_yieldable(d))
# Expect the event to be sent immediately.
service = Mock(id=4, name="service")
first_chunk = [Mock(name="event%i" % (i + 1)) for i in range(100)]
second_chunk = [Mock(name="event%i" % (i + 101)) for i in range(50)]
event_list = first_chunk + second_chunk
- self.queuer.enqueue_ephemeral(service, event_list)
- self.txn_ctrl.send.assert_called_once_with(service, [], first_chunk)
+ self.scheduler.enqueue_for_appservice(service, ephemeral=event_list)
+ self.txn_ctrl.send.assert_called_once_with(service, [], first_chunk, [])
d.callback(service)
- self.txn_ctrl.send.assert_called_with(service, [], second_chunk)
+ self.txn_ctrl.send.assert_called_with(service, [], second_chunk, [])
self.assertEquals(2, self.txn_ctrl.send.call_count)
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index d6f14e2dba..fe57ff2671 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -1,4 +1,4 @@
-# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2015-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,18 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Dict, Iterable, List, Optional
from unittest.mock import Mock
from twisted.internet import defer
+import synapse.rest.admin
+import synapse.storage
+from synapse.appservice import ApplicationService
from synapse.handlers.appservice import ApplicationServicesHandler
+from synapse.rest.client import login, receipts, room, sendtodevice
from synapse.types import RoomStreamToken
+from synapse.util.stringutils import random_string
-from tests.test_utils import make_awaitable
+from tests import unittest
+from tests.test_utils import make_awaitable, simple_async_mock
from tests.utils import MockClock
-from .. import unittest
-
class AppServiceHandlerTestCase(unittest.TestCase):
"""Tests the ApplicationServicesHandler."""
@@ -36,6 +41,9 @@ class AppServiceHandlerTestCase(unittest.TestCase):
hs.get_datastore.return_value = self.mock_store
self.mock_store.get_received_ts.return_value = make_awaitable(0)
self.mock_store.set_appservice_last_pos.return_value = make_awaitable(None)
+ self.mock_store.set_appservice_stream_type_pos.return_value = make_awaitable(
+ None
+ )
hs.get_application_service_api.return_value = self.mock_as_api
hs.get_application_service_scheduler.return_value = self.mock_scheduler
hs.get_clock.return_value = MockClock()
@@ -63,8 +71,8 @@ class AppServiceHandlerTestCase(unittest.TestCase):
]
self.handler.notify_interested_services(RoomStreamToken(None, 1))
- self.mock_scheduler.submit_event_for_as.assert_called_once_with(
- interested_service, event
+ self.mock_scheduler.enqueue_for_appservice.assert_called_once_with(
+ interested_service, events=[event]
)
def test_query_user_exists_unknown_user(self):
@@ -261,7 +269,6 @@ class AppServiceHandlerTestCase(unittest.TestCase):
"""
interested_service = self._mkservice(is_interested=True)
services = [interested_service]
-
self.mock_store.get_app_services.return_value = services
self.mock_store.get_type_stream_id_for_appservice.return_value = make_awaitable(
579
@@ -275,10 +282,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self.handler.notify_interested_services_ephemeral(
"receipt_key", 580, ["@fakerecipient:example.com"]
)
- self.mock_scheduler.submit_ephemeral_events_for_as.assert_called_once_with(
- interested_service, [event]
+ self.mock_scheduler.enqueue_for_appservice.assert_called_once_with(
+ interested_service, ephemeral=[event]
)
- self.mock_store.set_type_stream_id_for_appservice.assert_called_once_with(
+ self.mock_store.set_appservice_stream_type_pos.assert_called_once_with(
interested_service,
"read_receipt",
580,
@@ -305,7 +312,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self.handler.notify_interested_services_ephemeral(
"receipt_key", 580, ["@fakerecipient:example.com"]
)
- self.mock_scheduler.submit_ephemeral_events_for_as.assert_not_called()
+ # This method will be called, but with an empty list of events
+ self.mock_scheduler.enqueue_for_appservice.assert_called_once_with(
+ interested_service, ephemeral=[]
+ )
def _mkservice(self, is_interested, protocols=None):
service = Mock()
@@ -321,3 +331,252 @@ class AppServiceHandlerTestCase(unittest.TestCase):
service.token = "mock_service_token"
service.url = "mock_service_url"
return service
+
+
+class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
+ """
+ Tests that the ApplicationServicesHandler sends events to application
+ services correctly.
+ """
+
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ room.register_servlets,
+ sendtodevice.register_servlets,
+ receipts.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ # Mock the ApplicationServiceScheduler's _TransactionController's send method so that
+ # we can track any outgoing ephemeral events
+ self.send_mock = simple_async_mock()
+ hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock
+
+ # Mock out application services, and allow defining our own in tests
+ self._services: List[ApplicationService] = []
+ self.hs.get_datastore().get_app_services = Mock(return_value=self._services)
+
+ # A user on the homeserver.
+ self.local_user_device_id = "local_device"
+ self.local_user = self.register_user("local_user", "password")
+ self.local_user_token = self.login(
+ "local_user", "password", self.local_user_device_id
+ )
+
+ # A user on the homeserver which lies within an appservice's exclusive user namespace.
+ self.exclusive_as_user_device_id = "exclusive_as_device"
+ self.exclusive_as_user = self.register_user("exclusive_as_user", "password")
+ self.exclusive_as_user_token = self.login(
+ "exclusive_as_user", "password", self.exclusive_as_user_device_id
+ )
+
+ @unittest.override_config(
+ {"experimental_features": {"msc2409_to_device_messages_enabled": True}}
+ )
+ def test_application_services_receive_local_to_device(self):
+ """
+ Test that when a user sends a to-device message to another user
+ that is an application service's user namespace, the
+ application service will receive it.
+ """
+ interested_appservice = self._register_application_service(
+ namespaces={
+ ApplicationService.NS_USERS: [
+ {
+ "regex": "@exclusive_as_user:.+",
+ "exclusive": True,
+ }
+ ],
+ },
+ )
+
+ # Have local_user send a to-device message to exclusive_as_user
+ message_content = {"some_key": "some really interesting value"}
+ chan = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/sendToDevice/m.room_key_request/3",
+ content={
+ "messages": {
+ self.exclusive_as_user: {
+ self.exclusive_as_user_device_id: message_content
+ }
+ }
+ },
+ access_token=self.local_user_token,
+ )
+ self.assertEqual(chan.code, 200, chan.result)
+
+ # Have exclusive_as_user send a to-device message to local_user
+ chan = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/sendToDevice/m.room_key_request/4",
+ content={
+ "messages": {
+ self.local_user: {self.local_user_device_id: message_content}
+ }
+ },
+ access_token=self.exclusive_as_user_token,
+ )
+ self.assertEqual(chan.code, 200, chan.result)
+
+ # Check if our application service - that is interested in exclusive_as_user - received
+ # the to-device message as part of an AS transaction.
+ # Only the local_user -> exclusive_as_user to-device message should have been forwarded to the AS.
+ #
+ # The uninterested application service should not have been notified at all.
+ self.send_mock.assert_called_once()
+ service, _events, _ephemeral, to_device_messages = self.send_mock.call_args[0]
+
+ # Assert that this was the same to-device message that local_user sent
+ self.assertEqual(service, interested_appservice)
+ self.assertEqual(to_device_messages[0]["type"], "m.room_key_request")
+ self.assertEqual(to_device_messages[0]["sender"], self.local_user)
+
+ # Additional fields 'to_user_id' and 'to_device_id' specifically for
+ # to-device messages via the AS API
+ self.assertEqual(to_device_messages[0]["to_user_id"], self.exclusive_as_user)
+ self.assertEqual(
+ to_device_messages[0]["to_device_id"], self.exclusive_as_user_device_id
+ )
+ self.assertEqual(to_device_messages[0]["content"], message_content)
+
+ @unittest.override_config(
+ {"experimental_features": {"msc2409_to_device_messages_enabled": True}}
+ )
+ def test_application_services_receive_bursts_of_to_device(self):
+ """
+ Test that when a user sends >100 to-device messages at once, any
+ interested AS's will receive them in separate transactions.
+
+ Also tests that uninterested application services do not receive messages.
+ """
+ # Register two application services with exclusive interest in a user
+ interested_appservices = []
+ for _ in range(2):
+ appservice = self._register_application_service(
+ namespaces={
+ ApplicationService.NS_USERS: [
+ {
+ "regex": "@exclusive_as_user:.+",
+ "exclusive": True,
+ }
+ ],
+ },
+ )
+ interested_appservices.append(appservice)
+
+ # ...and an application service which does not have any user interest.
+ self._register_application_service()
+
+ to_device_message_content = {
+ "some key": "some interesting value",
+ }
+
+ # We need to send a large burst of to-device messages. We also would like to
+ # include them all in the same application service transaction so that we can
+ # test large transactions.
+ #
+ # To do this, we can send a single to-device message to many user devices at
+ # once.
+ #
+ # We insert number_of_messages - 1 messages into the database directly. We'll then
+ # send a final to-device message to the real device, which will also kick off
+ # an AS transaction (as just inserting messages into the DB won't).
+ number_of_messages = 150
+ fake_device_ids = [f"device_{num}" for num in range(number_of_messages - 1)]
+ messages = {
+ self.exclusive_as_user: {
+ device_id: to_device_message_content for device_id in fake_device_ids
+ }
+ }
+
+ # Create a fake device per message. We can't send to-device messages to
+ # a device that doesn't exist.
+ self.get_success(
+ self.hs.get_datastore().db_pool.simple_insert_many(
+ desc="test_application_services_receive_burst_of_to_device",
+ table="devices",
+ keys=("user_id", "device_id"),
+ values=[
+ (
+ self.exclusive_as_user,
+ device_id,
+ )
+ for device_id in fake_device_ids
+ ],
+ )
+ )
+
+ # Seed the device_inbox table with our fake messages
+ self.get_success(
+ self.hs.get_datastore().add_messages_to_device_inbox(messages, {})
+ )
+
+ # Now have local_user send a final to-device message to exclusive_as_user. All unsent
+ # to-device messages should be sent to any application services
+ # interested in exclusive_as_user.
+ chan = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/sendToDevice/m.room_key_request/4",
+ content={
+ "messages": {
+ self.exclusive_as_user: {
+ self.exclusive_as_user_device_id: to_device_message_content
+ }
+ }
+ },
+ access_token=self.local_user_token,
+ )
+ self.assertEqual(chan.code, 200, chan.result)
+
+ self.send_mock.assert_called()
+
+ # Count the total number of to-device messages that were sent out per-service.
+ # Ensure that we only sent to-device messages to interested services, and that
+ # each interested service received the full count of to-device messages.
+ service_id_to_message_count: Dict[str, int] = {}
+
+ for call in self.send_mock.call_args_list:
+ service, _events, _ephemeral, to_device_messages = call[0]
+
+ # Check that this was made to an interested service
+ self.assertIn(service, interested_appservices)
+
+ # Add to the count of messages for this application service
+ service_id_to_message_count.setdefault(service.id, 0)
+ service_id_to_message_count[service.id] += len(to_device_messages)
+
+ # Assert that each interested service received the full count of messages
+ for count in service_id_to_message_count.values():
+ self.assertEqual(count, number_of_messages)
+
+ def _register_application_service(
+ self,
+ namespaces: Optional[Dict[str, Iterable[Dict]]] = None,
+ ) -> ApplicationService:
+ """
+ Register a new application service, with the given namespaces of interest.
+
+ Args:
+ namespaces: A dictionary containing any user, room or alias namespaces that
+ the application service is interested in.
+
+ Returns:
+ The registered application service.
+ """
+ # Create an application service
+ appservice = ApplicationService(
+ token=random_string(10),
+ hostname="example.com",
+ id=random_string(10),
+ sender="@as:example.com",
+ rate_limited=False,
+ namespaces=namespaces,
+ supports_ephemeral=True,
+ )
+
+ # Register the application service
+ self._services.append(appservice)
+
+ return appservice
diff --git a/tests/handlers/test_deactivate_account.py b/tests/handlers/test_deactivate_account.py
new file mode 100644
index 0000000000..01096a1581
--- /dev/null
+++ b/tests/handlers/test_deactivate_account.py
@@ -0,0 +1,325 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from http import HTTPStatus
+from typing import Any, Dict
+
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.api.constants import AccountDataTypes
+from synapse.push.rulekinds import PRIORITY_CLASS_MAP
+from synapse.rest import admin
+from synapse.rest.client import account, login
+from synapse.server import HomeServer
+from synapse.util import Clock
+
+from tests.unittest import HomeserverTestCase
+
+
+class DeactivateAccountTestCase(HomeserverTestCase):
+ servlets = [
+ login.register_servlets,
+ admin.register_servlets,
+ account.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self._store = hs.get_datastore()
+
+ self.user = self.register_user("user", "pass")
+ self.token = self.login("user", "pass")
+
+ def _deactivate_my_account(self):
+ """
+ Deactivates the account `self.user` using `self.token` and asserts
+ that it returns a 200 success code.
+ """
+ req = self.get_success(
+ self.make_request(
+ "POST",
+ "account/deactivate",
+ {
+ "auth": {
+ "type": "m.login.password",
+ "user": self.user,
+ "password": "pass",
+ },
+ "erase": True,
+ },
+ access_token=self.token,
+ )
+ )
+ self.assertEqual(req.code, HTTPStatus.OK, req)
+
+ def test_global_account_data_deleted_upon_deactivation(self) -> None:
+ """
+ Tests that global account data is removed upon deactivation.
+ """
+ # Add some account data
+ self.get_success(
+ self._store.add_account_data_for_user(
+ self.user,
+ AccountDataTypes.DIRECT,
+ {"@someone:remote": ["!somewhere:remote"]},
+ )
+ )
+
+ # Check that we actually added some.
+ self.assertIsNotNone(
+ self.get_success(
+ self._store.get_global_account_data_by_type_for_user(
+ self.user, AccountDataTypes.DIRECT
+ )
+ ),
+ )
+
+ # Request the deactivation of our account
+ self._deactivate_my_account()
+
+ # Check that the account data does not persist.
+ self.assertIsNone(
+ self.get_success(
+ self._store.get_global_account_data_by_type_for_user(
+ self.user, AccountDataTypes.DIRECT
+ )
+ ),
+ )
+
+ def test_room_account_data_deleted_upon_deactivation(self) -> None:
+ """
+ Tests that room account data is removed upon deactivation.
+ """
+ room_id = "!room:test"
+
+ # Add some room account data
+ self.get_success(
+ self._store.add_account_data_to_room(
+ self.user,
+ room_id,
+ "m.fully_read",
+ {"event_id": "$aaaa:test"},
+ )
+ )
+
+ # Check that we actually added some.
+ self.assertIsNotNone(
+ self.get_success(
+ self._store.get_account_data_for_room_and_type(
+ self.user, room_id, "m.fully_read"
+ )
+ ),
+ )
+
+ # Request the deactivation of our account
+ self._deactivate_my_account()
+
+ # Check that the account data does not persist.
+ self.assertIsNone(
+ self.get_success(
+ self._store.get_account_data_for_room_and_type(
+ self.user, room_id, "m.fully_read"
+ )
+ ),
+ )
+
+ def _is_custom_rule(self, push_rule: Dict[str, Any]) -> bool:
+ """
+ Default rules start with a dot: such as .m.rule and .im.vector.
+ This function returns true iff a rule is custom (not default).
+ """
+ return "/." not in push_rule["rule_id"]
+
+ def test_push_rules_deleted_upon_account_deactivation(self) -> None:
+ """
+ Push rules are a special case of account data.
+ They are stored separately but get sent to the client as account data in /sync.
+ This tests that deactivating a user deletes push rules along with the rest
+ of their account data.
+ """
+
+ # Add a push rule
+ self.get_success(
+ self._store.add_push_rule(
+ self.user,
+ "personal.override.rule1",
+ PRIORITY_CLASS_MAP["override"],
+ [],
+ [],
+ )
+ )
+
+ # Test the rule exists
+ push_rules = self.get_success(self._store.get_push_rules_for_user(self.user))
+ # Filter out default rules; we don't care
+ push_rules = list(filter(self._is_custom_rule, push_rules))
+ # Check our rule made it
+ self.assertEqual(
+ push_rules,
+ [
+ {
+ "user_name": "@user:test",
+ "rule_id": "personal.override.rule1",
+ "priority_class": 5,
+ "priority": 0,
+ "conditions": [],
+ "actions": [],
+ "default": False,
+ }
+ ],
+ push_rules,
+ )
+
+ # Request the deactivation of our account
+ self._deactivate_my_account()
+
+ push_rules = self.get_success(self._store.get_push_rules_for_user(self.user))
+ # Filter out default rules; we don't care
+ push_rules = list(filter(self._is_custom_rule, push_rules))
+ # Check our rule no longer exists
+ self.assertEqual(push_rules, [], push_rules)
+
+ def test_ignored_users_deleted_upon_deactivation(self) -> None:
+ """
+ Ignored users are a special case of account data.
+ They get denormalised into the `ignored_users` table upon being stored as
+ account data.
+ Test that a user's list of ignored users is deleted upon deactivation.
+ """
+
+ # Add an ignored user
+ self.get_success(
+ self._store.add_account_data_for_user(
+ self.user,
+ AccountDataTypes.IGNORED_USER_LIST,
+ {"ignored_users": {"@sheltie:test": {}}},
+ )
+ )
+
+ # Test the user is ignored
+ self.assertEqual(
+ self.get_success(self._store.ignored_by("@sheltie:test")), {self.user}
+ )
+
+ # Request the deactivation of our account
+ self._deactivate_my_account()
+
+ # Test the user is no longer ignored by the user that was deactivated
+ self.assertEqual(
+ self.get_success(self._store.ignored_by("@sheltie:test")), set()
+ )
+
+ def _rerun_retroactive_account_data_deletion_update(self) -> None:
+ # Reset the 'all done' flag
+ self._store.db_pool.updates._all_done = False
+
+ self.get_success(
+ self._store.db_pool.simple_insert(
+ "background_updates",
+ {
+ "update_name": "delete_account_data_for_deactivated_users",
+ "progress_json": "{}",
+ },
+ )
+ )
+
+ self.wait_for_background_updates()
+
+ def test_account_data_deleted_retroactively_by_background_update_if_deactivated(
+ self,
+ ) -> None:
+ """
+ Tests that a user, who deactivated their account before account data was
+ deleted automatically upon deactivation, has their account data retroactively
+ scrubbed by the background update.
+ """
+
+ # Request the deactivation of our account
+ self._deactivate_my_account()
+
+ # Add some account data
+ # (we do this after the deactivation so that the act of deactivating doesn't
+ # clear it out. This emulates a user that was deactivated before this was cleared
+ # upon deactivation.)
+ self.get_success(
+ self._store.add_account_data_for_user(
+ self.user,
+ AccountDataTypes.DIRECT,
+ {"@someone:remote": ["!somewhere:remote"]},
+ )
+ )
+
+ # Check that the account data is there.
+ self.assertIsNotNone(
+ self.get_success(
+ self._store.get_global_account_data_by_type_for_user(
+ self.user,
+ AccountDataTypes.DIRECT,
+ )
+ ),
+ )
+
+ # Re-run the retroactive deletion update
+ self._rerun_retroactive_account_data_deletion_update()
+
+ # Check that the account data was cleared.
+ self.assertIsNone(
+ self.get_success(
+ self._store.get_global_account_data_by_type_for_user(
+ self.user,
+ AccountDataTypes.DIRECT,
+ )
+ ),
+ )
+
+ def test_account_data_preserved_by_background_update_if_not_deactivated(
+ self,
+ ) -> None:
+ """
+ Tests that the background update does not scrub account data for users that have
+ not been deactivated.
+ """
+
+ # Add some account data
+ # (we do this after the deactivation so that the act of deactivating doesn't
+ # clear it out. This emulates a user that was deactivated before this was cleared
+ # upon deactivation.)
+ self.get_success(
+ self._store.add_account_data_for_user(
+ self.user,
+ AccountDataTypes.DIRECT,
+ {"@someone:remote": ["!somewhere:remote"]},
+ )
+ )
+
+ # Check that the account data is there.
+ self.assertIsNotNone(
+ self.get_success(
+ self._store.get_global_account_data_by_type_for_user(
+ self.user,
+ AccountDataTypes.DIRECT,
+ )
+ ),
+ )
+
+ # Re-run the retroactive deletion update
+ self._rerun_retroactive_account_data_deletion_update()
+
+ # Check that the account data was NOT cleared.
+ self.assertIsNotNone(
+ self.get_success(
+ self._store.get_global_account_data_by_type_for_user(
+ self.user,
+ AccountDataTypes.DIRECT,
+ )
+ ),
+ )
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index 2add72b28a..94809cb8be 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -20,10 +20,11 @@ from unittest.mock import Mock
from twisted.internet import defer
import synapse
+from synapse.api.constants import LoginType
from synapse.handlers.auth import load_legacy_password_auth_providers
from synapse.module_api import ModuleApi
-from synapse.rest.client import devices, login, logout
-from synapse.types import JsonDict
+from synapse.rest.client import devices, login, logout, register
+from synapse.types import JsonDict, UserID
from tests import unittest
from tests.server import FakeChannel
@@ -156,6 +157,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
login.register_servlets,
devices.register_servlets,
logout.register_servlets,
+ register.register_servlets,
]
def setUp(self):
@@ -745,6 +747,79 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
on_logged_out.assert_called_once()
self.assertTrue(self.called)
+ def test_username(self):
+ """Tests that the get_username_for_registration callback can define the username
+ of a user when registering.
+ """
+ self._setup_get_username_for_registration()
+
+ username = "rin"
+ channel = self.make_request(
+ "POST",
+ "/register",
+ {
+ "username": username,
+ "password": "bar",
+ "auth": {"type": LoginType.DUMMY},
+ },
+ )
+ self.assertEqual(channel.code, 200)
+
+ # Our callback takes the username and appends "-foo" to it, check that's what we
+ # have.
+ mxid = channel.json_body["user_id"]
+ self.assertEqual(UserID.from_string(mxid).localpart, username + "-foo")
+
+ def test_username_uia(self):
+ """Tests that the get_username_for_registration callback is only called at the
+ end of the UIA flow.
+ """
+ m = self._setup_get_username_for_registration()
+
+ # Initiate the UIA flow.
+ username = "rin"
+ channel = self.make_request(
+ "POST",
+ "register",
+ {"username": username, "type": "m.login.password", "password": "bar"},
+ )
+ self.assertEqual(channel.code, 401)
+ self.assertIn("session", channel.json_body)
+
+ # Check that the callback hasn't been called yet.
+ m.assert_not_called()
+
+ # Finish the UIA flow.
+ session = channel.json_body["session"]
+ channel = self.make_request(
+ "POST",
+ "register",
+ {"auth": {"session": session, "type": LoginType.DUMMY}},
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+ mxid = channel.json_body["user_id"]
+ self.assertEqual(UserID.from_string(mxid).localpart, username + "-foo")
+
+ # Check that the callback has been called.
+ m.assert_called_once()
+
+ def _setup_get_username_for_registration(self) -> Mock:
+ """Registers a get_username_for_registration callback that appends "-foo" to the
+ username the client is trying to register.
+ """
+
+ async def get_username_for_registration(uia_results, params):
+ self.assertIn(LoginType.DUMMY, uia_results)
+ username = params["username"]
+ return username + "-foo"
+
+ m = Mock(side_effect=get_username_for_registration)
+
+ password_auth_provider = self.hs.get_password_auth_provider()
+ password_auth_provider.get_username_for_registration_callbacks.append(m)
+
+ return m
+
def _get_login_flows(self) -> JsonDict:
channel = self.make_request("GET", "/_matrix/client/r0/login")
self.assertEqual(channel.code, 200, channel.result)
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index c153018fd8..60235e5699 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -11,12 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+from typing import Any, Dict
from unittest.mock import Mock
import synapse.types
from synapse.api.errors import AuthError, SynapseError
from synapse.rest import admin
+from synapse.server import HomeServer
from synapse.types import UserID
from tests import unittest
@@ -46,7 +47,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
)
return hs
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor, clock, hs: HomeServer):
self.store = hs.get_datastore()
self.frank = UserID.from_string("@1234abcd:test")
@@ -248,3 +249,92 @@ class ProfileTestCase(unittest.HomeserverTestCase):
),
SynapseError,
)
+
+ def test_avatar_constraints_no_config(self):
+ """Tests that the method to check an avatar against configured constraints skips
+ all of its check if no constraint is configured.
+ """
+ # The first check that's done by this method is whether the file exists; if we
+ # don't get an error on a non-existing file then it means all of the checks were
+ # successfully skipped.
+ res = self.get_success(
+ self.handler.check_avatar_size_and_mime_type("mxc://test/unknown_file")
+ )
+ self.assertTrue(res)
+
+ @unittest.override_config({"max_avatar_size": 50})
+ def test_avatar_constraints_missing(self):
+ """Tests that an avatar isn't allowed if the file at the given MXC URI couldn't
+ be found.
+ """
+ res = self.get_success(
+ self.handler.check_avatar_size_and_mime_type("mxc://test/unknown_file")
+ )
+ self.assertFalse(res)
+
+ @unittest.override_config({"max_avatar_size": 50})
+ def test_avatar_constraints_file_size(self):
+ """Tests that a file that's above the allowed file size is forbidden but one
+ that's below it is allowed.
+ """
+ self._setup_local_files(
+ {
+ "small": {"size": 40},
+ "big": {"size": 60},
+ }
+ )
+
+ res = self.get_success(
+ self.handler.check_avatar_size_and_mime_type("mxc://test/small")
+ )
+ self.assertTrue(res)
+
+ res = self.get_success(
+ self.handler.check_avatar_size_and_mime_type("mxc://test/big")
+ )
+ self.assertFalse(res)
+
+ @unittest.override_config({"allowed_avatar_mimetypes": ["image/png"]})
+ def test_avatar_constraint_mime_type(self):
+ """Tests that a file with an unauthorised MIME type is forbidden but one with
+ an authorised content type is allowed.
+ """
+ self._setup_local_files(
+ {
+ "good": {"mimetype": "image/png"},
+ "bad": {"mimetype": "application/octet-stream"},
+ }
+ )
+
+ res = self.get_success(
+ self.handler.check_avatar_size_and_mime_type("mxc://test/good")
+ )
+ self.assertTrue(res)
+
+ res = self.get_success(
+ self.handler.check_avatar_size_and_mime_type("mxc://test/bad")
+ )
+ self.assertFalse(res)
+
+ def _setup_local_files(self, names_and_props: Dict[str, Dict[str, Any]]):
+ """Stores metadata about files in the database.
+
+ Args:
+ names_and_props: A dictionary with one entry per file, with the key being the
+ file's name, and the value being a dictionary of properties. Supported
+ properties are "mimetype" (for the file's type) and "size" (for the
+ file's size).
+ """
+ store = self.hs.get_datastore()
+
+ for name, props in names_and_props.items():
+ self.get_success(
+ store.store_local_media(
+ media_id=name,
+ media_type=props.get("mimetype", "image/png"),
+ time_now_ms=self.clock.time_msec(),
+ upload_name=None,
+ media_length=props.get("size", 50),
+ user_id=UserID.from_string("@rin:test"),
+ )
+ )
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 70c621b825..482c90ef68 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -169,7 +169,9 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
# Register an AS user.
user = self.register_user("user", "pass")
token = self.login(user, "pass")
- as_user = self.register_appservice_user("as_user_potato", self.appservice.token)
+ as_user, _ = self.register_appservice_user(
+ "as_user_potato", self.appservice.token
+ )
# Join the AS user to rooms owned by the normal user.
public, private = self._create_rooms_and_inject_memberships(
@@ -388,7 +390,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
def test_handle_local_profile_change_with_appservice_user(self) -> None:
# create user
- as_user_id = self.register_appservice_user(
+ as_user_id, _ = self.register_appservice_user(
"as_user_alice", self.appservice.token
)
diff --git a/tests/http/test_webclient.py b/tests/http/test_webclient.py
deleted file mode 100644
index ee5cf299f6..0000000000
--- a/tests/http/test_webclient.py
+++ /dev/null
@@ -1,108 +0,0 @@
-# Copyright 2022 The Matrix.org Foundation C.I.C.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-from http import HTTPStatus
-from typing import Dict
-
-from twisted.web.resource import Resource
-
-from synapse.app.homeserver import SynapseHomeServer
-from synapse.config.server import HttpListenerConfig, HttpResourceConfig, ListenerConfig
-from synapse.http.site import SynapseSite
-
-from tests.server import make_request
-from tests.unittest import HomeserverTestCase, create_resource_tree, override_config
-
-
-class WebClientTests(HomeserverTestCase):
- @override_config(
- {
- "web_client_location": "https://example.org",
- }
- )
- def test_webclient_resolves_with_client_resource(self):
- """
- Tests that both client and webclient resources can be accessed simultaneously.
-
- This is a regression test created in response to https://github.com/matrix-org/synapse/issues/11763.
- """
- for resource_name_order_list in [
- ["webclient", "client"],
- ["client", "webclient"],
- ]:
- # Create a dictionary from path regex -> resource
- resource_dict: Dict[str, Resource] = {}
-
- for resource_name in resource_name_order_list:
- resource_dict.update(
- SynapseHomeServer._configure_named_resource(self.hs, resource_name)
- )
-
- # Create a root resource which ties the above resources together into one
- root_resource = Resource()
- create_resource_tree(resource_dict, root_resource)
-
- # Create a site configured with this resource to make HTTP requests against
- listener_config = ListenerConfig(
- port=8008,
- bind_addresses=["127.0.0.1"],
- type="http",
- http_options=HttpListenerConfig(
- resources=[HttpResourceConfig(names=resource_name_order_list)]
- ),
- )
- test_site = SynapseSite(
- logger_name="synapse.access.http.fake",
- site_tag=self.hs.config.server.server_name,
- config=listener_config,
- resource=root_resource,
- server_version_string="1",
- max_request_body_size=1234,
- reactor=self.reactor,
- )
-
- # Attempt to make requests to endpoints on both the webclient and client resources
- # on test_site.
- self._request_client_and_webclient_resources(test_site)
-
- def _request_client_and_webclient_resources(self, test_site: SynapseSite) -> None:
- """Make a request to an endpoint on both the webclient and client-server resources
- of the given SynapseSite.
-
- Args:
- test_site: The SynapseSite object to make requests against.
- """
-
- # Ensure that the *webclient* resource is behaving as expected (we get redirected to
- # the configured web_client_location)
- channel = make_request(
- self.reactor,
- site=test_site,
- method="GET",
- path="/_matrix/client",
- )
- # Check that we are being redirected to the webclient location URI.
- self.assertEqual(channel.code, HTTPStatus.FOUND)
- self.assertEqual(
- channel.headers.getRawHeaders("Location"), ["https://example.org"]
- )
-
- # Ensure that a request to the *client* resource works.
- channel = make_request(
- self.reactor,
- site=test_site,
- method="GET",
- path="/_matrix/client/v3/login",
- )
- self.assertEqual(channel.code, HTTPStatus.OK)
- self.assertIn("flows", channel.json_body)
diff --git a/tests/logging/test_opentracing.py b/tests/logging/test_opentracing.py
new file mode 100644
index 0000000000..e430941d27
--- /dev/null
+++ b/tests/logging/test_opentracing.py
@@ -0,0 +1,184 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+from twisted.test.proto_helpers import MemoryReactorClock
+
+from synapse.logging.context import (
+ LoggingContext,
+ make_deferred_yieldable,
+ run_in_background,
+)
+from synapse.logging.opentracing import (
+ start_active_span,
+ start_active_span_follows_from,
+)
+from synapse.util import Clock
+
+try:
+ from synapse.logging.scopecontextmanager import LogContextScopeManager
+except ImportError:
+ LogContextScopeManager = None # type: ignore
+
+try:
+ import jaeger_client
+except ImportError:
+ jaeger_client = None # type: ignore
+
+from tests.unittest import TestCase
+
+
+class LogContextScopeManagerTestCase(TestCase):
+ if LogContextScopeManager is None:
+ skip = "Requires opentracing" # type: ignore[unreachable]
+ if jaeger_client is None:
+ skip = "Requires jaeger_client" # type: ignore[unreachable]
+
+ def setUp(self) -> None:
+ # since this is a unit test, we don't really want to mess around with the
+ # global variables that power opentracing. We create our own tracer instance
+ # and test with it.
+
+ scope_manager = LogContextScopeManager({})
+ config = jaeger_client.config.Config(
+ config={}, service_name="test", scope_manager=scope_manager
+ )
+
+ self._reporter = jaeger_client.reporter.InMemoryReporter()
+
+ self._tracer = config.create_tracer(
+ sampler=jaeger_client.ConstSampler(True),
+ reporter=self._reporter,
+ )
+
+ def test_start_active_span(self) -> None:
+ # the scope manager assumes a logging context of some sort.
+ with LoggingContext("root context"):
+ self.assertIsNone(self._tracer.active_span)
+
+ # start_active_span should start and activate a span.
+ scope = start_active_span("span", tracer=self._tracer)
+ span = scope.span
+ self.assertEqual(self._tracer.active_span, span)
+ self.assertIsNotNone(span.start_time)
+
+ # entering the context doesn't actually do a whole lot.
+ with scope as ctx:
+ self.assertIs(ctx, scope)
+ self.assertEqual(self._tracer.active_span, span)
+
+ # ... but leaving it unsets the active span, and finishes the span.
+ self.assertIsNone(self._tracer.active_span)
+ self.assertIsNotNone(span.end_time)
+
+ # the span should have been reported
+ self.assertEqual(self._reporter.get_spans(), [span])
+
+ def test_nested_spans(self) -> None:
+ """Starting two spans off inside each other should work"""
+
+ with LoggingContext("root context"):
+ with start_active_span("root span", tracer=self._tracer) as root_scope:
+ self.assertEqual(self._tracer.active_span, root_scope.span)
+
+ scope1 = start_active_span(
+ "child1",
+ tracer=self._tracer,
+ )
+ self.assertEqual(
+ self._tracer.active_span, scope1.span, "child1 was not activated"
+ )
+ self.assertEqual(
+ scope1.span.context.parent_id, root_scope.span.context.span_id
+ )
+
+ scope2 = start_active_span_follows_from(
+ "child2",
+ contexts=(scope1,),
+ tracer=self._tracer,
+ )
+ self.assertEqual(self._tracer.active_span, scope2.span)
+ self.assertEqual(
+ scope2.span.context.parent_id, scope1.span.context.span_id
+ )
+
+ with scope1, scope2:
+ pass
+
+ # the root scope should be restored
+ self.assertEqual(self._tracer.active_span, root_scope.span)
+ self.assertIsNotNone(scope2.span.end_time)
+ self.assertIsNotNone(scope1.span.end_time)
+
+ self.assertIsNone(self._tracer.active_span)
+
+ # the spans should be reported in order of their finishing.
+ self.assertEqual(
+ self._reporter.get_spans(), [scope2.span, scope1.span, root_scope.span]
+ )
+
+ def test_overlapping_spans(self) -> None:
+ """Overlapping spans which are not neatly nested should work"""
+ reactor = MemoryReactorClock()
+ clock = Clock(reactor)
+
+ scopes = []
+
+ async def task(i: int):
+ scope = start_active_span(
+ f"task{i}",
+ tracer=self._tracer,
+ )
+ scopes.append(scope)
+
+ self.assertEqual(self._tracer.active_span, scope.span)
+ await clock.sleep(4)
+ self.assertEqual(self._tracer.active_span, scope.span)
+ scope.close()
+
+ async def root():
+ with start_active_span("root span", tracer=self._tracer) as root_scope:
+ self.assertEqual(self._tracer.active_span, root_scope.span)
+ scopes.append(root_scope)
+
+ d1 = run_in_background(task, 1)
+ await clock.sleep(2)
+ d2 = run_in_background(task, 2)
+
+ # because we did run_in_background, the active span should still be the
+ # root.
+ self.assertEqual(self._tracer.active_span, root_scope.span)
+
+ await make_deferred_yieldable(
+ defer.gatherResults([d1, d2], consumeErrors=True)
+ )
+
+ self.assertEqual(self._tracer.active_span, root_scope.span)
+
+ with LoggingContext("root context"):
+ # start the test off
+ d1 = defer.ensureDeferred(root())
+
+ # let the tasks complete
+ reactor.pump((2,) * 8)
+
+ self.successResultOf(d1)
+ self.assertIsNone(self._tracer.active_span)
+
+ # the spans should be reported in order of their finishing: task 1, task 2,
+ # root.
+ self.assertEqual(
+ self._reporter.get_spans(),
+ [scopes[1].span, scopes[2].span, scopes[0].span],
+ )
diff --git a/tests/replication/slave/storage/test_account_data.py b/tests/replication/slave/storage/test_account_data.py
index 43e3248703..1524087c43 100644
--- a/tests/replication/slave/storage/test_account_data.py
+++ b/tests/replication/slave/storage/test_account_data.py
@@ -30,7 +30,7 @@ class SlavedAccountDataStoreTestCase(BaseSlavedStoreTestCase):
)
self.replicate()
self.check(
- "get_global_account_data_by_type_for_user", [TYPE, USER_ID], {"a": 1}
+ "get_global_account_data_by_type_for_user", [USER_ID, TYPE], {"a": 1}
)
self.get_success(
@@ -38,5 +38,5 @@ class SlavedAccountDataStoreTestCase(BaseSlavedStoreTestCase):
)
self.replicate()
self.check(
- "get_global_account_data_by_type_for_user", [TYPE, USER_ID], {"a": 2}
+ "get_global_account_data_by_type_for_user", [USER_ID, TYPE], {"a": 2}
)
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index 3adadcb46b..849d00ab4d 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -12,18 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import os
import urllib.parse
from http import HTTPStatus
-from unittest.mock import Mock
+from typing import List
-from twisted.internet.defer import Deferred
+from parameterized import parameterized
+
+from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
from synapse.http.server import JsonResource
-from synapse.logging.context import make_deferred_yieldable
from synapse.rest.admin import VersionServlet
from synapse.rest.client import groups, login, room
+from synapse.server import HomeServer
+from synapse.util import Clock
from tests import unittest
from tests.server import FakeSite, make_request
@@ -33,12 +35,12 @@ from tests.test_utils import SMALL_PNG
class VersionTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v1/server_version"
- def create_test_resource(self):
+ def create_test_resource(self) -> JsonResource:
resource = JsonResource(self.hs)
VersionServlet(self.hs).register(resource)
return resource
- def test_version_string(self):
+ def test_version_string(self) -> None:
channel = self.make_request("GET", self.url, shorthand=False)
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
@@ -54,14 +56,14 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
groups.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
self.other_user = self.register_user("user", "pass")
self.other_user_token = self.login("user", "pass")
- def test_delete_group(self):
+ def test_delete_group(self) -> None:
# Create a new group
channel = self.make_request(
"POST",
@@ -112,7 +114,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
self.assertNotIn(group_id, self._get_groups_user_is_in(self.admin_user_tok))
self.assertNotIn(group_id, self._get_groups_user_is_in(self.other_user_token))
- def _check_group(self, group_id, expect_code):
+ def _check_group(self, group_id: str, expect_code: int) -> None:
"""Assert that trying to fetch the given group results in the given
HTTP status code
"""
@@ -124,7 +126,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
self.assertEqual(expect_code, channel.code, msg=channel.json_body)
- def _get_groups_user_is_in(self, access_token):
+ def _get_groups_user_is_in(self, access_token: str) -> List[str]:
"""Returns the list of groups the user is in (given their access token)"""
channel = self.make_request("GET", b"/joined_groups", access_token=access_token)
@@ -143,59 +145,15 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
room.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# Allow for uploading and downloading to/from the media repo
self.media_repo = hs.get_media_repository_resource()
self.download_resource = self.media_repo.children[b"download"]
self.upload_resource = self.media_repo.children[b"upload"]
- def make_homeserver(self, reactor, clock):
-
- self.fetches = []
-
- async def get_file(destination, path, output_stream, args=None, max_size=None):
- """
- Returns tuple[int,dict,str,int] of file length, response headers,
- absolute URI, and response code.
- """
-
- def write_to(r):
- data, response = r
- output_stream.write(data)
- return response
-
- d = Deferred()
- d.addCallback(write_to)
- self.fetches.append((d, destination, path, args))
- return await make_deferred_yieldable(d)
-
- client = Mock()
- client.get_file = get_file
-
- self.storage_path = self.mktemp()
- self.media_store_path = self.mktemp()
- os.mkdir(self.storage_path)
- os.mkdir(self.media_store_path)
-
- config = self.default_config()
- config["media_store_path"] = self.media_store_path
- config["thumbnail_requirements"] = {}
- config["max_image_pixels"] = 2000000
-
- provider_config = {
- "module": "synapse.rest.media.v1.storage_provider.FileStorageProviderBackend",
- "store_local": True,
- "store_synchronous": False,
- "store_remote": True,
- "config": {"directory": self.storage_path},
- }
- config["media_storage_providers"] = [provider_config]
-
- hs = self.setup_test_homeserver(config=config, federation_http_client=client)
-
- return hs
-
- def _ensure_quarantined(self, admin_user_tok, server_and_media_id):
+ def _ensure_quarantined(
+ self, admin_user_tok: str, server_and_media_id: str
+ ) -> None:
"""Ensure a piece of media is quarantined when trying to access it."""
channel = make_request(
self.reactor,
@@ -216,12 +174,18 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
),
)
- def test_quarantine_media_requires_admin(self):
+ @parameterized.expand(
+ [
+ # Attempt quarantine media APIs as non-admin
+ "/_synapse/admin/v1/media/quarantine/example.org/abcde12345",
+ # And the roomID/userID endpoint
+ "/_synapse/admin/v1/room/!room%3Aexample.com/media/quarantine",
+ ]
+ )
+ def test_quarantine_media_requires_admin(self, url: str) -> None:
self.register_user("nonadmin", "pass", admin=False)
non_admin_user_tok = self.login("nonadmin", "pass")
- # Attempt quarantine media APIs as non-admin
- url = "/_synapse/admin/v1/media/quarantine/example.org/abcde12345"
channel = self.make_request(
"POST",
url.encode("ascii"),
@@ -235,22 +199,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
msg="Expected forbidden on quarantining media as a non-admin",
)
- # And the roomID/userID endpoint
- url = "/_synapse/admin/v1/room/!room%3Aexample.com/media/quarantine"
- channel = self.make_request(
- "POST",
- url.encode("ascii"),
- access_token=non_admin_user_tok,
- )
-
- # Expect a forbidden error
- self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
- msg="Expected forbidden on quarantining media as a non-admin",
- )
-
- def test_quarantine_media_by_id(self):
+ def test_quarantine_media_by_id(self) -> None:
self.register_user("id_admin", "pass", admin=True)
admin_user_tok = self.login("id_admin", "pass")
@@ -295,7 +244,15 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
# Attempt to access the media
self._ensure_quarantined(admin_user_tok, server_name_and_media_id)
- def test_quarantine_all_media_in_room(self, override_url_template=None):
+ @parameterized.expand(
+ [
+ # regular API path
+ "/_synapse/admin/v1/room/%s/media/quarantine",
+ # deprecated API path
+ "/_synapse/admin/v1/quarantine_media/%s",
+ ]
+ )
+ def test_quarantine_all_media_in_room(self, url: str) -> None:
self.register_user("room_admin", "pass", admin=True)
admin_user_tok = self.login("room_admin", "pass")
@@ -333,16 +290,9 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
tok=non_admin_user_tok,
)
- # Quarantine all media in the room
- if override_url_template:
- url = override_url_template % urllib.parse.quote(room_id)
- else:
- url = "/_synapse/admin/v1/room/%s/media/quarantine" % urllib.parse.quote(
- room_id
- )
channel = self.make_request(
"POST",
- url,
+ url % urllib.parse.quote(room_id),
access_token=admin_user_tok,
)
self.pump(1.0)
@@ -359,11 +309,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
self._ensure_quarantined(admin_user_tok, server_and_media_id_1)
self._ensure_quarantined(admin_user_tok, server_and_media_id_2)
- def test_quarantine_all_media_in_room_deprecated_api_path(self):
- # Perform the above test with the deprecated API path
- self.test_quarantine_all_media_in_room("/_synapse/admin/v1/quarantine_media/%s")
-
- def test_quarantine_all_media_by_user(self):
+ def test_quarantine_all_media_by_user(self) -> None:
self.register_user("user_admin", "pass", admin=True)
admin_user_tok = self.login("user_admin", "pass")
@@ -401,7 +347,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
self._ensure_quarantined(admin_user_tok, server_and_media_id_1)
self._ensure_quarantined(admin_user_tok, server_and_media_id_2)
- def test_cannot_quarantine_safe_media(self):
+ def test_cannot_quarantine_safe_media(self) -> None:
self.register_user("user_admin", "pass", admin=True)
admin_user_tok = self.login("user_admin", "pass")
@@ -475,7 +421,7 @@ class PurgeHistoryTestCase(unittest.HomeserverTestCase):
room.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
@@ -488,7 +434,7 @@ class PurgeHistoryTestCase(unittest.HomeserverTestCase):
self.url = f"/_synapse/admin/v1/purge_history/{self.room_id}"
self.url_status = "/_synapse/admin/v1/purge_history_status/"
- def test_purge_history(self):
+ def test_purge_history(self) -> None:
"""
Simple test of purge history API.
Test only that is is possible to call, get status HTTPStatus.OK and purge_id.
diff --git a/tests/rest/admin/test_federation.py b/tests/rest/admin/test_federation.py
index b70350b6f1..71068d16cd 100644
--- a/tests/rest/admin/test_federation.py
+++ b/tests/rest/admin/test_federation.py
@@ -20,7 +20,7 @@ from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
from synapse.api.errors import Codes
-from synapse.rest.client import login
+from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock
@@ -43,20 +43,22 @@ class FederationTestCase(unittest.HomeserverTestCase):
@parameterized.expand(
[
- ("/_synapse/admin/v1/federation/destinations",),
- ("/_synapse/admin/v1/federation/destinations/dummy",),
+ ("GET", "/_synapse/admin/v1/federation/destinations"),
+ ("GET", "/_synapse/admin/v1/federation/destinations/dummy"),
+ (
+ "POST",
+ "/_synapse/admin/v1/federation/destinations/dummy/reset_connection",
+ ),
]
)
- def test_requester_is_no_admin(self, url: str) -> None:
- """
- If the user is not a server admin, an error 403 is returned.
- """
+ def test_requester_is_no_admin(self, method: str, url: str) -> None:
+ """If the user is not a server admin, an error 403 is returned."""
self.register_user("user", "pass", admin=False)
other_user_tok = self.login("user", "pass")
channel = self.make_request(
- "GET",
+ method,
url,
content={},
access_token=other_user_tok,
@@ -66,9 +68,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_invalid_parameter(self) -> None:
- """
- If parameters are invalid, an error is returned.
- """
+ """If parameters are invalid, an error is returned."""
# negative limit
channel = self.make_request(
@@ -120,10 +120,18 @@ class FederationTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+ # invalid destination
+ channel = self.make_request(
+ "POST",
+ self.url + "/dummy/reset_connection",
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
def test_limit(self) -> None:
- """
- Testing list of destinations with limit
- """
+ """Testing list of destinations with limit"""
number_destinations = 20
self._create_destinations(number_destinations)
@@ -141,9 +149,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
self._check_fields(channel.json_body["destinations"])
def test_from(self) -> None:
- """
- Testing list of destinations with a defined starting point (from)
- """
+ """Testing list of destinations with a defined starting point (from)"""
number_destinations = 20
self._create_destinations(number_destinations)
@@ -161,9 +167,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
self._check_fields(channel.json_body["destinations"])
def test_limit_and_from(self) -> None:
- """
- Testing list of destinations with a defined starting point and limit
- """
+ """Testing list of destinations with a defined starting point and limit"""
number_destinations = 20
self._create_destinations(number_destinations)
@@ -181,9 +185,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
self._check_fields(channel.json_body["destinations"])
def test_next_token(self) -> None:
- """
- Testing that `next_token` appears at the right place
- """
+ """Testing that `next_token` appears at the right place"""
number_destinations = 20
self._create_destinations(number_destinations)
@@ -242,9 +244,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
self.assertNotIn("next_token", channel.json_body)
def test_list_all_destinations(self) -> None:
- """
- List all destinations.
- """
+ """List all destinations."""
number_destinations = 5
self._create_destinations(number_destinations)
@@ -263,9 +263,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
self._check_fields(channel.json_body["destinations"])
def test_order_by(self) -> None:
- """
- Testing order list with parameter `order_by`
- """
+ """Testing order list with parameter `order_by`"""
def _order_test(
expected_destination_list: List[str],
@@ -444,6 +442,39 @@ class FederationTestCase(unittest.HomeserverTestCase):
self.assertIsNone(channel.json_body["failure_ts"])
self.assertIsNone(channel.json_body["last_successful_stream_ordering"])
+ def test_destination_reset_connection(self) -> None:
+ """Reset timeouts and wake up destination."""
+ self._create_destination("sub0.example.com", 100, 100, 100)
+
+ channel = self.make_request(
+ "POST",
+ self.url + "/sub0.example.com/reset_connection",
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+
+ retry_timings = self.get_success(
+ self.store.get_destination_retry_timings("sub0.example.com")
+ )
+ self.assertIsNone(retry_timings)
+
+ def test_destination_reset_connection_not_required(self) -> None:
+ """Try to reset timeouts of a destination with no timeouts and get an error."""
+ self._create_destination("sub0.example.com", None, 0, 0)
+
+ channel = self.make_request(
+ "POST",
+ self.url + "/sub0.example.com/reset_connection",
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(
+ "The retry timing does not need to be reset for this destination.",
+ channel.json_body["error"],
+ )
+
def _create_destination(
self,
destination: str,
@@ -496,3 +527,271 @@ class FederationTestCase(unittest.HomeserverTestCase):
self.assertIn("retry_interval", c)
self.assertIn("failure_ts", c)
self.assertIn("last_successful_stream_ordering", c)
+
+
+class DestinationMembershipTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.store = hs.get_datastore()
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.dest = "sub0.example.com"
+ self.url = f"/_synapse/admin/v1/federation/destinations/{self.dest}/rooms"
+
+ # Record that we successfully contacted a destination in the DB.
+ self.get_success(
+ self.store.set_destination_retry_timings(self.dest, None, 0, 0)
+ )
+
+ def test_requester_is_no_admin(self) -> None:
+ """If the user is not a server admin, an error 403 is returned."""
+
+ self.register_user("user", "pass", admin=False)
+ other_user_tok = self.login("user", "pass")
+
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=other_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_invalid_parameter(self) -> None:
+ """If parameters are invalid, an error is returned."""
+
+ # negative limit
+ channel = self.make_request(
+ "GET",
+ self.url + "?limit=-5",
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+
+ # negative from
+ channel = self.make_request(
+ "GET",
+ self.url + "?from=-5",
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+
+ # invalid search order
+ channel = self.make_request(
+ "GET",
+ self.url + "?dir=bar",
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+
+ # invalid destination
+ channel = self.make_request(
+ "GET",
+ "/_synapse/admin/v1/federation/destinations/%s/rooms" % ("invalid",),
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ def test_limit(self) -> None:
+ """Testing list of destinations with limit"""
+
+ number_rooms = 5
+ self._create_destination_rooms(number_rooms)
+
+ channel = self.make_request(
+ "GET",
+ self.url + "?limit=3",
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(channel.json_body["total"], number_rooms)
+ self.assertEqual(len(channel.json_body["rooms"]), 3)
+ self.assertEqual(channel.json_body["next_token"], "3")
+ self._check_fields(channel.json_body["rooms"])
+
+ def test_from(self) -> None:
+ """Testing list of rooms with a defined starting point (from)"""
+
+ number_rooms = 10
+ self._create_destination_rooms(number_rooms)
+
+ channel = self.make_request(
+ "GET",
+ self.url + "?from=5",
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(channel.json_body["total"], number_rooms)
+ self.assertEqual(len(channel.json_body["rooms"]), 5)
+ self.assertNotIn("next_token", channel.json_body)
+ self._check_fields(channel.json_body["rooms"])
+
+ def test_limit_and_from(self) -> None:
+ """Testing list of rooms with a defined starting point and limit"""
+
+ number_rooms = 10
+ self._create_destination_rooms(number_rooms)
+
+ channel = self.make_request(
+ "GET",
+ self.url + "?from=3&limit=5",
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(channel.json_body["total"], number_rooms)
+ self.assertEqual(channel.json_body["next_token"], "8")
+ self.assertEqual(len(channel.json_body["rooms"]), 5)
+ self._check_fields(channel.json_body["rooms"])
+
+ def test_order_direction(self) -> None:
+ """Testing order list with parameter `dir`"""
+ number_rooms = 4
+ self._create_destination_rooms(number_rooms)
+
+ # get list in forward direction
+ channel_asc = self.make_request(
+ "GET",
+ self.url + "?dir=f",
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.OK, channel_asc.code, msg=channel_asc.json_body)
+ self.assertEqual(channel_asc.json_body["total"], number_rooms)
+ self.assertEqual(number_rooms, len(channel_asc.json_body["rooms"]))
+ self._check_fields(channel_asc.json_body["rooms"])
+
+ # get list in backward direction
+ channel_desc = self.make_request(
+ "GET",
+ self.url + "?dir=b",
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.OK, channel_desc.code, msg=channel_desc.json_body)
+ self.assertEqual(channel_desc.json_body["total"], number_rooms)
+ self.assertEqual(number_rooms, len(channel_desc.json_body["rooms"]))
+ self._check_fields(channel_desc.json_body["rooms"])
+
+ # test that both lists have different directions
+ for i in range(0, number_rooms):
+ self.assertEqual(
+ channel_asc.json_body["rooms"][i]["room_id"],
+ channel_desc.json_body["rooms"][number_rooms - 1 - i]["room_id"],
+ )
+
+ def test_next_token(self) -> None:
+ """Testing that `next_token` appears at the right place"""
+
+ number_rooms = 5
+ self._create_destination_rooms(number_rooms)
+
+ # `next_token` does not appear
+ # Number of results is the number of entries
+ channel = self.make_request(
+ "GET",
+ self.url + "?limit=5",
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(channel.json_body["total"], number_rooms)
+ self.assertEqual(len(channel.json_body["rooms"]), number_rooms)
+ self.assertNotIn("next_token", channel.json_body)
+
+ # `next_token` does not appear
+ # Number of max results is larger than the number of entries
+ channel = self.make_request(
+ "GET",
+ self.url + "?limit=6",
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(channel.json_body["total"], number_rooms)
+ self.assertEqual(len(channel.json_body["rooms"]), number_rooms)
+ self.assertNotIn("next_token", channel.json_body)
+
+ # `next_token` does appear
+ # Number of max results is smaller than the number of entries
+ channel = self.make_request(
+ "GET",
+ self.url + "?limit=4",
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(channel.json_body["total"], number_rooms)
+ self.assertEqual(len(channel.json_body["rooms"]), 4)
+ self.assertEqual(channel.json_body["next_token"], "4")
+
+ # Check
+ # Set `from` to value of `next_token` for request remaining entries
+ # `next_token` does not appear
+ channel = self.make_request(
+ "GET",
+ self.url + "?from=4",
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(channel.json_body["total"], number_rooms)
+ self.assertEqual(len(channel.json_body["rooms"]), 1)
+ self.assertNotIn("next_token", channel.json_body)
+
+ def test_destination_rooms(self) -> None:
+ """Testing that request the list of rooms is successfully."""
+ number_rooms = 3
+ self._create_destination_rooms(number_rooms)
+
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(channel.json_body["total"], number_rooms)
+ self.assertEqual(number_rooms, len(channel.json_body["rooms"]))
+ self._check_fields(channel.json_body["rooms"])
+
+ def _create_destination_rooms(self, number_rooms: int) -> None:
+ """Create a number rooms for destination
+
+ Args:
+ number_rooms: Number of rooms to be created
+ """
+ for _ in range(0, number_rooms):
+ room_id = self.helper.create_room_as(
+ self.admin_user, tok=self.admin_user_tok
+ )
+ self.get_success(
+ self.store.store_destination_rooms_entries((self.dest,), room_id, 1234)
+ )
+
+ def _check_fields(self, content: List[JsonDict]) -> None:
+ """Checks that the expected room attributes are present in content
+
+ Args:
+ content: List that is checked for content
+ """
+ for c in content:
+ self.assertIn("room_id", c)
+ self.assertIn("stream_ordering", c)
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index 3495a0366a..23da0ad736 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -2468,7 +2468,6 @@ PURGE_TABLES = [
"event_search",
"events",
"group_rooms",
- "public_room_list_stream",
"receipts_graph",
"receipts_linearized",
"room_aliases",
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 9711405735..272637e965 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -23,13 +23,17 @@ from unittest.mock import Mock, patch
from parameterized import parameterized, parameterized_class
+from twisted.test.proto_helpers import MemoryReactor
+
import synapse.rest.admin
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError
from synapse.api.room_versions import RoomVersions
from synapse.rest.client import devices, login, logout, profile, room, sync
from synapse.rest.media.v1.filepath import MediaFilePaths
+from synapse.server import HomeServer
from synapse.types import JsonDict, UserID
+from synapse.util import Clock
from tests import unittest
from tests.server import FakeSite, make_request
@@ -44,7 +48,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
profile.register_servlets,
]
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.url = "/_synapse/admin/v1/register"
@@ -61,12 +65,12 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
self.hs.config.registration.registration_shared_secret = "shared"
- self.hs.get_media_repository = Mock()
- self.hs.get_deactivate_account_handler = Mock()
+ self.hs.get_media_repository = Mock() # type: ignore[assignment]
+ self.hs.get_deactivate_account_handler = Mock() # type: ignore[assignment]
return self.hs
- def test_disabled(self):
+ def test_disabled(self) -> None:
"""
If there is no shared secret, registration through this method will be
prevented.
@@ -80,7 +84,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
"Shared secret registration is not enabled", channel.json_body["error"]
)
- def test_get_nonce(self):
+ def test_get_nonce(self) -> None:
"""
Calling GET on the endpoint will return a randomised nonce, using the
homeserver's secrets provider.
@@ -93,7 +97,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.json_body, {"nonce": "abcd"})
- def test_expired_nonce(self):
+ def test_expired_nonce(self) -> None:
"""
Calling GET on the endpoint will return a randomised nonce, which will
only last for SALT_TIMEOUT (60s).
@@ -118,7 +122,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual("unrecognised nonce", channel.json_body["error"])
- def test_register_incorrect_nonce(self):
+ def test_register_incorrect_nonce(self) -> None:
"""
Only the provided nonce can be used, as it's checked in the MAC.
"""
@@ -141,7 +145,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
self.assertEqual("HMAC incorrect", channel.json_body["error"])
- def test_register_correct_nonce(self):
+ def test_register_correct_nonce(self) -> None:
"""
When the correct nonce is provided, and the right key is provided, the
user is registered.
@@ -168,7 +172,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["user_id"])
- def test_nonce_reuse(self):
+ def test_nonce_reuse(self) -> None:
"""
A valid unrecognised nonce.
"""
@@ -197,14 +201,14 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual("unrecognised nonce", channel.json_body["error"])
- def test_missing_parts(self):
+ def test_missing_parts(self) -> None:
"""
Synapse will complain if you don't give nonce, username, password, and
mac. Admin and user_types are optional. Additional checks are done for length
and type.
"""
- def nonce():
+ def nonce() -> str:
channel = self.make_request("GET", self.url)
return channel.json_body["nonce"]
@@ -297,7 +301,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual("Invalid user type", channel.json_body["error"])
- def test_displayname(self):
+ def test_displayname(self) -> None:
"""
Test that displayname of new user is set
"""
@@ -400,7 +404,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
@override_config(
{"limit_usage_by_mau": True, "max_mau_value": 2, "mau_trial_days": 0}
)
- def test_register_mau_limit_reached(self):
+ def test_register_mau_limit_reached(self) -> None:
"""
Check we can register a user via the shared secret registration API
even if the MAU limit is reached.
@@ -450,13 +454,13 @@ class UsersListTestCase(unittest.HomeserverTestCase):
]
url = "/_synapse/admin/v2/users"
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastore()
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
- def test_no_auth(self):
+ def test_no_auth(self) -> None:
"""
Try to list users without authentication.
"""
@@ -465,7 +469,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
- def test_requester_is_no_admin(self):
+ def test_requester_is_no_admin(self) -> None:
"""
If the user is not a server admin, an error is returned.
"""
@@ -477,7 +481,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
- def test_all_users(self):
+ def test_all_users(self) -> None:
"""
List all users, including deactivated users.
"""
@@ -497,7 +501,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
# Check that all fields are available
self._check_fields(channel.json_body["users"])
- def test_search_term(self):
+ def test_search_term(self) -> None:
"""Test that searching for a users works correctly"""
def _search_test(
@@ -505,7 +509,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
search_term: str,
search_field: Optional[str] = "name",
expected_http_code: Optional[int] = HTTPStatus.OK,
- ):
+ ) -> None:
"""Search for a user and check that the returned user's id is a match
Args:
@@ -575,7 +579,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
_search_test(None, "foo", "user_id")
_search_test(None, "bar", "user_id")
- def test_invalid_parameter(self):
+ def test_invalid_parameter(self) -> None:
"""
If parameters are invalid, an error is returned.
"""
@@ -640,7 +644,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
- def test_limit(self):
+ def test_limit(self) -> None:
"""
Testing list of users with limit
"""
@@ -661,7 +665,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.json_body["next_token"], "5")
self._check_fields(channel.json_body["users"])
- def test_from(self):
+ def test_from(self) -> None:
"""
Testing list of users with a defined starting point (from)
"""
@@ -682,7 +686,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
self.assertNotIn("next_token", channel.json_body)
self._check_fields(channel.json_body["users"])
- def test_limit_and_from(self):
+ def test_limit_and_from(self) -> None:
"""
Testing list of users with a defined starting point and limit
"""
@@ -703,7 +707,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(channel.json_body["users"]), 10)
self._check_fields(channel.json_body["users"])
- def test_next_token(self):
+ def test_next_token(self) -> None:
"""
Testing that `next_token` appears at the right place
"""
@@ -765,7 +769,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(channel.json_body["users"]), 1)
self.assertNotIn("next_token", channel.json_body)
- def test_order_by(self):
+ def test_order_by(self) -> None:
"""
Testing order list with parameter `order_by`
"""
@@ -843,7 +847,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
expected_user_list: List[str],
order_by: Optional[str],
dir: Optional[str] = None,
- ):
+ ) -> None:
"""Request the list of users in a certain order. Assert that order is what
we expect
Args:
@@ -870,7 +874,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
self.assertEqual(expected_user_list, returned_order)
self._check_fields(channel.json_body["users"])
- def _check_fields(self, content: List[JsonDict]):
+ def _check_fields(self, content: List[JsonDict]) -> None:
"""Checks that the expected user attributes are present in content
Args:
content: List that is checked for content
@@ -886,7 +890,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
self.assertIn("avatar_url", u)
self.assertIn("creation_ts", u)
- def _create_users(self, number_users: int):
+ def _create_users(self, number_users: int) -> None:
"""
Create a number of users
Args:
@@ -908,7 +912,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
login.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastore()
self.admin_user = self.register_user("admin", "pass", admin=True)
@@ -931,7 +935,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
self.store.user_add_threepid("@user:test", "email", "foo@bar.com", 0, 0)
)
- def test_no_auth(self):
+ def test_no_auth(self) -> None:
"""
Try to deactivate users without authentication.
"""
@@ -940,7 +944,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
- def test_requester_is_not_admin(self):
+ def test_requester_is_not_admin(self) -> None:
"""
If the user is not a server admin, an error is returned.
"""
@@ -961,7 +965,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
self.assertEqual("You are not a server admin", channel.json_body["error"])
- def test_user_does_not_exist(self):
+ def test_user_does_not_exist(self) -> None:
"""
Tests that deactivation for a user that does not exist returns a HTTPStatus.NOT_FOUND
"""
@@ -975,7 +979,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
- def test_erase_is_not_bool(self):
+ def test_erase_is_not_bool(self) -> None:
"""
If parameter `erase` is not boolean, return an error
"""
@@ -990,7 +994,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
- def test_user_is_not_local(self):
+ def test_user_is_not_local(self) -> None:
"""
Tests that deactivation for a user that is not a local returns a HTTPStatus.BAD_REQUEST
"""
@@ -1001,7 +1005,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual("Can only deactivate local users", channel.json_body["error"])
- def test_deactivate_user_erase_true(self):
+ def test_deactivate_user_erase_true(self) -> None:
"""
Test deactivating a user and set `erase` to `true`
"""
@@ -1046,7 +1050,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
self._is_erased("@user:test", True)
- def test_deactivate_user_erase_false(self):
+ def test_deactivate_user_erase_false(self) -> None:
"""
Test deactivating a user and set `erase` to `false`
"""
@@ -1091,7 +1095,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
self._is_erased("@user:test", False)
- def test_deactivate_user_erase_true_no_profile(self):
+ def test_deactivate_user_erase_true_no_profile(self) -> None:
"""
Test deactivating a user and set `erase` to `true`
if user has no profile information (stored in the database table `profiles`).
@@ -1162,7 +1166,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
sync.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastore()
self.auth_handler = hs.get_auth_handler()
@@ -1185,7 +1189,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.url_prefix = "/_synapse/admin/v2/users/%s"
self.url_other_user = self.url_prefix % self.other_user
- def test_requester_is_no_admin(self):
+ def test_requester_is_no_admin(self) -> None:
"""
If the user is not a server admin, an error is returned.
"""
@@ -1210,7 +1214,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
self.assertEqual("You are not a server admin", channel.json_body["error"])
- def test_user_does_not_exist(self):
+ def test_user_does_not_exist(self) -> None:
"""
Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND
"""
@@ -1224,7 +1228,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
self.assertEqual("M_NOT_FOUND", channel.json_body["errcode"])
- def test_invalid_parameter(self):
+ def test_invalid_parameter(self) -> None:
"""
If parameters are invalid, an error is returned.
"""
@@ -1319,7 +1323,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
- def test_get_user(self):
+ def test_get_user(self) -> None:
"""
Test a simple get of a user.
"""
@@ -1334,7 +1338,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("User", channel.json_body["displayname"])
self._check_fields(channel.json_body)
- def test_create_server_admin(self):
+ def test_create_server_admin(self) -> None:
"""
Check that a new admin user is created successfully.
"""
@@ -1383,7 +1387,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
self._check_fields(channel.json_body)
- def test_create_user(self):
+ def test_create_user(self) -> None:
"""
Check that a new regular user is created successfully.
"""
@@ -1450,7 +1454,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
@override_config(
{"limit_usage_by_mau": True, "max_mau_value": 2, "mau_trial_days": 0}
)
- def test_create_user_mau_limit_reached_active_admin(self):
+ def test_create_user_mau_limit_reached_active_admin(self) -> None:
"""
Check that an admin can register a new user via the admin API
even if the MAU limit is reached.
@@ -1496,7 +1500,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
@override_config(
{"limit_usage_by_mau": True, "max_mau_value": 2, "mau_trial_days": 0}
)
- def test_create_user_mau_limit_reached_passive_admin(self):
+ def test_create_user_mau_limit_reached_passive_admin(self) -> None:
"""
Check that an admin can register a new user via the admin API
even if the MAU limit is reached.
@@ -1541,7 +1545,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
"public_baseurl": "https://example.com",
}
)
- def test_create_user_email_notif_for_new_users(self):
+ def test_create_user_email_notif_for_new_users(self) -> None:
"""
Check that a new regular user is created successfully and
got an email pusher.
@@ -1584,7 +1588,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
"public_baseurl": "https://example.com",
}
)
- def test_create_user_email_no_notif_for_new_users(self):
+ def test_create_user_email_no_notif_for_new_users(self) -> None:
"""
Check that a new regular user is created successfully and
got not an email pusher.
@@ -1615,7 +1619,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
pushers = list(pushers)
self.assertEqual(len(pushers), 0)
- def test_set_password(self):
+ def test_set_password(self) -> None:
"""
Test setting a new password for another user.
"""
@@ -1631,7 +1635,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self._check_fields(channel.json_body)
- def test_set_displayname(self):
+ def test_set_displayname(self) -> None:
"""
Test setting the displayname of another user.
"""
@@ -1659,7 +1663,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual("foobar", channel.json_body["displayname"])
- def test_set_threepid(self):
+ def test_set_threepid(self) -> None:
"""
Test setting threepid for an other user.
"""
@@ -1740,7 +1744,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(0, len(channel.json_body["threepids"]))
self._check_fields(channel.json_body)
- def test_set_duplicate_threepid(self):
+ def test_set_duplicate_threepid(self) -> None:
"""
Test setting the same threepid for a second user.
First user loses and second user gets mapping of this threepid.
@@ -1827,7 +1831,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(0, len(channel.json_body["threepids"]))
self._check_fields(channel.json_body)
- def test_set_external_id(self):
+ def test_set_external_id(self) -> None:
"""
Test setting external id for an other user.
"""
@@ -1925,7 +1929,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(0, len(channel.json_body["external_ids"]))
- def test_set_duplicate_external_id(self):
+ def test_set_duplicate_external_id(self) -> None:
"""
Test that setting the same external id for a second user fails and
external id from user must not be changed.
@@ -2048,7 +2052,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
)
self._check_fields(channel.json_body)
- def test_deactivate_user(self):
+ def test_deactivate_user(self) -> None:
"""
Test deactivating another user.
"""
@@ -2113,7 +2117,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertNotIn("password_hash", channel.json_body)
@override_config({"user_directory": {"enabled": True, "search_all_users": True}})
- def test_change_name_deactivate_user_user_directory(self):
+ def test_change_name_deactivate_user_user_directory(self) -> None:
"""
Test change profile information of a deactivated user and
check that it does not appear in user directory
@@ -2156,7 +2160,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
profile = self.get_success(self.store.get_user_in_directory(self.other_user))
self.assertIsNone(profile)
- def test_reactivate_user(self):
+ def test_reactivate_user(self) -> None:
"""
Test reactivating another user.
"""
@@ -2189,7 +2193,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertNotIn("password_hash", channel.json_body)
@override_config({"password_config": {"localdb_enabled": False}})
- def test_reactivate_user_localdb_disabled(self):
+ def test_reactivate_user_localdb_disabled(self) -> None:
"""
Test reactivating another user when using SSO.
"""
@@ -2223,7 +2227,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertNotIn("password_hash", channel.json_body)
@override_config({"password_config": {"enabled": False}})
- def test_reactivate_user_password_disabled(self):
+ def test_reactivate_user_password_disabled(self) -> None:
"""
Test reactivating another user when using SSO.
"""
@@ -2256,7 +2260,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# This key was removed intentionally. Ensure it is not accidentally re-included.
self.assertNotIn("password_hash", channel.json_body)
- def test_set_user_as_admin(self):
+ def test_set_user_as_admin(self) -> None:
"""
Test setting the admin flag on a user.
"""
@@ -2284,7 +2288,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("@user:test", channel.json_body["name"])
self.assertTrue(channel.json_body["admin"])
- def test_set_user_type(self):
+ def test_set_user_type(self) -> None:
"""
Test changing user type.
"""
@@ -2335,7 +2339,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("@user:test", channel.json_body["name"])
self.assertIsNone(channel.json_body["user_type"])
- def test_accidental_deactivation_prevention(self):
+ def test_accidental_deactivation_prevention(self) -> None:
"""
Ensure an account can't accidentally be deactivated by using a str value
for the deactivated body parameter
@@ -2418,7 +2422,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# This key was removed intentionally. Ensure it is not accidentally re-included.
self.assertNotIn("password_hash", channel.json_body)
- def _check_fields(self, content: JsonDict):
+ def _check_fields(self, content: JsonDict) -> None:
"""Checks that the expected user attributes are present in content
Args:
@@ -2448,7 +2452,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
room.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
@@ -2457,7 +2461,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
self.other_user
)
- def test_no_auth(self):
+ def test_no_auth(self) -> None:
"""
Try to list rooms of an user without authentication.
"""
@@ -2466,7 +2470,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
- def test_requester_is_no_admin(self):
+ def test_requester_is_no_admin(self) -> None:
"""
If the user is not a server admin, an error is returned.
"""
@@ -2481,7 +2485,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
- def test_user_does_not_exist(self):
+ def test_user_does_not_exist(self) -> None:
"""
Tests that a lookup for a user that does not exist returns an empty list
"""
@@ -2496,7 +2500,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(0, channel.json_body["total"])
self.assertEqual(0, len(channel.json_body["joined_rooms"]))
- def test_user_is_not_local(self):
+ def test_user_is_not_local(self) -> None:
"""
Tests that a lookup for a user that is not a local and participates in no conversation returns an empty list
"""
@@ -2512,7 +2516,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(0, channel.json_body["total"])
self.assertEqual(0, len(channel.json_body["joined_rooms"]))
- def test_no_memberships(self):
+ def test_no_memberships(self) -> None:
"""
Tests that a normal lookup for rooms is successfully
if user has no memberships
@@ -2528,7 +2532,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(0, channel.json_body["total"])
self.assertEqual(0, len(channel.json_body["joined_rooms"]))
- def test_get_rooms(self):
+ def test_get_rooms(self) -> None:
"""
Tests that a normal lookup for rooms is successfully
"""
@@ -2549,7 +2553,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(number_rooms, channel.json_body["total"])
self.assertEqual(number_rooms, len(channel.json_body["joined_rooms"]))
- def test_get_rooms_with_nonlocal_user(self):
+ def test_get_rooms_with_nonlocal_user(self) -> None:
"""
Tests that a normal lookup for rooms is successful with a non-local user
"""
@@ -2604,7 +2608,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
login.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastore()
self.admin_user = self.register_user("admin", "pass", admin=True)
@@ -2615,7 +2619,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
self.other_user
)
- def test_no_auth(self):
+ def test_no_auth(self) -> None:
"""
Try to list pushers of an user without authentication.
"""
@@ -2624,7 +2628,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
- def test_requester_is_no_admin(self):
+ def test_requester_is_no_admin(self) -> None:
"""
If the user is not a server admin, an error is returned.
"""
@@ -2639,7 +2643,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
- def test_user_does_not_exist(self):
+ def test_user_does_not_exist(self) -> None:
"""
Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND
"""
@@ -2653,7 +2657,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
- def test_user_is_not_local(self):
+ def test_user_is_not_local(self) -> None:
"""
Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST
"""
@@ -2668,7 +2672,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual("Can only look up local users", channel.json_body["error"])
- def test_get_pushers(self):
+ def test_get_pushers(self) -> None:
"""
Tests that a normal lookup for pushers is successfully
"""
@@ -2732,7 +2736,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
login.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastore()
self.media_repo = hs.get_media_repository_resource()
self.filepaths = MediaFilePaths(hs.config.media.media_store_path)
@@ -2746,7 +2750,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
)
@parameterized.expand(["GET", "DELETE"])
- def test_no_auth(self, method: str):
+ def test_no_auth(self, method: str) -> None:
"""Try to list media of an user without authentication."""
channel = self.make_request(method, self.url, {})
@@ -2754,7 +2758,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@parameterized.expand(["GET", "DELETE"])
- def test_requester_is_no_admin(self, method: str):
+ def test_requester_is_no_admin(self, method: str) -> None:
"""If the user is not a server admin, an error is returned."""
other_user_token = self.login("user", "pass")
@@ -2768,7 +2772,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@parameterized.expand(["GET", "DELETE"])
- def test_user_does_not_exist(self, method: str):
+ def test_user_does_not_exist(self, method: str) -> None:
"""Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND"""
url = "/_synapse/admin/v1/users/@unknown_person:test/media"
channel = self.make_request(
@@ -2781,7 +2785,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@parameterized.expand(["GET", "DELETE"])
- def test_user_is_not_local(self, method: str):
+ def test_user_is_not_local(self, method: str) -> None:
"""Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST"""
url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/media"
@@ -2794,7 +2798,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual("Can only look up local users", channel.json_body["error"])
- def test_limit_GET(self):
+ def test_limit_GET(self) -> None:
"""Testing list of media with limit"""
number_media = 20
@@ -2813,7 +2817,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.json_body["next_token"], 5)
self._check_fields(channel.json_body["media"])
- def test_limit_DELETE(self):
+ def test_limit_DELETE(self) -> None:
"""Testing delete of media with limit"""
number_media = 20
@@ -2830,7 +2834,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.json_body["total"], 5)
self.assertEqual(len(channel.json_body["deleted_media"]), 5)
- def test_from_GET(self):
+ def test_from_GET(self) -> None:
"""Testing list of media with a defined starting point (from)"""
number_media = 20
@@ -2849,7 +2853,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
self.assertNotIn("next_token", channel.json_body)
self._check_fields(channel.json_body["media"])
- def test_from_DELETE(self):
+ def test_from_DELETE(self) -> None:
"""Testing delete of media with a defined starting point (from)"""
number_media = 20
@@ -2866,7 +2870,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.json_body["total"], 15)
self.assertEqual(len(channel.json_body["deleted_media"]), 15)
- def test_limit_and_from_GET(self):
+ def test_limit_and_from_GET(self) -> None:
"""Testing list of media with a defined starting point and limit"""
number_media = 20
@@ -2885,7 +2889,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(channel.json_body["media"]), 10)
self._check_fields(channel.json_body["media"])
- def test_limit_and_from_DELETE(self):
+ def test_limit_and_from_DELETE(self) -> None:
"""Testing delete of media with a defined starting point and limit"""
number_media = 20
@@ -2903,7 +2907,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(channel.json_body["deleted_media"]), 10)
@parameterized.expand(["GET", "DELETE"])
- def test_invalid_parameter(self, method: str):
+ def test_invalid_parameter(self, method: str) -> None:
"""If parameters are invalid, an error is returned."""
# unkown order_by
channel = self.make_request(
@@ -2945,7 +2949,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
- def test_next_token(self):
+ def test_next_token(self) -> None:
"""
Testing that `next_token` appears at the right place
@@ -3010,7 +3014,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(channel.json_body["media"]), 1)
self.assertNotIn("next_token", channel.json_body)
- def test_user_has_no_media_GET(self):
+ def test_user_has_no_media_GET(self) -> None:
"""
Tests that a normal lookup for media is successfully
if user has no media created
@@ -3026,7 +3030,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(0, channel.json_body["total"])
self.assertEqual(0, len(channel.json_body["media"]))
- def test_user_has_no_media_DELETE(self):
+ def test_user_has_no_media_DELETE(self) -> None:
"""
Tests that a delete is successful if user has no media
"""
@@ -3041,7 +3045,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(0, channel.json_body["total"])
self.assertEqual(0, len(channel.json_body["deleted_media"]))
- def test_get_media(self):
+ def test_get_media(self) -> None:
"""Tests that a normal lookup for media is successful"""
number_media = 5
@@ -3060,7 +3064,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
self.assertNotIn("next_token", channel.json_body)
self._check_fields(channel.json_body["media"])
- def test_delete_media(self):
+ def test_delete_media(self) -> None:
"""Tests that a normal delete of media is successful"""
number_media = 5
@@ -3089,7 +3093,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
for local_path in local_paths:
self.assertFalse(os.path.exists(local_path))
- def test_order_by(self):
+ def test_order_by(self) -> None:
"""
Testing order list with parameter `order_by`
"""
@@ -3252,7 +3256,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
return media_id
- def _check_fields(self, content: List[JsonDict]):
+ def _check_fields(self, content: List[JsonDict]) -> None:
"""Checks that the expected user attributes are present in content
Args:
content: List that is checked for content
@@ -3272,7 +3276,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
expected_media_list: List[str],
order_by: Optional[str],
dir: Optional[str] = None,
- ):
+ ) -> None:
"""Request the list of media in a certain order. Assert that order is what
we expect
Args:
@@ -3312,7 +3316,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
logout.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastore()
self.admin_user = self.register_user("admin", "pass", admin=True)
@@ -3331,14 +3335,14 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
return channel.json_body["access_token"]
- def test_no_auth(self):
+ def test_no_auth(self) -> None:
"""Try to login as a user without authentication."""
channel = self.make_request("POST", self.url, b"{}")
self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
- def test_not_admin(self):
+ def test_not_admin(self) -> None:
"""Try to login as a user as a non-admin user."""
channel = self.make_request(
"POST", self.url, b"{}", access_token=self.other_user_tok
@@ -3346,7 +3350,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
- def test_send_event(self):
+ def test_send_event(self) -> None:
"""Test that sending event as a user works."""
# Create a room.
room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_tok)
@@ -3360,7 +3364,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
event = self.get_success(self.store.get_event(event_id))
self.assertEqual(event.sender, self.other_user)
- def test_devices(self):
+ def test_devices(self) -> None:
"""Tests that logging in as a user doesn't create a new device for them."""
# Login in as the user
self._get_token()
@@ -3374,7 +3378,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
# We should only see the one device (from the login in `prepare`)
self.assertEqual(len(channel.json_body["devices"]), 1)
- def test_logout(self):
+ def test_logout(self) -> None:
"""Test that calling `/logout` with the token works."""
# Login in as the user
puppet_token = self._get_token()
@@ -3397,7 +3401,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
- def test_user_logout_all(self):
+ def test_user_logout_all(self) -> None:
"""Tests that the target user calling `/logout/all` does *not* expire
the token.
"""
@@ -3424,7 +3428,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
- def test_admin_logout_all(self):
+ def test_admin_logout_all(self) -> None:
"""Tests that the admin user calling `/logout/all` does expire the
token.
"""
@@ -3464,7 +3468,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
"form_secret": "123secret",
}
)
- def test_consent(self):
+ def test_consent(self) -> None:
"""Test that sending a message is not subject to the privacy policies."""
# Have the admin user accept the terms.
self.get_success(self.store.user_set_consent_version(self.admin_user, "1.0"))
@@ -3492,7 +3496,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
@override_config(
{"limit_usage_by_mau": True, "max_mau_value": 1, "mau_trial_days": 0}
)
- def test_mau_limit(self):
+ def test_mau_limit(self) -> None:
# Create a room as the admin user. This will bump the monthly active users to 1.
room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
@@ -3524,14 +3528,14 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
login.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
self.other_user = self.register_user("user", "pass")
- self.url = self.url_prefix % self.other_user
+ self.url = self.url_prefix % self.other_user # type: ignore[attr-defined]
- def test_no_auth(self):
+ def test_no_auth(self) -> None:
"""
Try to get information of an user without authentication.
"""
@@ -3539,7 +3543,7 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
- def test_requester_is_not_admin(self):
+ def test_requester_is_not_admin(self) -> None:
"""
If the user is not a server admin, an error is returned.
"""
@@ -3554,11 +3558,11 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
- def test_user_is_not_local(self):
+ def test_user_is_not_local(self) -> None:
"""
Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST
"""
- url = self.url_prefix % "@unknown_person:unknown_domain"
+ url = self.url_prefix % "@unknown_person:unknown_domain" # type: ignore[attr-defined]
channel = self.make_request(
"GET",
@@ -3568,7 +3572,7 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual("Can only whois a local user", channel.json_body["error"])
- def test_get_whois_admin(self):
+ def test_get_whois_admin(self) -> None:
"""
The lookup should succeed for an admin.
"""
@@ -3581,7 +3585,7 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(self.other_user, channel.json_body["user_id"])
self.assertIn("devices", channel.json_body)
- def test_get_whois_user(self):
+ def test_get_whois_user(self) -> None:
"""
The lookup should succeed for a normal user looking up their own information.
"""
@@ -3604,7 +3608,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
login.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastore()
self.admin_user = self.register_user("admin", "pass", admin=True)
@@ -3617,7 +3621,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
)
@parameterized.expand(["POST", "DELETE"])
- def test_no_auth(self, method: str):
+ def test_no_auth(self, method: str) -> None:
"""
Try to get information of an user without authentication.
"""
@@ -3626,7 +3630,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@parameterized.expand(["POST", "DELETE"])
- def test_requester_is_not_admin(self, method: str):
+ def test_requester_is_not_admin(self, method: str) -> None:
"""
If the user is not a server admin, an error is returned.
"""
@@ -3637,7 +3641,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@parameterized.expand(["POST", "DELETE"])
- def test_user_is_not_local(self, method: str):
+ def test_user_is_not_local(self, method: str) -> None:
"""
Tests that shadow-banning for a user that is not a local returns a HTTPStatus.BAD_REQUEST
"""
@@ -3646,7 +3650,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
channel = self.make_request(method, url, access_token=self.admin_user_tok)
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
- def test_success(self):
+ def test_success(self) -> None:
"""
Shadow-banning should succeed for an admin.
"""
@@ -3682,7 +3686,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
login.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastore()
self.admin_user = self.register_user("admin", "pass", admin=True)
@@ -3695,7 +3699,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
)
@parameterized.expand(["GET", "POST", "DELETE"])
- def test_no_auth(self, method: str):
+ def test_no_auth(self, method: str) -> None:
"""
Try to get information of a user without authentication.
"""
@@ -3705,7 +3709,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@parameterized.expand(["GET", "POST", "DELETE"])
- def test_requester_is_no_admin(self, method: str):
+ def test_requester_is_no_admin(self, method: str) -> None:
"""
If the user is not a server admin, an error is returned.
"""
@@ -3721,7 +3725,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@parameterized.expand(["GET", "POST", "DELETE"])
- def test_user_does_not_exist(self, method: str):
+ def test_user_does_not_exist(self, method: str) -> None:
"""
Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND
"""
@@ -3743,7 +3747,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
("DELETE", "Only local users can be ratelimited"),
]
)
- def test_user_is_not_local(self, method: str, error_msg: str):
+ def test_user_is_not_local(self, method: str, error_msg: str) -> None:
"""
Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST
"""
@@ -3760,7 +3764,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual(error_msg, channel.json_body["error"])
- def test_invalid_parameter(self):
+ def test_invalid_parameter(self) -> None:
"""
If parameters are invalid, an error is returned.
"""
@@ -3808,7 +3812,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
- def test_return_zero_when_null(self):
+ def test_return_zero_when_null(self) -> None:
"""
If values in database are `null` API should return an int `0`
"""
@@ -3834,7 +3838,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
self.assertEqual(0, channel.json_body["messages_per_second"])
self.assertEqual(0, channel.json_body["burst_count"])
- def test_success(self):
+ def test_success(self) -> None:
"""
Rate-limiting (set/update/delete) should succeed for an admin.
"""
@@ -3908,7 +3912,7 @@ class AccountDataTestCase(unittest.HomeserverTestCase):
login.register_servlets,
]
- def prepare(self, reactor, clock, hs) -> None:
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastore()
self.admin_user = self.register_user("admin", "pass", admin=True)
diff --git a/tests/rest/admin/test_username_available.py b/tests/rest/admin/test_username_available.py
index 7978626e71..b21f6d4689 100644
--- a/tests/rest/admin/test_username_available.py
+++ b/tests/rest/admin/test_username_available.py
@@ -14,9 +14,13 @@
from http import HTTPStatus
+from twisted.test.proto_helpers import MemoryReactor
+
import synapse.rest.admin
from synapse.api.errors import Codes, SynapseError
from synapse.rest.client import login
+from synapse.server import HomeServer
+from synapse.util import Clock
from tests import unittest
@@ -28,11 +32,11 @@ class UsernameAvailableTestCase(unittest.HomeserverTestCase):
]
url = "/_synapse/admin/v1/username_available"
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
- async def check_username(username):
+ async def check_username(username: str) -> bool:
if username == "allowed":
return True
raise SynapseError(
@@ -44,24 +48,24 @@ class UsernameAvailableTestCase(unittest.HomeserverTestCase):
handler = self.hs.get_registration_handler()
handler.check_username = check_username
- def test_username_available(self):
+ def test_username_available(self) -> None:
"""
The endpoint should return a HTTPStatus.OK response if the username does not exist
"""
url = "%s?username=%s" % (self.url, "allowed")
- channel = self.make_request("GET", url, None, self.admin_user_tok)
+ channel = self.make_request("GET", url, access_token=self.admin_user_tok)
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertTrue(channel.json_body["available"])
- def test_username_unavailable(self):
+ def test_username_unavailable(self) -> None:
"""
The endpoint should return a HTTPStatus.OK response if the username does not exist
"""
url = "%s?username=%s" % (self.url, "disallowed")
- channel = self.make_request("GET", url, None, self.admin_user_tok)
+ channel = self.make_request("GET", url, access_token=self.admin_user_tok)
self.assertEqual(
HTTPStatus.BAD_REQUEST,
diff --git a/tests/rest/client/test_profile.py b/tests/rest/client/test_profile.py
index 2860579c2e..ead883ded8 100644
--- a/tests/rest/client/test_profile.py
+++ b/tests/rest/client/test_profile.py
@@ -13,8 +13,12 @@
# limitations under the License.
"""Tests REST events for /profile paths."""
+from typing import Any, Dict
+
+from synapse.api.errors import Codes
from synapse.rest import admin
from synapse.rest.client import login, profile, room
+from synapse.types import UserID
from tests import unittest
@@ -25,6 +29,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
admin.register_servlets_for_client_rest_resource,
login.register_servlets,
profile.register_servlets,
+ room.register_servlets,
]
def make_homeserver(self, reactor, clock):
@@ -150,6 +155,157 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200, channel.result)
return channel.json_body.get("avatar_url")
+ @unittest.override_config({"max_avatar_size": 50})
+ def test_avatar_size_limit_global(self):
+ """Tests that the maximum size limit for avatars is enforced when updating a
+ global profile.
+ """
+ self._setup_local_files(
+ {
+ "small": {"size": 40},
+ "big": {"size": 60},
+ }
+ )
+
+ channel = self.make_request(
+ "PUT",
+ f"/profile/{self.owner}/avatar_url",
+ content={"avatar_url": "mxc://test/big"},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 403, channel.result)
+ self.assertEqual(
+ channel.json_body["errcode"], Codes.FORBIDDEN, channel.json_body
+ )
+
+ channel = self.make_request(
+ "PUT",
+ f"/profile/{self.owner}/avatar_url",
+ content={"avatar_url": "mxc://test/small"},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ @unittest.override_config({"max_avatar_size": 50})
+ def test_avatar_size_limit_per_room(self):
+ """Tests that the maximum size limit for avatars is enforced when updating a
+ per-room profile.
+ """
+ self._setup_local_files(
+ {
+ "small": {"size": 40},
+ "big": {"size": 60},
+ }
+ )
+
+ room_id = self.helper.create_room_as(tok=self.owner_tok)
+
+ channel = self.make_request(
+ "PUT",
+ f"/rooms/{room_id}/state/m.room.member/{self.owner}",
+ content={"membership": "join", "avatar_url": "mxc://test/big"},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 403, channel.result)
+ self.assertEqual(
+ channel.json_body["errcode"], Codes.FORBIDDEN, channel.json_body
+ )
+
+ channel = self.make_request(
+ "PUT",
+ f"/rooms/{room_id}/state/m.room.member/{self.owner}",
+ content={"membership": "join", "avatar_url": "mxc://test/small"},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ @unittest.override_config({"allowed_avatar_mimetypes": ["image/png"]})
+ def test_avatar_allowed_mime_type_global(self):
+ """Tests that the MIME type whitelist for avatars is enforced when updating a
+ global profile.
+ """
+ self._setup_local_files(
+ {
+ "good": {"mimetype": "image/png"},
+ "bad": {"mimetype": "application/octet-stream"},
+ }
+ )
+
+ channel = self.make_request(
+ "PUT",
+ f"/profile/{self.owner}/avatar_url",
+ content={"avatar_url": "mxc://test/bad"},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 403, channel.result)
+ self.assertEqual(
+ channel.json_body["errcode"], Codes.FORBIDDEN, channel.json_body
+ )
+
+ channel = self.make_request(
+ "PUT",
+ f"/profile/{self.owner}/avatar_url",
+ content={"avatar_url": "mxc://test/good"},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ @unittest.override_config({"allowed_avatar_mimetypes": ["image/png"]})
+ def test_avatar_allowed_mime_type_per_room(self):
+ """Tests that the MIME type whitelist for avatars is enforced when updating a
+ per-room profile.
+ """
+ self._setup_local_files(
+ {
+ "good": {"mimetype": "image/png"},
+ "bad": {"mimetype": "application/octet-stream"},
+ }
+ )
+
+ room_id = self.helper.create_room_as(tok=self.owner_tok)
+
+ channel = self.make_request(
+ "PUT",
+ f"/rooms/{room_id}/state/m.room.member/{self.owner}",
+ content={"membership": "join", "avatar_url": "mxc://test/bad"},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 403, channel.result)
+ self.assertEqual(
+ channel.json_body["errcode"], Codes.FORBIDDEN, channel.json_body
+ )
+
+ channel = self.make_request(
+ "PUT",
+ f"/rooms/{room_id}/state/m.room.member/{self.owner}",
+ content={"membership": "join", "avatar_url": "mxc://test/good"},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ def _setup_local_files(self, names_and_props: Dict[str, Dict[str, Any]]):
+ """Stores metadata about files in the database.
+
+ Args:
+ names_and_props: A dictionary with one entry per file, with the key being the
+ file's name, and the value being a dictionary of properties. Supported
+ properties are "mimetype" (for the file's type) and "size" (for the
+ file's size).
+ """
+ store = self.hs.get_datastore()
+
+ for name, props in names_and_props.items():
+ self.get_success(
+ store.store_local_media(
+ media_id=name,
+ media_type=props.get("mimetype", "image/png"),
+ time_now_ms=self.clock.time_msec(),
+ upload_name=None,
+ media_length=props.get("size", 50),
+ user_id=UserID.from_string("@rin:test"),
+ )
+ )
+
class ProfilesRestrictedTestCase(unittest.HomeserverTestCase):
diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py
index 6e7c0f11df..407dd32a73 100644
--- a/tests/rest/client/test_register.py
+++ b/tests/rest/client/test_register.py
@@ -726,6 +726,47 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
{"errcode": "M_UNKNOWN", "error": "Unable to parse email address"},
)
+ @override_config(
+ {
+ "inhibit_user_in_use_error": True,
+ }
+ )
+ def test_inhibit_user_in_use_error(self):
+ """Tests that the 'inhibit_user_in_use_error' configuration flag behaves
+ correctly.
+ """
+ username = "arthur"
+
+ # Manually register the user, so we know the test isn't passing because of a lack
+ # of clashing.
+ reg_handler = self.hs.get_registration_handler()
+ self.get_success(reg_handler.register_user(username))
+
+ # Check that /available correctly ignores the username provided despite the
+ # username being already registered.
+ channel = self.make_request("GET", "register/available?username=" + username)
+ self.assertEquals(200, channel.code, channel.result)
+
+ # Test that when starting a UIA registration flow the request doesn't fail because
+ # of a conflicting username
+ channel = self.make_request(
+ "POST",
+ "register",
+ {"username": username, "type": "m.login.password", "password": "foo"},
+ )
+ self.assertEqual(channel.code, 401)
+ self.assertIn("session", channel.json_body)
+
+ # Test that finishing the registration fails because of a conflicting username.
+ session = channel.json_body["session"]
+ channel = self.make_request(
+ "POST",
+ "register",
+ {"auth": {"session": session, "type": LoginType.DUMMY}},
+ )
+ self.assertEqual(channel.code, 400, channel.json_body)
+ self.assertEqual(channel.json_body["errcode"], Codes.USER_IN_USE)
+
class AccountValidityTestCase(unittest.HomeserverTestCase):
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index c9b220e73d..96ae7790bb 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -577,7 +577,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(200, channel.code, channel.json_body)
room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
self.assertTrue(room_timeline["limited"])
- self._find_event_in_chunk(room_timeline["events"])
+ assert_bundle(self._find_event_in_chunk(room_timeline["events"]))
def test_aggregation_get_event_for_annotation(self):
"""Test that annotations do not get bundled aggregations included
diff --git a/tests/rest/client/test_room_batch.py b/tests/rest/client/test_room_batch.py
index 721454c187..e9f8704035 100644
--- a/tests/rest/client/test_room_batch.py
+++ b/tests/rest/client/test_room_batch.py
@@ -89,7 +89,7 @@ class RoomBatchTestCase(unittest.HomeserverTestCase):
self.clock = clock
self.storage = hs.get_storage()
- self.virtual_user_id = self.register_appservice_user(
+ self.virtual_user_id, _ = self.register_appservice_user(
"as_user_potato", self.appservice.token
)
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index 8424383580..1c0cb0cf4f 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -31,6 +31,7 @@ from typing import (
overload,
)
from unittest.mock import patch
+from urllib.parse import urlencode
import attr
from typing_extensions import Literal
@@ -147,12 +148,20 @@ class RestHelper:
expect_code=expect_code,
)
- def join(self, room=None, user=None, expect_code=200, tok=None):
+ def join(
+ self,
+ room: str,
+ user: Optional[str] = None,
+ expect_code: int = 200,
+ tok: Optional[str] = None,
+ appservice_user_id: Optional[str] = None,
+ ) -> None:
self.change_membership(
room=room,
src=user,
targ=user,
tok=tok,
+ appservice_user_id=appservice_user_id,
membership=Membership.JOIN,
expect_code=expect_code,
)
@@ -209,11 +218,12 @@ class RestHelper:
def change_membership(
self,
room: str,
- src: str,
- targ: str,
+ src: Optional[str],
+ targ: Optional[str],
membership: str,
extra_data: Optional[dict] = None,
tok: Optional[str] = None,
+ appservice_user_id: Optional[str] = None,
expect_code: int = 200,
expect_errcode: Optional[str] = None,
) -> None:
@@ -227,15 +237,26 @@ class RestHelper:
membership: The type of membership event
extra_data: Extra information to include in the content of the event
tok: The user access token to use
+ appservice_user_id: The `user_id` URL parameter to pass.
+ This allows driving an application service user
+ using an application service access token in `tok`.
expect_code: The expected HTTP response code
expect_errcode: The expected Matrix error code
"""
temp_id = self.auth_user_id
self.auth_user_id = src
- path = "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" % (room, targ)
+ path = f"/_matrix/client/r0/rooms/{room}/state/m.room.member/{targ}"
+ url_params: Dict[str, str] = {}
+
if tok:
- path = path + "?access_token=%s" % tok
+ url_params["access_token"] = tok
+
+ if appservice_user_id:
+ url_params["user_id"] = appservice_user_id
+
+ if url_params:
+ path += "?" + urlencode(url_params)
data = {"membership": membership}
data.update(extra_data or {})
diff --git a/tests/test_preview.py b/tests/rest/media/v1/test_html_preview.py
index 46e02f483f..a4b57e3d1f 100644
--- a/tests/test_preview.py
+++ b/tests/rest/media/v1/test_html_preview.py
@@ -16,10 +16,11 @@ from synapse.rest.media.v1.preview_html import (
_get_html_media_encodings,
decode_body,
parse_html_to_open_graph,
+ rebase_url,
summarize_paragraphs,
)
-from . import unittest
+from tests import unittest
try:
import lxml
@@ -447,3 +448,34 @@ class MediaEncodingTestCase(unittest.TestCase):
'text/html; charset="invalid"',
)
self.assertEqual(list(encodings), ["utf-8", "cp1252"])
+
+
+class RebaseUrlTestCase(unittest.TestCase):
+ def test_relative(self):
+ """Relative URLs should be resolved based on the context of the base URL."""
+ self.assertEqual(
+ rebase_url("subpage", "https://example.com/foo/"),
+ "https://example.com/foo/subpage",
+ )
+ self.assertEqual(
+ rebase_url("sibling", "https://example.com/foo"),
+ "https://example.com/sibling",
+ )
+ self.assertEqual(
+ rebase_url("/bar", "https://example.com/foo/"),
+ "https://example.com/bar",
+ )
+
+ def test_absolute(self):
+ """Absolute URLs should not be modified."""
+ self.assertEqual(
+ rebase_url("https://alice.com/a/", "https://example.com/foo/"),
+ "https://alice.com/a/",
+ )
+
+ def test_data(self):
+ """Data URLs should not be modified."""
+ self.assertEqual(
+ rebase_url("data:,Hello%2C%20World%21", "https://example.com/foo/"),
+ "data:,Hello%2C%20World%21",
+ )
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
index 16e904f15b..53f6186213 100644
--- a/tests/rest/media/v1/test_url_preview.py
+++ b/tests/rest/media/v1/test_url_preview.py
@@ -12,9 +12,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import base64
import json
import os
import re
+from urllib.parse import urlencode
from twisted.internet._resolver import HostResolution
from twisted.internet.address import IPv4Address, IPv6Address
@@ -23,6 +25,7 @@ from twisted.test.proto_helpers import AccumulatingProtocol
from synapse.config.oembed import OEmbedEndpointConfig
from synapse.rest.media.v1.preview_url_resource import IMAGE_CACHE_EXPIRY_MS
+from synapse.types import JsonDict
from synapse.util.stringutils import parse_and_validate_mxc_uri
from tests import unittest
@@ -142,6 +145,14 @@ class URLPreviewTests(unittest.HomeserverTestCase):
def create_test_resource(self):
return self.hs.get_media_repository_resource()
+ def _assert_small_png(self, json_body: JsonDict) -> None:
+ """Assert properties from the SMALL_PNG test image."""
+ self.assertTrue(json_body["og:image"].startswith("mxc://"))
+ self.assertEqual(json_body["og:image:height"], 1)
+ self.assertEqual(json_body["og:image:width"], 1)
+ self.assertEqual(json_body["og:image:type"], "image/png")
+ self.assertEqual(json_body["matrix:image:size"], 67)
+
def test_cache_returns_correct_type(self):
self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
@@ -569,6 +580,66 @@ class URLPreviewTests(unittest.HomeserverTestCase):
server.data,
)
+ def test_data_url(self):
+ """
+ Requesting to preview a data URL is not supported.
+ """
+ self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
+
+ data = base64.b64encode(SMALL_PNG).decode()
+
+ query_params = urlencode(
+ {
+ "url": f'<html><head><img src="data:image/png;base64,{data}" /></head></html>'
+ }
+ )
+
+ channel = self.make_request(
+ "GET",
+ f"preview_url?{query_params}",
+ shorthand=False,
+ )
+ self.pump()
+
+ self.assertEqual(channel.code, 500)
+
+ def test_inline_data_url(self):
+ """
+ An inline image (as a data URL) should be parsed properly.
+ """
+ self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
+
+ data = base64.b64encode(SMALL_PNG)
+
+ end_content = (
+ b"<html><head>" b'<img src="data:image/png;base64,%s" />' b"</head></html>"
+ ) % (data,)
+
+ channel = self.make_request(
+ "GET",
+ "preview_url?url=http://matrix.org",
+ shorthand=False,
+ await_result=False,
+ )
+ self.pump()
+
+ client = self.reactor.tcpClients[0][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ (
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+ b'Content-Type: text/html; charset="utf8"\r\n\r\n'
+ )
+ % (len(end_content),)
+ + end_content
+ )
+
+ self.pump()
+ self.assertEqual(channel.code, 200)
+ self._assert_small_png(channel.json_body)
+
def test_oembed_photo(self):
"""Test an oEmbed endpoint which returns a 'photo' type which redirects the preview to a new URL."""
self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")]
@@ -626,10 +697,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200)
body = channel.json_body
self.assertEqual(body["og:url"], "http://twitter.com/matrixdotorg/status/12345")
- self.assertTrue(body["og:image"].startswith("mxc://"))
- self.assertEqual(body["og:image:height"], 1)
- self.assertEqual(body["og:image:width"], 1)
- self.assertEqual(body["og:image:type"], "image/png")
+ self._assert_small_png(body)
def test_oembed_rich(self):
"""Test an oEmbed endpoint which returns HTML content via the 'rich' type."""
@@ -820,10 +888,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertEqual(
body["og:url"], "http://www.twitter.com/matrixdotorg/status/12345"
)
- self.assertTrue(body["og:image"].startswith("mxc://"))
- self.assertEqual(body["og:image:height"], 1)
- self.assertEqual(body["og:image:width"], 1)
- self.assertEqual(body["og:image:type"], "image/png")
+ self._assert_small_png(body)
def _download_image(self):
"""Downloads an image into the URL cache.
diff --git a/tests/server.py b/tests/server.py
index a0cd14ea45..82990c2eb9 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -313,7 +313,7 @@ def make_request(
req = request(channel, site)
req.content = BytesIO(content)
# Twisted expects to be at the end of the content when parsing the request.
- req.content.seek(SEEK_END)
+ req.content.seek(0, SEEK_END)
if access_token:
req.requestHeaders.addRawHeader(
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 329490caad..ddcb7f5549 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -266,7 +266,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
service = Mock(id=self.as_list[0]["id"])
events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")])
txn = self.get_success(
- defer.ensureDeferred(self.store.create_appservice_txn(service, events, []))
+ defer.ensureDeferred(
+ self.store.create_appservice_txn(service, events, [], [])
+ )
)
self.assertEquals(txn.id, 1)
self.assertEquals(txn.events, events)
@@ -280,7 +282,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
self.get_success(self._set_last_txn(service.id, 9643)) # AS is falling behind
self.get_success(self._insert_txn(service.id, 9644, events))
self.get_success(self._insert_txn(service.id, 9645, events))
- txn = self.get_success(self.store.create_appservice_txn(service, events, []))
+ txn = self.get_success(
+ self.store.create_appservice_txn(service, events, [], [])
+ )
self.assertEquals(txn.id, 9646)
self.assertEquals(txn.events, events)
self.assertEquals(txn.service, service)
@@ -291,7 +295,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
service = Mock(id=self.as_list[0]["id"])
events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")])
self.get_success(self._set_last_txn(service.id, 9643))
- txn = self.get_success(self.store.create_appservice_txn(service, events, []))
+ txn = self.get_success(
+ self.store.create_appservice_txn(service, events, [], [])
+ )
self.assertEquals(txn.id, 9644)
self.assertEquals(txn.events, events)
self.assertEquals(txn.service, service)
@@ -313,7 +319,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
self.get_success(self._insert_txn(self.as_list[2]["id"], 10, events))
self.get_success(self._insert_txn(self.as_list[3]["id"], 9643, events))
- txn = self.get_success(self.store.create_appservice_txn(service, events, []))
+ txn = self.get_success(
+ self.store.create_appservice_txn(service, events, [], [])
+ )
self.assertEquals(txn.id, 9644)
self.assertEquals(txn.events, events)
self.assertEquals(txn.service, service)
@@ -481,10 +489,10 @@ class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase):
ValueError,
)
- def test_set_type_stream_id_for_appservice(self) -> None:
+ def test_set_appservice_stream_type_pos(self) -> None:
read_receipt_value = 1024
self.get_success(
- self.store.set_type_stream_id_for_appservice(
+ self.store.set_appservice_stream_type_pos(
self.service, "read_receipt", read_receipt_value
)
)
@@ -494,7 +502,7 @@ class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase):
self.assertEqual(result, read_receipt_value)
self.get_success(
- self.store.set_type_stream_id_for_appservice(
+ self.store.set_appservice_stream_type_pos(
self.service, "presence", read_receipt_value
)
)
@@ -503,9 +511,9 @@ class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase):
)
self.assertEqual(result, read_receipt_value)
- def test_set_type_stream_id_for_appservice_invalid_type(self) -> None:
+ def test_set_appservice_stream_type_pos_invalid_type(self) -> None:
self.get_failure(
- self.store.set_type_stream_id_for_appservice(self.service, "foobar", 1024),
+ self.store.set_appservice_stream_type_pos(self.service, "foobar", 1024),
ValueError,
)
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index 7b7f6c349e..e3273a93f9 100644
--- a/tests/storage/test_event_chain.py
+++ b/tests/storage/test_event_chain.py
@@ -19,6 +19,7 @@ from twisted.trial import unittest
from synapse.api.constants import EventTypes
from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase
+from synapse.events.snapshot import EventContext
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.storage.databases.main.events import _LinkMap
@@ -391,7 +392,9 @@ class EventChainStoreTestCase(HomeserverTestCase):
def _persist(txn):
# We need to persist the events to the events and state_events
# tables.
- persist_events_store._store_event_txn(txn, [(e, {}) for e in events])
+ persist_events_store._store_event_txn(
+ txn, [(e, EventContext()) for e in events]
+ )
# Actually call the function that calculates the auth chain stuff.
persist_events_store._persist_event_auth_chain_txn(txn, events)
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index 7f5b28aed8..48f1e9d841 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -341,7 +341,9 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase):
# Register an AS user.
user = self.register_user("user", "pass")
token = self.login(user, "pass")
- as_user = self.register_appservice_user("as_user_potato", self.appservice.token)
+ as_user, _ = self.register_appservice_user(
+ "as_user_potato", self.appservice.token
+ )
# Join the AS user to rooms owned by the normal user.
public, private = self._create_rooms_and_inject_memberships(
diff --git a/tests/unittest.py b/tests/unittest.py
index 1431848367..6fc617601a 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -620,18 +620,19 @@ class HomeserverTestCase(TestCase):
self,
username: str,
appservice_token: str,
- ) -> str:
+ ) -> Tuple[str, str]:
"""Register an appservice user as an application service.
Requires the client-facing registration API be registered.
Args:
username: the user to be registered by an application service.
- Should be a full username, i.e. ""@localpart:hostname" as opposed to just "localpart"
+ Should NOT be a full username, i.e. just "localpart" as opposed to "@localpart:hostname"
appservice_token: the acccess token for that application service.
Raises: if the request to '/register' does not return 200 OK.
- Returns: the MXID of the new user.
+ Returns:
+ The MXID of the new user, the device ID of the new user's first device.
"""
channel = self.make_request(
"POST",
@@ -643,7 +644,7 @@ class HomeserverTestCase(TestCase):
access_token=appservice_token,
)
self.assertEqual(channel.code, 200, channel.json_body)
- return channel.json_body["user_id"]
+ return channel.json_body["user_id"], channel.json_body["device_id"]
def login(
self,
|