diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py
index 9bd6275e92..edc584d0cf 100644
--- a/tests/appservice/test_appservice.py
+++ b/tests/appservice/test_appservice.py
@@ -36,7 +36,10 @@ class ApplicationServiceTestCase(unittest.TestCase):
hostname="matrix.org", # only used by get_groups_for_user
)
self.event = Mock(
- type="m.something", room_id="!foo:bar", sender="@someone:somewhere"
+ event_id="$abc:xyz",
+ type="m.something",
+ room_id="!foo:bar",
+ sender="@someone:somewhere",
)
self.store = Mock()
@@ -50,7 +53,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.assertTrue(
(
yield defer.ensureDeferred(
- self.service.is_interested(self.event, self.store)
+ self.service.is_interested_in_event(
+ self.event.event_id, self.event, self.store
+ )
)
)
)
@@ -62,7 +67,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.assertFalse(
(
yield defer.ensureDeferred(
- self.service.is_interested(self.event, self.store)
+ self.service.is_interested_in_event(
+ self.event.event_id, self.event, self.store
+ )
)
)
)
@@ -76,7 +83,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.assertTrue(
(
yield defer.ensureDeferred(
- self.service.is_interested(self.event, self.store)
+ self.service.is_interested_in_event(
+ self.event.event_id, self.event, self.store
+ )
)
)
)
@@ -90,7 +99,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.assertTrue(
(
yield defer.ensureDeferred(
- self.service.is_interested(self.event, self.store)
+ self.service.is_interested_in_event(
+ self.event.event_id, self.event, self.store
+ )
)
)
)
@@ -104,7 +115,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.assertFalse(
(
yield defer.ensureDeferred(
- self.service.is_interested(self.event, self.store)
+ self.service.is_interested_in_event(
+ self.event.event_id, self.event, self.store
+ )
)
)
)
@@ -121,7 +134,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.assertTrue(
(
yield defer.ensureDeferred(
- self.service.is_interested(self.event, self.store)
+ self.service.is_interested_in_event(
+ self.event.event_id, self.event, self.store
+ )
)
)
)
@@ -174,7 +189,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.assertFalse(
(
yield defer.ensureDeferred(
- self.service.is_interested(self.event, self.store)
+ self.service.is_interested_in_event(
+ self.event.event_id, self.event, self.store
+ )
)
)
)
@@ -191,7 +208,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.assertTrue(
(
yield defer.ensureDeferred(
- self.service.is_interested(self.event, self.store)
+ self.service.is_interested_in_event(
+ self.event.event_id, self.event, self.store
+ )
)
)
)
@@ -207,7 +226,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.assertTrue(
(
yield defer.ensureDeferred(
- self.service.is_interested(self.event, self.store)
+ self.service.is_interested_in_event(
+ self.event.event_id, self.event, self.store
+ )
)
)
)
@@ -225,7 +246,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.assertTrue(
(
yield defer.ensureDeferred(
- self.service.is_interested(event=self.event, store=self.store)
+ self.service.is_interested_in_event(
+ self.event.event_id, self.event, self.store
+ )
)
)
)
diff --git a/tests/config/test_background_update.py b/tests/config/test_background_update.py
new file mode 100644
index 0000000000..0c32c1ca29
--- /dev/null
+++ b/tests/config/test_background_update.py
@@ -0,0 +1,58 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import yaml
+
+from synapse.storage.background_updates import BackgroundUpdater
+
+from tests.unittest import HomeserverTestCase, override_config
+
+
+class BackgroundUpdateConfigTestCase(HomeserverTestCase):
+ # Tests that the default values in the config are correctly loaded. Note that the default
+ # values are loaded when the corresponding config options are commented out, which is why there isn't
+ # a config specified here.
+ def test_default_configuration(self):
+ background_updater = BackgroundUpdater(
+ self.hs, self.hs.get_datastores().main.db_pool
+ )
+
+ self.assertEqual(background_updater.minimum_background_batch_size, 1)
+ self.assertEqual(background_updater.default_background_batch_size, 100)
+ self.assertEqual(background_updater.sleep_enabled, True)
+ self.assertEqual(background_updater.sleep_duration_ms, 1000)
+ self.assertEqual(background_updater.update_duration_ms, 100)
+
+ # Tests that non-default values for the config options are properly picked up and passed on.
+ @override_config(
+ yaml.safe_load(
+ """
+ background_updates:
+ background_update_duration_ms: 1000
+ sleep_enabled: false
+ sleep_duration_ms: 600
+ min_batch_size: 5
+ default_batch_size: 50
+ """
+ )
+ )
+ def test_custom_configuration(self):
+ background_updater = BackgroundUpdater(
+ self.hs, self.hs.get_datastores().main.db_pool
+ )
+
+ self.assertEqual(background_updater.minimum_background_batch_size, 5)
+ self.assertEqual(background_updater.default_background_batch_size, 50)
+ self.assertEqual(background_updater.sleep_enabled, False)
+ self.assertEqual(background_updater.sleep_duration_ms, 600)
+ self.assertEqual(background_updater.update_duration_ms, 1000)
diff --git a/tests/config/test_registration_config.py b/tests/config/test_registration_config.py
index 17a84d20d8..2acdb6ac61 100644
--- a/tests/config/test_registration_config.py
+++ b/tests/config/test_registration_config.py
@@ -11,14 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
+import synapse.app.homeserver
from synapse.config import ConfigError
from synapse.config.homeserver import HomeServerConfig
-from tests.unittest import TestCase
+from tests.config.utils import ConfigFileTestCase
from tests.utils import default_config
-class RegistrationConfigTestCase(TestCase):
+class RegistrationConfigTestCase(ConfigFileTestCase):
def test_session_lifetime_must_not_be_exceeded_by_smaller_lifetimes(self):
"""
session_lifetime should logically be larger than, or at least as large as,
@@ -76,3 +78,19 @@ class RegistrationConfigTestCase(TestCase):
HomeServerConfig().parse_config_dict(
{"session_lifetime": "31m", "refresh_token_lifetime": "31m", **config_dict}
)
+
+ def test_refuse_to_start_if_open_registration_and_no_verification(self):
+ self.generate_config()
+ self.add_lines_to_config(
+ [
+ " ",
+ "enable_registration: true",
+ "registrations_require_3pid: []",
+ "enable_registration_captcha: false",
+ "registration_requires_token: false",
+ ]
+ )
+
+ # Test that allowing open registration without verification raises an error
+ with self.assertRaises(ConfigError):
+ synapse.app.homeserver.setup(["-c", self.config_file])
diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py
index 45e3395b33..00ad19e446 100644
--- a/tests/events/test_utils.py
+++ b/tests/events/test_utils.py
@@ -16,6 +16,7 @@ from synapse.api.constants import EventContentFields
from synapse.api.room_versions import RoomVersions
from synapse.events import make_event_from_dict
from synapse.events.utils import (
+ SerializeEventConfig,
copy_power_levels_contents,
prune_event,
serialize_event,
@@ -392,7 +393,9 @@ class PruneEventTestCase(unittest.TestCase):
class SerializeEventTestCase(unittest.TestCase):
def serialize(self, ev, fields):
- return serialize_event(ev, 1479807801915, only_event_fields=fields)
+ return serialize_event(
+ ev, 1479807801915, config=SerializeEventConfig(only_event_fields=fields)
+ )
def test_event_fields_works_with_keys(self):
self.assertEqual(
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index 60e0c31f43..e90592855a 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -201,9 +201,12 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
self.assertEqual(len(self.edus), 1)
stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", None)
+ # We queue up device list updates to be sent over federation, so we
+ # advance to clear the queue.
+ self.reactor.advance(1)
+
# a second call should produce no new device EDUs
self.hs.get_federation_sender().send_device_messages("host2")
- self.pump()
self.assertEqual(self.edus, [])
# a second device
@@ -232,6 +235,10 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
device1_signing_key = self.generate_and_upload_device_signing_key(u1, "D1")
device2_signing_key = self.generate_and_upload_device_signing_key(u1, "D2")
+ # We queue up device list updates to be sent over federation, so we
+ # advance to clear the queue.
+ self.reactor.advance(1)
+
# expect two more edus
self.assertEqual(len(self.edus), 2)
stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", stream_id)
@@ -265,6 +272,10 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
e2e_handler.upload_signing_keys_for_user(u1, cross_signing_keys)
)
+ # We queue up device list updates to be sent over federation, so we
+ # advance to clear the queue.
+ self.reactor.advance(1)
+
# expect signing key update edu
self.assertEqual(len(self.edus), 2)
self.assertEqual(self.edus.pop(0)["edu_type"], "m.signing_key_update")
@@ -284,6 +295,10 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
)
self.assertEqual(ret["failures"], {})
+ # We queue up device list updates to be sent over federation, so we
+ # advance to clear the queue.
+ self.reactor.advance(1)
+
# expect two edus, in one or two transactions. We don't know what order the
# devices will be updated.
self.assertEqual(len(self.edus), 2)
@@ -307,6 +322,10 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
self.login("user", "pass", device_id="D2")
self.login("user", "pass", device_id="D3")
+ # We queue up device list updates to be sent over federation, so we
+ # advance to clear the queue.
+ self.reactor.advance(1)
+
# expect three edus
self.assertEqual(len(self.edus), 3)
stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", None)
@@ -318,6 +337,10 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"])
)
+ # We queue up device list updates to be sent over federation, so we
+ # advance to clear the queue.
+ self.reactor.advance(1)
+
# expect three edus, in an unknown order
self.assertEqual(len(self.edus), 3)
for edu in self.edus:
@@ -350,12 +373,19 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"])
)
+ # We queue up device list updates to be sent over federation, so we
+ # advance to clear the queue.
+ self.reactor.advance(1)
+
self.assertGreaterEqual(mock_send_txn.call_count, 4)
# recover the server
mock_send_txn.side_effect = self.record_transaction
self.hs.get_federation_sender().send_device_messages("host2")
- self.pump()
+
+ # We queue up device list updates to be sent over federation, so we
+ # advance to clear the queue.
+ self.reactor.advance(1)
# for each device, there should be a single update
self.assertEqual(len(self.edus), 3)
@@ -390,6 +420,10 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"])
)
+ # We queue up device list updates to be sent over federation, so we
+ # advance to clear the queue.
+ self.reactor.advance(1)
+
self.assertGreaterEqual(mock_send_txn.call_count, 4)
# run the prune job
@@ -401,7 +435,10 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
# recover the server
mock_send_txn.side_effect = self.record_transaction
self.hs.get_federation_sender().send_device_messages("host2")
- self.pump()
+
+ # We queue up device list updates to be sent over federation, so we
+ # advance to clear the queue.
+ self.reactor.advance(1)
# there should be a single update for this user.
self.assertEqual(len(self.edus), 1)
@@ -435,6 +472,10 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
self.login("user", "pass", device_id="D2")
self.login("user", "pass", device_id="D3")
+ # We queue up device list updates to be sent over federation, so we
+ # advance to clear the queue.
+ self.reactor.advance(1)
+
# delete them again
self.get_success(
self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"])
@@ -451,7 +492,10 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
# recover the server
mock_send_txn.side_effect = self.record_transaction
self.hs.get_federation_sender().send_device_messages("host2")
- self.pump()
+
+ # We queue up device list updates to be sent over federation, so we
+ # advance to clear the queue.
+ self.reactor.advance(1)
# ... and we should get a single update for this user.
self.assertEqual(len(self.edus), 1)
diff --git a/tests/handlers/test_admin.py b/tests/handlers/test_admin.py
index abf2a0fe0d..c1579dac61 100644
--- a/tests/handlers/test_admin.py
+++ b/tests/handlers/test_admin.py
@@ -15,11 +15,15 @@
from collections import Counter
from unittest.mock import Mock
+from twisted.test.proto_helpers import MemoryReactor
+
import synapse.rest.admin
import synapse.storage
from synapse.api.constants import EventTypes, JoinRules
from synapse.api.room_versions import RoomVersions
from synapse.rest.client import knock, login, room
+from synapse.server import HomeServer
+from synapse.util import Clock
from tests import unittest
@@ -32,7 +36,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
knock.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.admin_handler = hs.get_admin_handler()
self.user1 = self.register_user("user1", "password")
@@ -41,7 +45,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
self.user2 = self.register_user("user2", "password")
self.token2 = self.login("user2", "password")
- def test_single_public_joined_room(self):
+ def test_single_public_joined_room(self) -> None:
"""Test that we write *all* events for a public room"""
room_id = self.helper.create_room_as(
self.user1, tok=self.token1, is_public=True
@@ -74,7 +78,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
self.assertEqual(counter[(EventTypes.Member, self.user1)], 1)
self.assertEqual(counter[(EventTypes.Member, self.user2)], 1)
- def test_single_private_joined_room(self):
+ def test_single_private_joined_room(self) -> None:
"""Tests that we correctly write state when we can't see all events in
a room.
"""
@@ -112,7 +116,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
self.assertEqual(counter[(EventTypes.Member, self.user1)], 1)
self.assertEqual(counter[(EventTypes.Member, self.user2)], 1)
- def test_single_left_room(self):
+ def test_single_left_room(self) -> None:
"""Tests that we don't see events in the room after we leave."""
room_id = self.helper.create_room_as(self.user1, tok=self.token1)
self.helper.send(room_id, body="Hello!", tok=self.token1)
@@ -144,7 +148,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
self.assertEqual(counter[(EventTypes.Member, self.user1)], 1)
self.assertEqual(counter[(EventTypes.Member, self.user2)], 2)
- def test_single_left_rejoined_private_room(self):
+ def test_single_left_rejoined_private_room(self) -> None:
"""Tests that see the correct events in private rooms when we
repeatedly join and leave.
"""
@@ -185,7 +189,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
self.assertEqual(counter[(EventTypes.Member, self.user1)], 1)
self.assertEqual(counter[(EventTypes.Member, self.user2)], 3)
- def test_invite(self):
+ def test_invite(self) -> None:
"""Tests that pending invites get handled correctly."""
room_id = self.helper.create_room_as(self.user1, tok=self.token1)
self.helper.send(room_id, body="Hello!", tok=self.token1)
@@ -204,7 +208,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
self.assertEqual(args[1].content["membership"], "invite")
self.assertTrue(args[2]) # Assert there is at least one bit of state
- def test_knock(self):
+ def test_knock(self) -> None:
"""Tests that knock get handled correctly."""
# create a knockable v7 room
room_id = self.helper.create_room_as(
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index 072e6bbcdd..cead9f90df 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -59,11 +59,11 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self.event_source = hs.get_event_sources()
def test_notify_interested_services(self):
- interested_service = self._mkservice(is_interested=True)
+ interested_service = self._mkservice(is_interested_in_event=True)
services = [
- self._mkservice(is_interested=False),
+ self._mkservice(is_interested_in_event=False),
interested_service,
- self._mkservice(is_interested=False),
+ self._mkservice(is_interested_in_event=False),
]
self.mock_as_api.query_user.return_value = make_awaitable(True)
@@ -85,7 +85,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
def test_query_user_exists_unknown_user(self):
user_id = "@someone:anywhere"
- services = [self._mkservice(is_interested=True)]
+ services = [self._mkservice(is_interested_in_event=True)]
services[0].is_interested_in_user.return_value = True
self.mock_store.get_app_services.return_value = services
self.mock_store.get_user_by_id.return_value = make_awaitable(None)
@@ -102,7 +102,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
def test_query_user_exists_known_user(self):
user_id = "@someone:anywhere"
- services = [self._mkservice(is_interested=True)]
+ services = [self._mkservice(is_interested_in_event=True)]
services[0].is_interested_in_user.return_value = True
self.mock_store.get_app_services.return_value = services
self.mock_store.get_user_by_id.return_value = make_awaitable({"name": user_id})
@@ -127,11 +127,11 @@ class AppServiceHandlerTestCase(unittest.TestCase):
room_id = "!alpha:bet"
servers = ["aperture"]
- interested_service = self._mkservice_alias(is_interested_in_alias=True)
+ interested_service = self._mkservice_alias(is_room_alias_in_namespace=True)
services = [
- self._mkservice_alias(is_interested_in_alias=False),
+ self._mkservice_alias(is_room_alias_in_namespace=False),
interested_service,
- self._mkservice_alias(is_interested_in_alias=False),
+ self._mkservice_alias(is_room_alias_in_namespace=False),
]
self.mock_as_api.query_alias.return_value = make_awaitable(True)
@@ -275,7 +275,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
to be pushed out to interested appservices, and that the stream ID is
updated accordingly.
"""
- interested_service = self._mkservice(is_interested=True)
+ interested_service = self._mkservice(is_interested_in_event=True)
services = [interested_service]
self.mock_store.get_app_services.return_value = services
self.mock_store.get_type_stream_id_for_appservice.return_value = make_awaitable(
@@ -304,7 +304,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
Test sending out of order ephemeral events to the appservice handler
are ignored.
"""
- interested_service = self._mkservice(is_interested=True)
+ interested_service = self._mkservice(is_interested_in_event=True)
services = [interested_service]
self.mock_store.get_app_services.return_value = services
@@ -325,17 +325,45 @@ class AppServiceHandlerTestCase(unittest.TestCase):
interested_service, ephemeral=[]
)
- def _mkservice(self, is_interested, protocols=None):
+ def _mkservice(
+ self, is_interested_in_event: bool, protocols: Optional[Iterable] = None
+ ) -> Mock:
+ """
+ Create a new mock representing an ApplicationService.
+
+ Args:
+ is_interested_in_event: Whether this application service will be considered
+ interested in all events.
+ protocols: The third-party protocols that this application service claims to
+ support.
+
+ Returns:
+ A mock representing the ApplicationService.
+ """
service = Mock()
- service.is_interested.return_value = make_awaitable(is_interested)
+ service.is_interested_in_event.return_value = make_awaitable(
+ is_interested_in_event
+ )
service.token = "mock_service_token"
service.url = "mock_service_url"
service.protocols = protocols
return service
- def _mkservice_alias(self, is_interested_in_alias):
+ def _mkservice_alias(self, is_room_alias_in_namespace: bool) -> Mock:
+ """
+ Create a new mock representing an ApplicationService that is or is not interested
+ any given room aliase.
+
+ Args:
+ is_room_alias_in_namespace: If true, the application service will be interested
+ in all room aliases that are queried against it. If false, the application
+ service will not be interested in any room aliases.
+
+ Returns:
+ A mock representing the ApplicationService.
+ """
service = Mock()
- service.is_interested_in_alias.return_value = is_interested_in_alias
+ service.is_room_alias_in_namespace.return_value = is_room_alias_in_namespace
service.token = "mock_service_token"
service.url = "mock_service_url"
return service
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index 0c6e55e725..67a7829769 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -15,8 +15,12 @@ from unittest.mock import Mock
import pymacaroons
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.api.errors import AuthError, ResourceLimitError
from synapse.rest import admin
+from synapse.server import HomeServer
+from synapse.util import Clock
from tests import unittest
from tests.test_utils import make_awaitable
@@ -27,7 +31,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
admin.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.auth_handler = hs.get_auth_handler()
self.macaroon_generator = hs.get_macaroon_generator()
@@ -42,23 +46,23 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.user1 = self.register_user("a_user", "pass")
- def test_macaroon_caveats(self):
+ def test_macaroon_caveats(self) -> None:
token = self.macaroon_generator.generate_guest_access_token("a_user")
macaroon = pymacaroons.Macaroon.deserialize(token)
- def verify_gen(caveat):
+ def verify_gen(caveat: str) -> bool:
return caveat == "gen = 1"
- def verify_user(caveat):
+ def verify_user(caveat: str) -> bool:
return caveat == "user_id = a_user"
- def verify_type(caveat):
+ def verify_type(caveat: str) -> bool:
return caveat == "type = access"
- def verify_nonce(caveat):
+ def verify_nonce(caveat: str) -> bool:
return caveat.startswith("nonce =")
- def verify_guest(caveat):
+ def verify_guest(caveat: str) -> bool:
return caveat == "guest = true"
v = pymacaroons.Verifier()
@@ -69,7 +73,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
v.satisfy_general(verify_guest)
v.verify(macaroon, self.hs.config.key.macaroon_secret_key)
- def test_short_term_login_token_gives_user_id(self):
+ def test_short_term_login_token_gives_user_id(self) -> None:
token = self.macaroon_generator.generate_short_term_login_token(
self.user1, "", duration_in_ms=5000
)
@@ -84,7 +88,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
AuthError,
)
- def test_short_term_login_token_gives_auth_provider(self):
+ def test_short_term_login_token_gives_auth_provider(self) -> None:
token = self.macaroon_generator.generate_short_term_login_token(
self.user1, auth_provider_id="my_idp"
)
@@ -92,7 +96,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.assertEqual(self.user1, res.user_id)
self.assertEqual("my_idp", res.auth_provider_id)
- def test_short_term_login_token_cannot_replace_user_id(self):
+ def test_short_term_login_token_cannot_replace_user_id(self) -> None:
token = self.macaroon_generator.generate_short_term_login_token(
self.user1, "", duration_in_ms=5000
)
@@ -112,7 +116,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
AuthError,
)
- def test_mau_limits_disabled(self):
+ def test_mau_limits_disabled(self) -> None:
self.auth_blocking._limit_usage_by_mau = False
# Ensure does not throw exception
self.get_success(
@@ -127,7 +131,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
)
)
- def test_mau_limits_exceeded_large(self):
+ def test_mau_limits_exceeded_large(self) -> None:
self.auth_blocking._limit_usage_by_mau = True
self.hs.get_datastores().main.get_monthly_active_count = Mock(
return_value=make_awaitable(self.large_number_of_users)
@@ -150,7 +154,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
ResourceLimitError,
)
- def test_mau_limits_parity(self):
+ def test_mau_limits_parity(self) -> None:
# Ensure we're not at the unix epoch.
self.reactor.advance(1)
self.auth_blocking._limit_usage_by_mau = True
@@ -189,7 +193,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
)
)
- def test_mau_limits_not_exceeded(self):
+ def test_mau_limits_not_exceeded(self) -> None:
self.auth_blocking._limit_usage_by_mau = True
self.hs.get_datastores().main.get_monthly_active_count = Mock(
@@ -211,7 +215,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
)
)
- def _get_macaroon(self):
+ def _get_macaroon(self) -> pymacaroons.Macaroon:
token = self.macaroon_generator.generate_short_term_login_token(
self.user1, "", duration_in_ms=5000
)
diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py
index a267228846..a54aa29cf1 100644
--- a/tests/handlers/test_cas.py
+++ b/tests/handlers/test_cas.py
@@ -11,9 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Any, Dict
from unittest.mock import Mock
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.handlers.cas import CasResponse
+from synapse.server import HomeServer
+from synapse.util import Clock
from tests.test_utils import simple_async_mock
from tests.unittest import HomeserverTestCase, override_config
@@ -24,7 +29,7 @@ SERVER_URL = "https://issuer/"
class CasHandlerTestCase(HomeserverTestCase):
- def default_config(self):
+ def default_config(self) -> Dict[str, Any]:
config = super().default_config()
config["public_baseurl"] = BASE_URL
cas_config = {
@@ -40,7 +45,7 @@ class CasHandlerTestCase(HomeserverTestCase):
return config
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver()
self.handler = hs.get_cas_handler()
@@ -51,7 +56,7 @@ class CasHandlerTestCase(HomeserverTestCase):
return hs
- def test_map_cas_user_to_user(self):
+ def test_map_cas_user_to_user(self) -> None:
"""Ensure that mapping the CAS user returned from a provider to an MXID works properly."""
# stub out the auth handler
@@ -75,7 +80,7 @@ class CasHandlerTestCase(HomeserverTestCase):
auth_provider_session_id=None,
)
- def test_map_cas_user_to_existing_user(self):
+ def test_map_cas_user_to_existing_user(self) -> None:
"""Existing users can log in with CAS account."""
store = self.hs.get_datastores().main
self.get_success(
@@ -119,7 +124,7 @@ class CasHandlerTestCase(HomeserverTestCase):
auth_provider_session_id=None,
)
- def test_map_cas_user_to_invalid_localpart(self):
+ def test_map_cas_user_to_invalid_localpart(self) -> None:
"""CAS automaps invalid characters to base-64 encoding."""
# stub out the auth handler
@@ -150,7 +155,7 @@ class CasHandlerTestCase(HomeserverTestCase):
}
}
)
- def test_required_attributes(self):
+ def test_required_attributes(self) -> None:
"""The required attributes must be met from the CAS response."""
# stub out the auth handler
@@ -166,7 +171,7 @@ class CasHandlerTestCase(HomeserverTestCase):
auth_handler.complete_sso_login.assert_not_called()
# The response doesn't have any department.
- cas_response = CasResponse("test_user", {"userGroup": "staff"})
+ cas_response = CasResponse("test_user", {"userGroup": ["staff"]})
request.reset_mock()
self.get_success(
self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
diff --git a/tests/handlers/test_deactivate_account.py b/tests/handlers/test_deactivate_account.py
index ddda36c5a9..3a10791226 100644
--- a/tests/handlers/test_deactivate_account.py
+++ b/tests/handlers/test_deactivate_account.py
@@ -39,7 +39,7 @@ class DeactivateAccountTestCase(HomeserverTestCase):
self.user = self.register_user("user", "pass")
self.token = self.login("user", "pass")
- def _deactivate_my_account(self):
+ def _deactivate_my_account(self) -> None:
"""
Deactivates the account `self.user` using `self.token` and asserts
that it returns a 200 success code.
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 683677fd07..01ea7d2a42 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -14,9 +14,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import synapse.api.errors
-import synapse.handlers.device
-import synapse.storage
+from typing import Optional
+
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.api.errors import NotFoundError, SynapseError
+from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN
+from synapse.server import HomeServer
+from synapse.util import Clock
from tests import unittest
@@ -25,28 +30,27 @@ user2 = "@theresa:bbb"
class DeviceTestCase(unittest.HomeserverTestCase):
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver("server", federation_http_client=None)
self.handler = hs.get_device_handler()
self.store = hs.get_datastores().main
return hs
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# These tests assume that it starts 1000 seconds in.
self.reactor.advance(1000)
- def test_device_is_created_with_invalid_name(self):
+ def test_device_is_created_with_invalid_name(self) -> None:
self.get_failure(
self.handler.check_device_registered(
user_id="@boris:foo",
device_id="foo",
- initial_device_display_name="a"
- * (synapse.handlers.device.MAX_DEVICE_DISPLAY_NAME_LEN + 1),
+ initial_device_display_name="a" * (MAX_DEVICE_DISPLAY_NAME_LEN + 1),
),
- synapse.api.errors.SynapseError,
+ SynapseError,
)
- def test_device_is_created_if_doesnt_exist(self):
+ def test_device_is_created_if_doesnt_exist(self) -> None:
res = self.get_success(
self.handler.check_device_registered(
user_id="@boris:foo",
@@ -59,7 +63,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco"))
self.assertEqual(dev["display_name"], "display name")
- def test_device_is_preserved_if_exists(self):
+ def test_device_is_preserved_if_exists(self) -> None:
res1 = self.get_success(
self.handler.check_device_registered(
user_id="@boris:foo",
@@ -81,7 +85,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco"))
self.assertEqual(dev["display_name"], "display name")
- def test_device_id_is_made_up_if_unspecified(self):
+ def test_device_id_is_made_up_if_unspecified(self) -> None:
device_id = self.get_success(
self.handler.check_device_registered(
user_id="@theresa:foo",
@@ -93,7 +97,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
dev = self.get_success(self.handler.store.get_device("@theresa:foo", device_id))
self.assertEqual(dev["display_name"], "display")
- def test_get_devices_by_user(self):
+ def test_get_devices_by_user(self) -> None:
self._record_users()
res = self.get_success(self.handler.get_devices_by_user(user1))
@@ -131,7 +135,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
device_map["abc"],
)
- def test_get_device(self):
+ def test_get_device(self) -> None:
self._record_users()
res = self.get_success(self.handler.get_device(user1, "abc"))
@@ -146,21 +150,19 @@ class DeviceTestCase(unittest.HomeserverTestCase):
res,
)
- def test_delete_device(self):
+ def test_delete_device(self) -> None:
self._record_users()
# delete the device
self.get_success(self.handler.delete_device(user1, "abc"))
# check the device was deleted
- self.get_failure(
- self.handler.get_device(user1, "abc"), synapse.api.errors.NotFoundError
- )
+ self.get_failure(self.handler.get_device(user1, "abc"), NotFoundError)
# we'd like to check the access token was invalidated, but that's a
# bit of a PITA.
- def test_delete_device_and_device_inbox(self):
+ def test_delete_device_and_device_inbox(self) -> None:
self._record_users()
# add an device_inbox
@@ -191,7 +193,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
)
self.assertIsNone(res)
- def test_update_device(self):
+ def test_update_device(self) -> None:
self._record_users()
update = {"display_name": "new display"}
@@ -200,32 +202,29 @@ class DeviceTestCase(unittest.HomeserverTestCase):
res = self.get_success(self.handler.get_device(user1, "abc"))
self.assertEqual(res["display_name"], "new display")
- def test_update_device_too_long_display_name(self):
+ def test_update_device_too_long_display_name(self) -> None:
"""Update a device with a display name that is invalid (too long)."""
self._record_users()
# Request to update a device display name with a new value that is longer than allowed.
- update = {
- "display_name": "a"
- * (synapse.handlers.device.MAX_DEVICE_DISPLAY_NAME_LEN + 1)
- }
+ update = {"display_name": "a" * (MAX_DEVICE_DISPLAY_NAME_LEN + 1)}
self.get_failure(
self.handler.update_device(user1, "abc", update),
- synapse.api.errors.SynapseError,
+ SynapseError,
)
# Ensure the display name was not updated.
res = self.get_success(self.handler.get_device(user1, "abc"))
self.assertEqual(res["display_name"], "display 2")
- def test_update_unknown_device(self):
+ def test_update_unknown_device(self) -> None:
update = {"display_name": "new_display"}
self.get_failure(
self.handler.update_device("user_id", "unknown_device_id", update),
- synapse.api.errors.NotFoundError,
+ NotFoundError,
)
- def _record_users(self):
+ def _record_users(self) -> None:
# check this works for both devices which have a recorded client_ip,
# and those which don't.
self._record_user(user1, "xyz", "display 0")
@@ -238,8 +237,13 @@ class DeviceTestCase(unittest.HomeserverTestCase):
self.reactor.advance(10000)
def _record_user(
- self, user_id, device_id, display_name, access_token=None, ip=None
- ):
+ self,
+ user_id: str,
+ device_id: str,
+ display_name: str,
+ access_token: Optional[str] = None,
+ ip: Optional[str] = None,
+ ) -> None:
device_id = self.get_success(
self.handler.check_device_registered(
user_id=user_id,
@@ -248,7 +252,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
)
)
- if ip is not None:
+ if access_token is not None and ip is not None:
self.get_success(
self.store.insert_client_ip(
user_id, access_token, ip, "user_agent", device_id
@@ -258,7 +262,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
class DehydrationTestCase(unittest.HomeserverTestCase):
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver("server", federation_http_client=None)
self.handler = hs.get_device_handler()
self.registration = hs.get_registration_handler()
@@ -266,7 +270,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
self.store = hs.get_datastores().main
return hs
- def test_dehydrate_and_rehydrate_device(self):
+ def test_dehydrate_and_rehydrate_device(self) -> None:
user_id = "@boris:dehydration"
self.get_success(self.store.register_user(user_id, "foobar"))
@@ -303,7 +307,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
access_token=access_token,
device_id="not the right device ID",
),
- synapse.api.errors.NotFoundError,
+ NotFoundError,
)
# dehydrating the right devices should succeed and change our device ID
@@ -331,7 +335,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
# make sure that the device ID that we were initially assigned no longer exists
self.get_failure(
self.handler.get_device(user_id, device_id),
- synapse.api.errors.NotFoundError,
+ NotFoundError,
)
# make sure that there's no device available for dehydrating now
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index 6e403a87c5..11ad44223d 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -12,14 +12,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+from typing import Any, Awaitable, Callable, Dict
from unittest.mock import Mock
+from twisted.test.proto_helpers import MemoryReactor
+
import synapse.api.errors
import synapse.rest.admin
from synapse.api.constants import EventTypes
from synapse.rest.client import directory, login, room
-from synapse.types import RoomAlias, create_requester
+from synapse.server import HomeServer
+from synapse.types import JsonDict, RoomAlias, create_requester
+from synapse.util import Clock
from tests import unittest
from tests.test_utils import make_awaitable
@@ -28,13 +32,15 @@ from tests.test_utils import make_awaitable
class DirectoryTestCase(unittest.HomeserverTestCase):
"""Tests the directory service."""
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.mock_federation = Mock()
self.mock_registry = Mock()
- self.query_handlers = {}
+ self.query_handlers: Dict[str, Callable[[dict], Awaitable[JsonDict]]] = {}
- def register_query_handler(query_type, handler):
+ def register_query_handler(
+ query_type: str, handler: Callable[[dict], Awaitable[JsonDict]]
+ ) -> None:
self.query_handlers[query_type] = handler
self.mock_registry.register_query_handler = register_query_handler
@@ -54,7 +60,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
return hs
- def test_get_local_association(self):
+ def test_get_local_association(self) -> None:
self.get_success(
self.store.create_room_alias_association(
self.my_room, "!8765qwer:test", ["test"]
@@ -65,7 +71,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
self.assertEqual({"room_id": "!8765qwer:test", "servers": ["test"]}, result)
- def test_get_remote_association(self):
+ def test_get_remote_association(self) -> None:
self.mock_federation.make_query.return_value = make_awaitable(
{"room_id": "!8765qwer:test", "servers": ["test", "remote"]}
)
@@ -83,7 +89,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
ignore_backoff=True,
)
- def test_incoming_fed_query(self):
+ def test_incoming_fed_query(self) -> None:
self.get_success(
self.store.create_room_alias_association(
self.your_room, "!8765asdf:test", ["test"]
@@ -105,7 +111,7 @@ class TestCreateAlias(unittest.HomeserverTestCase):
directory.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.handler = hs.get_directory_handler()
# Create user
@@ -125,7 +131,7 @@ class TestCreateAlias(unittest.HomeserverTestCase):
self.test_user_tok = self.login("user", "pass")
self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok)
- def test_create_alias_joined_room(self):
+ def test_create_alias_joined_room(self) -> None:
"""A user can create an alias for a room they're in."""
self.get_success(
self.handler.create_association(
@@ -135,7 +141,7 @@ class TestCreateAlias(unittest.HomeserverTestCase):
)
)
- def test_create_alias_other_room(self):
+ def test_create_alias_other_room(self) -> None:
"""A user cannot create an alias for a room they're NOT in."""
other_room_id = self.helper.create_room_as(
self.admin_user, tok=self.admin_user_tok
@@ -150,7 +156,7 @@ class TestCreateAlias(unittest.HomeserverTestCase):
synapse.api.errors.SynapseError,
)
- def test_create_alias_admin(self):
+ def test_create_alias_admin(self) -> None:
"""An admin can create an alias for a room they're NOT in."""
other_room_id = self.helper.create_room_as(
self.test_user, tok=self.test_user_tok
@@ -173,7 +179,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
directory.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.handler = hs.get_directory_handler()
self.state_handler = hs.get_state_handler()
@@ -195,7 +201,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
self.test_user_tok = self.login("user", "pass")
self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok)
- def _create_alias(self, user):
+ def _create_alias(self, user) -> None:
# Create a new alias to this room.
self.get_success(
self.store.create_room_alias_association(
@@ -203,7 +209,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
)
)
- def test_delete_alias_not_allowed(self):
+ def test_delete_alias_not_allowed(self) -> None:
"""A user that doesn't meet the expected guidelines cannot delete an alias."""
self._create_alias(self.admin_user)
self.get_failure(
@@ -213,7 +219,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
synapse.api.errors.AuthError,
)
- def test_delete_alias_creator(self):
+ def test_delete_alias_creator(self) -> None:
"""An alias creator can delete their own alias."""
# Create an alias from a different user.
self._create_alias(self.test_user)
@@ -232,7 +238,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
synapse.api.errors.SynapseError,
)
- def test_delete_alias_admin(self):
+ def test_delete_alias_admin(self) -> None:
"""A server admin can delete an alias created by another user."""
# Create an alias from a different user.
self._create_alias(self.test_user)
@@ -251,7 +257,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
synapse.api.errors.SynapseError,
)
- def test_delete_alias_sufficient_power(self):
+ def test_delete_alias_sufficient_power(self) -> None:
"""A user with a sufficient power level should be able to delete an alias."""
self._create_alias(self.admin_user)
@@ -288,7 +294,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
directory.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.handler = hs.get_directory_handler()
self.state_handler = hs.get_state_handler()
@@ -317,7 +323,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
)
return room_alias
- def _set_canonical_alias(self, content):
+ def _set_canonical_alias(self, content) -> None:
"""Configure the canonical alias state on the room."""
self.helper.send_state(
self.room_id,
@@ -334,7 +340,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
)
)
- def test_remove_alias(self):
+ def test_remove_alias(self) -> None:
"""Removing an alias that is the canonical alias should remove it there too."""
# Set this new alias as the canonical alias for this room
self._set_canonical_alias(
@@ -356,7 +362,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
self.assertNotIn("alias", data["content"])
self.assertNotIn("alt_aliases", data["content"])
- def test_remove_other_alias(self):
+ def test_remove_other_alias(self) -> None:
"""Removing an alias listed as in alt_aliases should remove it there too."""
# Create a second alias.
other_test_alias = "#test2:test"
@@ -393,7 +399,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
servlets = [directory.register_servlets, room.register_servlets]
- def default_config(self):
+ def default_config(self) -> Dict[str, Any]:
config = super().default_config()
# Add custom alias creation rules to the config.
@@ -403,7 +409,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
return config
- def test_denied(self):
+ def test_denied(self) -> None:
room_id = self.helper.create_room_as(self.user_id)
channel = self.make_request(
@@ -413,7 +419,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
)
self.assertEqual(403, channel.code, channel.result)
- def test_allowed(self):
+ def test_allowed(self) -> None:
room_id = self.helper.create_room_as(self.user_id)
channel = self.make_request(
@@ -423,7 +429,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
)
self.assertEqual(200, channel.code, channel.result)
- def test_denied_during_creation(self):
+ def test_denied_during_creation(self) -> None:
"""A room alias that is not allowed should be rejected during creation."""
# Invalid room alias.
self.helper.create_room_as(
@@ -432,7 +438,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
extra_content={"room_alias_name": "foo"},
)
- def test_allowed_during_creation(self):
+ def test_allowed_during_creation(self) -> None:
"""A valid room alias should be allowed during creation."""
room_id = self.helper.create_room_as(
self.user_id,
@@ -459,7 +465,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
data = {"room_alias_name": "unofficial_test"}
allowed_localpart = "allowed"
- def default_config(self):
+ def default_config(self) -> Dict[str, Any]:
config = super().default_config()
# Add custom room list publication rules to the config.
@@ -474,7 +480,9 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
return config
- def prepare(self, reactor, clock, hs):
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, hs: HomeServer
+ ) -> HomeServer:
self.allowed_user_id = self.register_user(self.allowed_localpart, "pass")
self.allowed_access_token = self.login(self.allowed_localpart, "pass")
@@ -483,7 +491,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
return hs
- def test_denied_without_publication_permission(self):
+ def test_denied_without_publication_permission(self) -> None:
"""
Try to create a room, register an alias for it, and publish it,
as a user without permission to publish rooms.
@@ -497,7 +505,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
expect_code=403,
)
- def test_allowed_when_creating_private_room(self):
+ def test_allowed_when_creating_private_room(self) -> None:
"""
Try to create a room, register an alias for it, and NOT publish it,
as a user without permission to publish rooms.
@@ -511,7 +519,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
expect_code=200,
)
- def test_allowed_with_publication_permission(self):
+ def test_allowed_with_publication_permission(self) -> None:
"""
Try to create a room, register an alias for it, and publish it,
as a user WITH permission to publish rooms.
@@ -525,7 +533,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
expect_code=200,
)
- def test_denied_publication_with_invalid_alias(self):
+ def test_denied_publication_with_invalid_alias(self) -> None:
"""
Try to create a room, register an alias for it, and publish it,
as a user WITH permission to publish rooms.
@@ -538,7 +546,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
expect_code=403,
)
- def test_can_create_as_private_room_after_rejection(self):
+ def test_can_create_as_private_room_after_rejection(self) -> None:
"""
After failing to publish a room with an alias as a user without publish permission,
retry as the same user, but without publishing the room.
@@ -549,7 +557,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
self.test_denied_without_publication_permission()
self.test_allowed_when_creating_private_room()
- def test_can_create_with_permission_after_rejection(self):
+ def test_can_create_with_permission_after_rejection(self) -> None:
"""
After failing to publish a room with an alias as a user without publish permission,
retry as someone with permission, using the same alias.
@@ -566,7 +574,9 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
servlets = [directory.register_servlets, room.register_servlets]
- def prepare(self, reactor, clock, hs):
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, hs: HomeServer
+ ) -> HomeServer:
room_id = self.helper.create_room_as(self.user_id)
channel = self.make_request(
@@ -579,7 +589,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
return hs
- def test_disabling_room_list(self):
+ def test_disabling_room_list(self) -> None:
self.room_list_handler.enable_room_list_search = True
self.directory_handler.enable_room_list_search = True
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 9338ab92e9..ac21a28c43 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -20,33 +20,37 @@ from parameterized import parameterized
from signedjson import key as key, sign as sign
from twisted.internet import defer
+from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import RoomEncryptionAlgorithms
from synapse.api.errors import Codes, SynapseError
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+from synapse.util import Clock
from tests import unittest
from tests.test_utils import make_awaitable
class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
return self.setup_test_homeserver(federation_client=mock.Mock())
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.handler = hs.get_e2e_keys_handler()
self.store = self.hs.get_datastores().main
- def test_query_local_devices_no_devices(self):
+ def test_query_local_devices_no_devices(self) -> None:
"""If the user has no devices, we expect an empty list."""
local_user = "@boris:" + self.hs.hostname
res = self.get_success(self.handler.query_local_devices({local_user: None}))
self.assertDictEqual(res, {local_user: {}})
- def test_reupload_one_time_keys(self):
+ def test_reupload_one_time_keys(self) -> None:
"""we should be able to re-upload the same keys"""
local_user = "@boris:" + self.hs.hostname
device_id = "xyz"
- keys = {
+ keys: JsonDict = {
"alg1:k1": "key1",
"alg2:k2": {"key": "key2", "signatures": {"k1": "sig1"}},
"alg2:k3": {"key": "key3"},
@@ -74,7 +78,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
res, {"one_time_key_counts": {"alg1": 1, "alg2": 2, "signed_curve25519": 0}}
)
- def test_change_one_time_keys(self):
+ def test_change_one_time_keys(self) -> None:
"""attempts to change one-time-keys should be rejected"""
local_user = "@boris:" + self.hs.hostname
@@ -134,7 +138,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
SynapseError,
)
- def test_claim_one_time_key(self):
+ def test_claim_one_time_key(self) -> None:
local_user = "@boris:" + self.hs.hostname
device_id = "xyz"
keys = {"alg1:k1": "key1"}
@@ -161,7 +165,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
},
)
- def test_fallback_key(self):
+ def test_fallback_key(self) -> None:
local_user = "@boris:" + self.hs.hostname
device_id = "xyz"
fallback_key = {"alg1:k1": "fallback_key1"}
@@ -294,7 +298,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}},
)
- def test_replace_master_key(self):
+ def test_replace_master_key(self) -> None:
"""uploading a new signing key should make the old signing key unavailable"""
local_user = "@boris:" + self.hs.hostname
keys1 = {
@@ -328,7 +332,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
)
self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]})
- def test_reupload_signatures(self):
+ def test_reupload_signatures(self) -> None:
"""re-uploading a signature should not fail"""
local_user = "@boris:" + self.hs.hostname
keys1 = {
@@ -433,7 +437,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
self.assertDictEqual(devices["device_keys"][local_user]["abc"], device_key_1)
self.assertDictEqual(devices["device_keys"][local_user]["def"], device_key_2)
- def test_self_signing_key_doesnt_show_up_as_device(self):
+ def test_self_signing_key_doesnt_show_up_as_device(self) -> None:
"""signing keys should be hidden when fetching a user's devices"""
local_user = "@boris:" + self.hs.hostname
keys1 = {
@@ -462,7 +466,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
res = self.get_success(self.handler.query_local_devices({local_user: None}))
self.assertDictEqual(res, {local_user: {}})
- def test_upload_signatures(self):
+ def test_upload_signatures(self) -> None:
"""should check signatures that are uploaded"""
# set up a user with cross-signing keys and a device. This user will
# try uploading signatures
@@ -686,7 +690,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
other_master_key["signatures"][local_user]["ed25519:" + usersigning_pubkey],
)
- def test_query_devices_remote_no_sync(self):
+ def test_query_devices_remote_no_sync(self) -> None:
"""Tests that querying keys for a remote user that we don't share a room
with returns the cross signing keys correctly.
"""
@@ -759,7 +763,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
},
)
- def test_query_devices_remote_sync(self):
+ def test_query_devices_remote_sync(self) -> None:
"""Tests that querying keys for a remote user that we share a room with,
but haven't yet fetched the keys for, returns the cross signing keys
correctly.
@@ -845,7 +849,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
(["device_1", "device_2"],),
]
)
- def test_query_all_devices_caches_result(self, device_ids: Iterable[str]):
+ def test_query_all_devices_caches_result(self, device_ids: Iterable[str]) -> None:
"""Test that requests for all of a remote user's devices are cached.
We do this by asserting that only one call over federation was made, and that
@@ -853,7 +857,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
"""
local_user_id = "@test:test"
remote_user_id = "@test:other"
- request_body = {"device_keys": {remote_user_id: []}}
+ request_body: JsonDict = {"device_keys": {remote_user_id: []}}
response_devices = [
{
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index e8b4e39d1a..89078fc637 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import List
+from typing import List, cast
from unittest import TestCase
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseError
from synapse.api.room_versions import RoomVersions
@@ -23,7 +25,9 @@ from synapse.federation.federation_base import event_from_pdu_json
from synapse.logging.context import LoggingContext, run_in_background
from synapse.rest import admin
from synapse.rest.client import login, room
+from synapse.server import HomeServer
from synapse.types import create_requester
+from synapse.util import Clock
from synapse.util.stringutils import random_string
from tests import unittest
@@ -42,7 +46,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
room.register_servlets,
]
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver(federation_http_client=None)
self.handler = hs.get_federation_handler()
self.store = hs.get_datastores().main
@@ -50,7 +54,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
self._event_auth_handler = hs.get_event_auth_handler()
return hs
- def test_exchange_revoked_invite(self):
+ def test_exchange_revoked_invite(self) -> None:
user_id = self.register_user("kermit", "test")
tok = self.login("kermit", "test")
@@ -96,7 +100,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
self.assertEqual(failure.errcode, Codes.FORBIDDEN, failure)
self.assertEqual(failure.msg, "You are not invited to this room.")
- def test_rejected_message_event_state(self):
+ def test_rejected_message_event_state(self) -> None:
"""
Check that we store the state group correctly for rejected non-state events.
@@ -126,7 +130,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
"content": {},
"room_id": room_id,
"sender": "@yetanotheruser:" + OTHER_SERVER,
- "depth": join_event["depth"] + 1,
+ "depth": cast(int, join_event["depth"]) + 1,
"prev_events": [join_event.event_id],
"auth_events": [],
"origin_server_ts": self.clock.time_msec(),
@@ -149,7 +153,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
self.assertEqual(sg, sg2)
- def test_rejected_state_event_state(self):
+ def test_rejected_state_event_state(self) -> None:
"""
Check that we store the state group correctly for rejected state events.
@@ -180,7 +184,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
"content": {},
"room_id": room_id,
"sender": "@yetanotheruser:" + OTHER_SERVER,
- "depth": join_event["depth"] + 1,
+ "depth": cast(int, join_event["depth"]) + 1,
"prev_events": [join_event.event_id],
"auth_events": [],
"origin_server_ts": self.clock.time_msec(),
@@ -203,7 +207,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
self.assertEqual(sg, sg2)
- def test_backfill_with_many_backward_extremities(self):
+ def test_backfill_with_many_backward_extremities(self) -> None:
"""
Check that we can backfill with many backward extremities.
The goal is to make sure that when we only use a portion
@@ -262,7 +266,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
)
self.get_success(d)
- def test_backfill_floating_outlier_membership_auth(self):
+ def test_backfill_floating_outlier_membership_auth(self) -> None:
"""
As the local homeserver, check that we can properly process a federated
event from the OTHER_SERVER with auth_events that include a floating
@@ -377,7 +381,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
for ae in auth_events
]
- self.handler.federation_client.get_event_auth = get_event_auth
+ self.handler.federation_client.get_event_auth = get_event_auth # type: ignore[assignment]
with LoggingContext("receive_pdu"):
# Fake the OTHER_SERVER federating the message event over to our local homeserver
@@ -397,7 +401,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
@unittest.override_config(
{"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}}
)
- def test_invite_by_user_ratelimit(self):
+ def test_invite_by_user_ratelimit(self) -> None:
"""Tests that invites from federation to a particular user are
actually rate-limited.
"""
@@ -446,7 +450,9 @@ class FederationTestCase(unittest.HomeserverTestCase):
exc=LimitExceededError,
)
- def _build_and_send_join_event(self, other_server, other_user, room_id):
+ def _build_and_send_join_event(
+ self, other_server: str, other_user: str, room_id: str
+ ) -> EventBase:
join_event = self.get_success(
self.handler.on_make_join_request(other_server, room_id, other_user)
)
@@ -469,7 +475,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
class EventFromPduTestCase(TestCase):
- def test_valid_json(self):
+ def test_valid_json(self) -> None:
"""Valid JSON should be turned into an event."""
ev = event_from_pdu_json(
{
@@ -487,7 +493,7 @@ class EventFromPduTestCase(TestCase):
self.assertIsInstance(ev, EventBase)
- def test_invalid_numbers(self):
+ def test_invalid_numbers(self) -> None:
"""Invalid values for an integer should be rejected, all floats should be rejected."""
for value in [
-(2 ** 53),
@@ -512,7 +518,7 @@ class EventFromPduTestCase(TestCase):
RoomVersions.V6,
)
- def test_invalid_nested(self):
+ def test_invalid_nested(self) -> None:
"""List and dictionaries are recursively searched."""
with self.assertRaises(SynapseError):
event_from_pdu_json(
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index e8418b6638..014815db6e 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -13,14 +13,18 @@
# limitations under the License.
import json
import os
+from typing import Any, Dict
from unittest.mock import ANY, Mock, patch
from urllib.parse import parse_qs, urlparse
import pymacaroons
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.handlers.sso import MappingException
from synapse.server import HomeServer
-from synapse.types import UserID
+from synapse.types import JsonDict, UserID
+from synapse.util import Clock
from synapse.util.macaroons import get_value_from_macaroon
from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
@@ -98,7 +102,7 @@ class TestMappingProviderFailures(TestMappingProvider):
}
-async def get_json(url):
+async def get_json(url: str) -> JsonDict:
# Mock get_json calls to handle jwks & oidc discovery endpoints
if url == WELL_KNOWN:
# Minimal discovery document, as defined in OpenID.Discovery
@@ -116,6 +120,8 @@ async def get_json(url):
elif url == JWKS_URI:
return {"keys": []}
+ return {}
+
def _key_file_path() -> str:
"""path to a file containing the private half of a test key"""
@@ -147,12 +153,12 @@ class OidcHandlerTestCase(HomeserverTestCase):
if not HAS_OIDC:
skip = "requires OIDC"
- def default_config(self):
+ def default_config(self) -> Dict[str, Any]:
config = super().default_config()
config["public_baseurl"] = BASE_URL
return config
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.http_client = Mock(spec=["get_json"])
self.http_client.get_json.side_effect = get_json
self.http_client.user_agent = b"Synapse Test"
@@ -164,7 +170,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
sso_handler = hs.get_sso_handler()
# Mock the render error method.
self.render_error = Mock(return_value=None)
- sso_handler.render_error = self.render_error
+ sso_handler.render_error = self.render_error # type: ignore[assignment]
# Reduce the number of attempts when generating MXIDs.
sso_handler._MAP_USERNAME_RETRIES = 3
@@ -193,14 +199,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
return args
@override_config({"oidc_config": DEFAULT_CONFIG})
- def test_config(self):
+ def test_config(self) -> None:
"""Basic config correctly sets up the callback URL and client auth correctly."""
self.assertEqual(self.provider._callback_url, CALLBACK_URL)
self.assertEqual(self.provider._client_auth.client_id, CLIENT_ID)
self.assertEqual(self.provider._client_auth.client_secret, CLIENT_SECRET)
@override_config({"oidc_config": {**DEFAULT_CONFIG, "discover": True}})
- def test_discovery(self):
+ def test_discovery(self) -> None:
"""The handler should discover the endpoints from OIDC discovery document."""
# This would throw if some metadata were invalid
metadata = self.get_success(self.provider.load_metadata())
@@ -219,13 +225,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.http_client.get_json.assert_not_called()
@override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
- def test_no_discovery(self):
+ def test_no_discovery(self) -> None:
"""When discovery is disabled, it should not try to load from discovery document."""
self.get_success(self.provider.load_metadata())
self.http_client.get_json.assert_not_called()
@override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
- def test_load_jwks(self):
+ def test_load_jwks(self) -> None:
"""JWKS loading is done once (then cached) if used."""
jwks = self.get_success(self.provider.load_jwks())
self.http_client.get_json.assert_called_once_with(JWKS_URI)
@@ -253,7 +259,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.get_failure(self.provider.load_jwks(force=True), RuntimeError)
@override_config({"oidc_config": DEFAULT_CONFIG})
- def test_validate_config(self):
+ def test_validate_config(self) -> None:
"""Provider metadatas are extensively validated."""
h = self.provider
@@ -336,14 +342,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
force_load_metadata()
@override_config({"oidc_config": {**DEFAULT_CONFIG, "skip_verification": True}})
- def test_skip_verification(self):
+ def test_skip_verification(self) -> None:
"""Provider metadata validation can be disabled by config."""
with self.metadata_edit({"issuer": "http://insecure"}):
# This should not throw
get_awaitable_result(self.provider.load_metadata())
@override_config({"oidc_config": DEFAULT_CONFIG})
- def test_redirect_request(self):
+ def test_redirect_request(self) -> None:
"""The redirect request has the right arguments & generates a valid session cookie."""
req = Mock(spec=["cookies"])
req.cookies = []
@@ -387,7 +393,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertEqual(redirect, "http://client/redirect")
@override_config({"oidc_config": DEFAULT_CONFIG})
- def test_callback_error(self):
+ def test_callback_error(self) -> None:
"""Errors from the provider returned in the callback are displayed."""
request = Mock(args={})
request.args[b"error"] = [b"invalid_client"]
@@ -399,7 +405,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertRenderedError("invalid_client", "some description")
@override_config({"oidc_config": DEFAULT_CONFIG})
- def test_callback(self):
+ def test_callback(self) -> None:
"""Code callback works and display errors if something went wrong.
A lot of scenarios are tested here:
@@ -428,9 +434,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
"username": username,
}
expected_user_id = "@%s:%s" % (username, self.hs.hostname)
- self.provider._exchange_code = simple_async_mock(return_value=token)
- self.provider._parse_id_token = simple_async_mock(return_value=userinfo)
- self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
+ self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
+ self.provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
+ self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
@@ -468,7 +474,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertRenderedError("mapping_error")
# Handle ID token errors
- self.provider._parse_id_token = simple_async_mock(raises=Exception())
+ self.provider._parse_id_token = simple_async_mock(raises=Exception()) # type: ignore[assignment]
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_token")
@@ -483,7 +489,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
"type": "bearer",
"access_token": "access_token",
}
- self.provider._exchange_code = simple_async_mock(return_value=token)
+ self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
self.get_success(self.handler.handle_oidc_callback(request))
auth_handler.complete_sso_login.assert_called_once_with(
@@ -510,8 +516,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
id_token = {
"sid": "abcdefgh",
}
- self.provider._parse_id_token = simple_async_mock(return_value=id_token)
- self.provider._exchange_code = simple_async_mock(return_value=token)
+ self.provider._parse_id_token = simple_async_mock(return_value=id_token) # type: ignore[assignment]
+ self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
auth_handler.complete_sso_login.reset_mock()
self.provider._fetch_userinfo.reset_mock()
self.get_success(self.handler.handle_oidc_callback(request))
@@ -531,21 +537,21 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.render_error.assert_not_called()
# Handle userinfo fetching error
- self.provider._fetch_userinfo = simple_async_mock(raises=Exception())
+ self.provider._fetch_userinfo = simple_async_mock(raises=Exception()) # type: ignore[assignment]
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("fetch_error")
# Handle code exchange failure
from synapse.handlers.oidc import OidcError
- self.provider._exchange_code = simple_async_mock(
+ self.provider._exchange_code = simple_async_mock( # type: ignore[assignment]
raises=OidcError("invalid_request")
)
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_request")
@override_config({"oidc_config": DEFAULT_CONFIG})
- def test_callback_session(self):
+ def test_callback_session(self) -> None:
"""The callback verifies the session presence and validity"""
request = Mock(spec=["args", "getCookie", "cookies"])
@@ -590,7 +596,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
@override_config(
{"oidc_config": {**DEFAULT_CONFIG, "client_auth_method": "client_secret_post"}}
)
- def test_exchange_code(self):
+ def test_exchange_code(self) -> None:
"""Code exchange behaves correctly and handles various error scenarios."""
token = {"type": "bearer"}
token_json = json.dumps(token).encode("utf-8")
@@ -686,7 +692,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
}
)
- def test_exchange_code_jwt_key(self):
+ def test_exchange_code_jwt_key(self) -> None:
"""Test that code exchange works with a JWK client secret."""
from authlib.jose import jwt
@@ -741,7 +747,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
}
)
- def test_exchange_code_no_auth(self):
+ def test_exchange_code_no_auth(self) -> None:
"""Test that code exchange works with no client secret."""
token = {"type": "bearer"}
self.http_client.request = simple_async_mock(
@@ -776,7 +782,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
}
)
- def test_extra_attributes(self):
+ def test_extra_attributes(self) -> None:
"""
Login while using a mapping provider that implements get_extra_attributes.
"""
@@ -790,8 +796,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
"username": "foo",
"phone": "1234567",
}
- self.provider._exchange_code = simple_async_mock(return_value=token)
- self.provider._parse_id_token = simple_async_mock(return_value=userinfo)
+ self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
+ self.provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
@@ -817,12 +823,12 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
@override_config({"oidc_config": DEFAULT_CONFIG})
- def test_map_userinfo_to_user(self):
+ def test_map_userinfo_to_user(self) -> None:
"""Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
- userinfo = {
+ userinfo: dict = {
"sub": "test_user",
"username": "test_user",
}
@@ -870,7 +876,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
@override_config({"oidc_config": {**DEFAULT_CONFIG, "allow_existing_users": True}})
- def test_map_userinfo_to_existing_user(self):
+ def test_map_userinfo_to_existing_user(self) -> None:
"""Existing users can log in with OpenID Connect when allow_existing_users is True."""
store = self.hs.get_datastores().main
user = UserID.from_string("@test_user:test")
@@ -974,7 +980,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
@override_config({"oidc_config": DEFAULT_CONFIG})
- def test_map_userinfo_to_invalid_localpart(self):
+ def test_map_userinfo_to_invalid_localpart(self) -> None:
"""If the mapping provider generates an invalid localpart it should be rejected."""
self.get_success(
_make_callback_with_userinfo(self.hs, {"sub": "test2", "username": "föö"})
@@ -991,7 +997,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
}
)
- def test_map_userinfo_to_user_retries(self):
+ def test_map_userinfo_to_user_retries(self) -> None:
"""The mapping provider can retry generating an MXID if the MXID is already in use."""
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
@@ -1039,7 +1045,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
@override_config({"oidc_config": DEFAULT_CONFIG})
- def test_empty_localpart(self):
+ def test_empty_localpart(self) -> None:
"""Attempts to map onto an empty localpart should be rejected."""
userinfo = {
"sub": "tester",
@@ -1058,7 +1064,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
}
)
- def test_null_localpart(self):
+ def test_null_localpart(self) -> None:
"""Mapping onto a null localpart via an empty OIDC attribute should be rejected"""
userinfo = {
"sub": "tester",
@@ -1075,7 +1081,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
}
)
- def test_attribute_requirements(self):
+ def test_attribute_requirements(self) -> None:
"""The required attributes must be met from the OIDC userinfo response."""
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
@@ -1115,7 +1121,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
}
)
- def test_attribute_requirements_contains(self):
+ def test_attribute_requirements_contains(self) -> None:
"""Test that auth succeeds if userinfo attribute CONTAINS required value"""
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
@@ -1146,7 +1152,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
}
)
- def test_attribute_requirements_mismatch(self):
+ def test_attribute_requirements_mismatch(self) -> None:
"""
Test that auth fails if attributes exist but don't match,
or are non-string values.
@@ -1154,7 +1160,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
# userinfo with "test": "not_foobar" attribute should fail
- userinfo = {
+ userinfo: dict = {
"sub": "tester",
"username": "tester",
"test": "not_foobar",
@@ -1248,9 +1254,9 @@ async def _make_callback_with_userinfo(
handler = hs.get_oidc_handler()
provider = handler._providers["oidc"]
- provider._exchange_code = simple_async_mock(return_value={"id_token": ""})
- provider._parse_id_token = simple_async_mock(return_value=userinfo)
- provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
+ provider._exchange_code = simple_async_mock(return_value={"id_token": ""}) # type: ignore[assignment]
+ provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
+ provider._fetch_userinfo = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
state = "state"
session = handler._token_generator.generate_oidc_session_token(
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index 49d832de81..d401fda938 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -124,7 +124,6 @@ class PasswordCustomAuthProvider:
("m.login.password", ("password",)): self.check_auth,
}
)
- pass
def check_auth(self, *args):
return mock_password_provider.check_auth(*args)
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 6ddec9ecf1..b2ed9cbe37 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -331,11 +331,11 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
# Extract presence update user ID and state information into lists of tuples
db_presence_states = [(ps[0], ps[1]) for _, ps in db_presence_states[0]]
- presence_states = [(ps.user_id, ps.state) for ps in presence_states]
+ presence_states_compare = [(ps.user_id, ps.state) for ps in presence_states]
# Compare what we put into the storage with what we got out.
# They should be identical.
- self.assertEqual(presence_states, db_presence_states)
+ self.assertEqual(presence_states_compare, db_presence_states)
class PresenceTimeoutTestCase(unittest.TestCase):
@@ -357,6 +357,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
self.assertIsNotNone(new_state)
+ assert new_state is not None
self.assertEqual(new_state.state, PresenceState.UNAVAILABLE)
self.assertEqual(new_state.status_msg, status_msg)
@@ -380,6 +381,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
self.assertIsNotNone(new_state)
+ assert new_state is not None
self.assertEqual(new_state.state, PresenceState.BUSY)
self.assertEqual(new_state.status_msg, status_msg)
@@ -399,6 +401,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
self.assertIsNotNone(new_state)
+ assert new_state is not None
self.assertEqual(new_state.state, PresenceState.OFFLINE)
self.assertEqual(new_state.status_msg, status_msg)
@@ -420,6 +423,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
)
self.assertIsNotNone(new_state)
+ assert new_state is not None
self.assertEqual(new_state.state, PresenceState.ONLINE)
self.assertEqual(new_state.status_msg, status_msg)
@@ -477,6 +481,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
)
self.assertIsNotNone(new_state)
+ assert new_state is not None
self.assertEqual(new_state.state, PresenceState.OFFLINE)
self.assertEqual(new_state.status_msg, status_msg)
@@ -653,13 +658,13 @@ class PresenceHandlerTestCase(unittest.HomeserverTestCase):
self._set_presencestate_with_status_msg(user_id, PresenceState.ONLINE, None)
def _set_presencestate_with_status_msg(
- self, user_id: str, state: PresenceState, status_msg: Optional[str]
+ self, user_id: str, state: str, status_msg: Optional[str]
):
"""Set a PresenceState and status_msg and check the result.
Args:
user_id: User for that the status is to be set.
- PresenceState: The new PresenceState.
+ state: The new PresenceState.
status_msg: Status message that is to be set.
"""
self.get_success(
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 972cbac6e4..1ec105c373 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -11,14 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict
+from typing import Any, Awaitable, Callable, Dict
from unittest.mock import Mock
+from twisted.test.proto_helpers import MemoryReactor
+
import synapse.types
from synapse.api.errors import AuthError, SynapseError
from synapse.rest import admin
from synapse.server import HomeServer
-from synapse.types import UserID
+from synapse.types import JsonDict, UserID
+from synapse.util import Clock
from tests import unittest
from tests.test_utils import make_awaitable
@@ -29,13 +32,15 @@ class ProfileTestCase(unittest.HomeserverTestCase):
servlets = [admin.register_servlets]
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.mock_federation = Mock()
self.mock_registry = Mock()
- self.query_handlers = {}
+ self.query_handlers: Dict[str, Callable[[dict], Awaitable[JsonDict]]] = {}
- def register_query_handler(query_type, handler):
+ def register_query_handler(
+ query_type: str, handler: Callable[[dict], Awaitable[JsonDict]]
+ ) -> None:
self.query_handlers[query_type] = handler
self.mock_registry.register_query_handler = register_query_handler
@@ -47,7 +52,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
)
return hs
- def prepare(self, reactor, clock, hs: HomeServer):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.frank = UserID.from_string("@1234abcd:test")
@@ -58,7 +63,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.handler = hs.get_profile_handler()
- def test_get_my_name(self):
+ def test_get_my_name(self) -> None:
self.get_success(
self.store.set_profile_displayname(self.frank.localpart, "Frank")
)
@@ -67,7 +72,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertEqual("Frank", displayname)
- def test_set_my_name(self):
+ def test_set_my_name(self) -> None:
self.get_success(
self.handler.set_displayname(
self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
@@ -110,7 +115,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.get_success(self.store.get_profile_displayname(self.frank.localpart))
)
- def test_set_my_name_if_disabled(self):
+ def test_set_my_name_if_disabled(self) -> None:
self.hs.config.registration.enable_set_displayname = False
# Setting displayname for the first time is allowed
@@ -135,7 +140,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
SynapseError,
)
- def test_set_my_name_noauth(self):
+ def test_set_my_name_noauth(self) -> None:
self.get_failure(
self.handler.set_displayname(
self.frank, synapse.types.create_requester(self.bob), "Frank Jr."
@@ -143,7 +148,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
AuthError,
)
- def test_get_other_name(self):
+ def test_get_other_name(self) -> None:
self.mock_federation.make_query.return_value = make_awaitable(
{"displayname": "Alice"}
)
@@ -158,7 +163,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
ignore_backoff=True,
)
- def test_incoming_fed_query(self):
+ def test_incoming_fed_query(self) -> None:
self.get_success(self.store.create_profile("caroline"))
self.get_success(self.store.set_profile_displayname("caroline", "Caroline"))
@@ -174,7 +179,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertEqual({"displayname": "Caroline"}, response)
- def test_get_my_avatar(self):
+ def test_get_my_avatar(self) -> None:
self.get_success(
self.store.set_profile_avatar_url(
self.frank.localpart, "http://my.server/me.png"
@@ -184,7 +189,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertEqual("http://my.server/me.png", avatar_url)
- def test_set_my_avatar(self):
+ def test_set_my_avatar(self) -> None:
self.get_success(
self.handler.set_avatar_url(
self.frank,
@@ -225,7 +230,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
(self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
)
- def test_set_my_avatar_if_disabled(self):
+ def test_set_my_avatar_if_disabled(self) -> None:
self.hs.config.registration.enable_set_avatar_url = False
# Setting displayname for the first time is allowed
@@ -250,7 +255,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
SynapseError,
)
- def test_avatar_constraints_no_config(self):
+ def test_avatar_constraints_no_config(self) -> None:
"""Tests that the method to check an avatar against configured constraints skips
all of its check if no constraint is configured.
"""
@@ -263,7 +268,13 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertTrue(res)
@unittest.override_config({"max_avatar_size": 50})
- def test_avatar_constraints_missing(self):
+ def test_avatar_constraints_allow_empty_avatar_url(self) -> None:
+ """An empty avatar is always permitted."""
+ res = self.get_success(self.handler.check_avatar_size_and_mime_type(""))
+ self.assertTrue(res)
+
+ @unittest.override_config({"max_avatar_size": 50})
+ def test_avatar_constraints_missing(self) -> None:
"""Tests that an avatar isn't allowed if the file at the given MXC URI couldn't
be found.
"""
@@ -273,7 +284,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertFalse(res)
@unittest.override_config({"max_avatar_size": 50})
- def test_avatar_constraints_file_size(self):
+ def test_avatar_constraints_file_size(self) -> None:
"""Tests that a file that's above the allowed file size is forbidden but one
that's below it is allowed.
"""
@@ -295,7 +306,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertFalse(res)
@unittest.override_config({"allowed_avatar_mimetypes": ["image/png"]})
- def test_avatar_constraint_mime_type(self):
+ def test_avatar_constraint_mime_type(self) -> None:
"""Tests that a file with an unauthorised MIME type is forbidden but one with
an authorised content type is allowed.
"""
diff --git a/tests/handlers/test_room_summary.py b/tests/handlers/test_room_summary.py
index cff07a8973..d37292ce13 100644
--- a/tests/handlers/test_room_summary.py
+++ b/tests/handlers/test_room_summary.py
@@ -172,6 +172,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
result_room_ids = []
result_children_ids = []
for result_room in result["rooms"]:
+ # Ensure federation results are not leaking over the client-server API.
+ self.assertNotIn("allowed_room_ids", result_room)
+
result_room_ids.append(result_room["room_id"])
result_children_ids.append(
[
diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py
index 23941abed8..8d4404eda1 100644
--- a/tests/handlers/test_saml.py
+++ b/tests/handlers/test_saml.py
@@ -12,12 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Optional
+from typing import Any, Dict, Optional
from unittest.mock import Mock
import attr
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.api.errors import RedirectException
+from synapse.server import HomeServer
+from synapse.util import Clock
from tests.test_utils import simple_async_mock
from tests.unittest import HomeserverTestCase, override_config
@@ -81,10 +85,10 @@ class TestRedirectMappingProvider(TestMappingProvider):
class SamlHandlerTestCase(HomeserverTestCase):
- def default_config(self):
+ def default_config(self) -> Dict[str, Any]:
config = super().default_config()
config["public_baseurl"] = BASE_URL
- saml_config = {
+ saml_config: Dict[str, Any] = {
"sp_config": {"metadata": {}},
# Disable grandfathering.
"grandfathered_mxid_source_attribute": None,
@@ -98,7 +102,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
return config
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver()
self.handler = hs.get_saml_handler()
@@ -114,7 +118,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
elif not has_xmlsec1:
skip = "Requires xmlsec1"
- def test_map_saml_response_to_user(self):
+ def test_map_saml_response_to_user(self) -> None:
"""Ensure that mapping the SAML response returned from a provider to an MXID works properly."""
# stub out the auth handler
@@ -140,7 +144,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
)
@override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
- def test_map_saml_response_to_existing_user(self):
+ def test_map_saml_response_to_existing_user(self) -> None:
"""Existing users can log in with SAML account."""
store = self.hs.get_datastores().main
self.get_success(
@@ -186,7 +190,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
auth_provider_session_id=None,
)
- def test_map_saml_response_to_invalid_localpart(self):
+ def test_map_saml_response_to_invalid_localpart(self) -> None:
"""If the mapping provider generates an invalid localpart it should be rejected."""
# stub out the auth handler
@@ -207,7 +211,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
)
auth_handler.complete_sso_login.assert_not_called()
- def test_map_saml_response_to_user_retries(self):
+ def test_map_saml_response_to_user_retries(self) -> None:
"""The mapping provider can retry generating an MXID if the MXID is already in use."""
# stub out the auth handler and error renderer
@@ -271,7 +275,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
}
}
)
- def test_map_saml_response_redirect(self):
+ def test_map_saml_response_redirect(self) -> None:
"""Test a mapping provider that raises a RedirectException"""
saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"})
@@ -292,7 +296,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
},
}
)
- def test_attribute_requirements(self):
+ def test_attribute_requirements(self) -> None:
"""The required attributes must be met from the SAML response."""
# stub out the auth handler
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index f91a80b9fa..ffd5c4cb93 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -18,11 +18,14 @@ from typing import Dict
from unittest.mock import ANY, Mock, call
from twisted.internet import defer
+from twisted.test.proto_helpers import MemoryReactor
from twisted.web.resource import Resource
from synapse.api.errors import AuthError
from synapse.federation.transport.server import TransportLayerServer
-from synapse.types import UserID, create_requester
+from synapse.server import HomeServer
+from synapse.types import JsonDict, UserID, create_requester
+from synapse.util import Clock
from tests import unittest
from tests.test_utils import make_awaitable
@@ -42,7 +45,9 @@ ROOM_ID = "a-room"
OTHER_ROOM_ID = "another-room"
-def _expect_edu_transaction(edu_type, content, origin="test"):
+def _expect_edu_transaction(
+ edu_type: str, content: JsonDict, origin: str = "test"
+) -> JsonDict:
return {
"origin": origin,
"origin_server_ts": 1000000,
@@ -51,12 +56,12 @@ def _expect_edu_transaction(edu_type, content, origin="test"):
}
-def _make_edu_transaction_json(edu_type, content):
+def _make_edu_transaction_json(edu_type: str, content: JsonDict) -> bytes:
return json.dumps(_expect_edu_transaction(edu_type, content)).encode("utf8")
class TypingNotificationsTestCase(unittest.HomeserverTestCase):
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
# we mock out the keyring so as to skip the authentication check on the
# federation API call.
mock_keyring = Mock(spec=["verify_json_for_server"])
@@ -83,7 +88,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
d["/_matrix/federation"] = TransportLayerServer(self.hs)
return d
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
mock_notifier = hs.get_notifier()
self.on_new_event = mock_notifier.on_new_event
@@ -111,24 +116,24 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.room_members = []
- async def check_user_in_room(room_id, user_id):
+ async def check_user_in_room(room_id: str, user_id: str) -> None:
if user_id not in [u.to_string() for u in self.room_members]:
raise AuthError(401, "User is not in the room")
return None
hs.get_auth().check_user_in_room = check_user_in_room
- async def check_host_in_room(room_id, server_name):
+ async def check_host_in_room(room_id: str, server_name: str) -> bool:
return room_id == ROOM_ID
hs.get_event_auth_handler().check_host_in_room = check_host_in_room
- def get_joined_hosts_for_room(room_id):
+ def get_joined_hosts_for_room(room_id: str):
return {member.domain for member in self.room_members}
self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room
- async def get_users_in_room(room_id):
+ async def get_users_in_room(room_id: str):
return {str(u) for u in self.room_members}
self.datastore.get_users_in_room = get_users_in_room
@@ -153,7 +158,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
lambda *args, **kwargs: make_awaitable(None)
)
- def test_started_typing_local(self):
+ def test_started_typing_local(self) -> None:
self.room_members = [U_APPLE, U_BANANA]
self.assertEqual(self.event_source.get_current_key(), 0)
@@ -187,7 +192,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
)
@override_config({"send_federation": True})
- def test_started_typing_remote_send(self):
+ def test_started_typing_remote_send(self) -> None:
self.room_members = [U_APPLE, U_ONION]
self.get_success(
@@ -217,7 +222,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
try_trailing_slash_on_400=True,
)
- def test_started_typing_remote_recv(self):
+ def test_started_typing_remote_recv(self) -> None:
self.room_members = [U_APPLE, U_ONION]
self.assertEqual(self.event_source.get_current_key(), 0)
@@ -256,7 +261,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
],
)
- def test_started_typing_remote_recv_not_in_room(self):
+ def test_started_typing_remote_recv_not_in_room(self) -> None:
self.room_members = [U_APPLE, U_ONION]
self.assertEqual(self.event_source.get_current_key(), 0)
@@ -292,7 +297,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.assertEqual(events[1], 0)
@override_config({"send_federation": True})
- def test_stopped_typing(self):
+ def test_stopped_typing(self) -> None:
self.room_members = [U_APPLE, U_BANANA, U_ONION]
# Gut-wrenching
@@ -343,7 +348,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
[{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}],
)
- def test_typing_timeout(self):
+ def test_typing_timeout(self) -> None:
self.room_members = [U_APPLE, U_BANANA]
self.assertEqual(self.event_source.get_current_key(), 0)
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index c3f20f9692..10dd94b549 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -86,6 +86,16 @@ class ModuleApiTestCase(HomeserverTestCase):
displayname = self.get_success(self.store.get_profile_displayname("bob"))
self.assertEqual(displayname, "Bobberino")
+ def test_can_register_admin_user(self):
+ user_id = self.get_success(
+ self.register_user(
+ "bob_module_admin", "1234", displayname="Bobberino Admin", admin=True
+ )
+ )
+ found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id))
+ self.assertEqual(found_user.user_id.to_string(), user_id)
+ self.assertIdentical(found_user.is_admin, True)
+
def test_get_userinfo_by_id(self):
user_id = self.register_user("alice", "1234")
found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id))
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index c284beb37c..ba158f5d93 100644
--- a/tests/push/test_http.py
+++ b/tests/push/test_http.py
@@ -11,14 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Any, Dict, List, Optional, Tuple
from unittest.mock import Mock
from twisted.internet.defer import Deferred
+from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
from synapse.logging.context import make_deferred_yieldable
from synapse.push import PusherConfigException
-from synapse.rest.client import login, receipts, room
+from synapse.rest.client import login, push_rule, receipts, room
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+from synapse.util import Clock
from tests.unittest import HomeserverTestCase, override_config
@@ -29,22 +34,23 @@ class HTTPPusherTests(HomeserverTestCase):
room.register_servlets,
login.register_servlets,
receipts.register_servlets,
+ push_rule.register_servlets,
]
user_id = True
hijack_auth = False
- def default_config(self):
+ def default_config(self) -> Dict[str, Any]:
config = super().default_config()
config["start_pushers"] = True
return config
- def make_homeserver(self, reactor, clock):
- self.push_attempts = []
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+ self.push_attempts: List[Tuple[Deferred, str, dict]] = []
m = Mock()
def post_json_get_json(url, body):
- d = Deferred()
+ d: Deferred = Deferred()
self.push_attempts.append((d, url, body))
return make_deferred_yieldable(d)
@@ -54,7 +60,7 @@ class HTTPPusherTests(HomeserverTestCase):
return hs
- def test_invalid_configuration(self):
+ def test_invalid_configuration(self) -> None:
"""Invalid push configurations should be rejected."""
# Register the user who gets notified
user_id = self.register_user("user", "pass")
@@ -66,7 +72,7 @@ class HTTPPusherTests(HomeserverTestCase):
)
token_id = user_tuple.token_id
- def test_data(data):
+ def test_data(data: Optional[JsonDict]) -> None:
self.get_failure(
self.hs.get_pusherpool().add_pusher(
user_id=user_id,
@@ -93,7 +99,7 @@ class HTTPPusherTests(HomeserverTestCase):
# A url with an incorrect path isn't accepted.
test_data({"url": "http://example.com/foo"})
- def test_sends_http(self):
+ def test_sends_http(self) -> None:
"""
The HTTP pusher will send pushes for each message to a HTTP endpoint
when configured to do so.
@@ -198,7 +204,7 @@ class HTTPPusherTests(HomeserverTestCase):
self.assertEqual(len(pushers), 1)
self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering)
- def test_sends_high_priority_for_encrypted(self):
+ def test_sends_high_priority_for_encrypted(self) -> None:
"""
The HTTP pusher will send pushes at high priority if they correspond
to an encrypted message.
@@ -319,7 +325,7 @@ class HTTPPusherTests(HomeserverTestCase):
)
self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "high")
- def test_sends_high_priority_for_one_to_one_only(self):
+ def test_sends_high_priority_for_one_to_one_only(self) -> None:
"""
The HTTP pusher will send pushes at high priority if they correspond
to a message in a one-to-one room.
@@ -402,7 +408,7 @@ class HTTPPusherTests(HomeserverTestCase):
# check that this is low-priority
self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low")
- def test_sends_high_priority_for_mention(self):
+ def test_sends_high_priority_for_mention(self) -> None:
"""
The HTTP pusher will send pushes at high priority if they correspond
to a message containing the user's display name.
@@ -478,7 +484,7 @@ class HTTPPusherTests(HomeserverTestCase):
# check that this is low-priority
self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low")
- def test_sends_high_priority_for_atroom(self):
+ def test_sends_high_priority_for_atroom(self) -> None:
"""
The HTTP pusher will send pushes at high priority if they correspond
to a message that contains @room.
@@ -561,7 +567,7 @@ class HTTPPusherTests(HomeserverTestCase):
# check that this is low-priority
self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low")
- def test_push_unread_count_group_by_room(self):
+ def test_push_unread_count_group_by_room(self) -> None:
"""
The HTTP pusher will group unread count by number of unread rooms.
"""
@@ -574,7 +580,7 @@ class HTTPPusherTests(HomeserverTestCase):
self._check_push_attempt(6, 1)
@override_config({"push": {"group_unread_count_by_room": False}})
- def test_push_unread_count_message_count(self):
+ def test_push_unread_count_message_count(self) -> None:
"""
The HTTP pusher will send the total unread message count.
"""
@@ -587,7 +593,7 @@ class HTTPPusherTests(HomeserverTestCase):
# last read receipt
self._check_push_attempt(6, 3)
- def _test_push_unread_count(self):
+ def _test_push_unread_count(self) -> None:
"""
Tests that the correct unread count appears in sent push notifications
@@ -679,7 +685,7 @@ class HTTPPusherTests(HomeserverTestCase):
self.helper.send(room_id, body="HELLO???", tok=other_access_token)
- def _advance_time_and_make_push_succeed(self, expected_push_attempts):
+ def _advance_time_and_make_push_succeed(self, expected_push_attempts: int) -> None:
self.pump()
self.push_attempts[expected_push_attempts - 1][0].callback({})
@@ -706,7 +712,9 @@ class HTTPPusherTests(HomeserverTestCase):
expected_unread_count_last_push,
)
- def _send_read_request(self, access_token, message_event_id, room_id):
+ def _send_read_request(
+ self, access_token: str, message_event_id: str, room_id: str
+ ) -> None:
# Now set the user's read receipt position to the first event
#
# This will actually trigger a new notification to be sent out so that
@@ -719,3 +727,67 @@ class HTTPPusherTests(HomeserverTestCase):
access_token=access_token,
)
self.assertEqual(channel.code, 200, channel.json_body)
+
+ def _make_user_with_pusher(self, username: str) -> Tuple[str, str]:
+ user_id = self.register_user(username, "pass")
+ access_token = self.login(username, "pass")
+
+ # Register the pusher
+ user_tuple = self.get_success(
+ self.hs.get_datastores().main.get_user_by_access_token(access_token)
+ )
+ token_id = user_tuple.token_id
+
+ self.get_success(
+ self.hs.get_pusherpool().add_pusher(
+ user_id=user_id,
+ access_token=token_id,
+ kind="http",
+ app_id="m.http",
+ app_display_name="HTTP Push Notifications",
+ device_display_name="pushy push",
+ pushkey="a@example.com",
+ lang=None,
+ data={"url": "http://example.com/_matrix/push/v1/notify"},
+ )
+ )
+
+ return user_id, access_token
+
+ def test_dont_notify_rule_overrides_message(self) -> None:
+ """
+ The override push rule will suppress notification
+ """
+
+ user_id, access_token = self._make_user_with_pusher("user")
+ other_user_id, other_access_token = self._make_user_with_pusher("otheruser")
+
+ # Create a room
+ room = self.helper.create_room_as(user_id, tok=access_token)
+
+ # Disable user notifications for this room -> user
+ body = {
+ "conditions": [{"kind": "event_match", "key": "room_id", "pattern": room}],
+ "actions": ["dont_notify"],
+ }
+ channel = self.make_request(
+ "PUT",
+ "/pushrules/global/override/best.friend",
+ body,
+ access_token=access_token,
+ )
+ self.assertEqual(channel.code, 200)
+
+ # Check we start with no pushes
+ self.assertEqual(len(self.push_attempts), 0)
+
+ # The other user joins
+ self.helper.join(room=room, user=other_user_id, tok=other_access_token)
+
+ # The other user sends a message (ignored by dont_notify push rule set above)
+ self.helper.send(room, body="Hi!", tok=other_access_token)
+ self.assertEqual(len(self.push_attempts), 0)
+
+ # The user sends a message back (sends a notification)
+ self.helper.send(room, body="Hello", tok=access_token)
+ self.assertEqual(len(self.push_attempts), 1)
diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py
index 3849beb9d6..5dba187076 100644
--- a/tests/push/test_push_rule_evaluator.py
+++ b/tests/push/test_push_rule_evaluator.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict
+from typing import Dict, Optional, Union
import frozendict
@@ -20,12 +20,13 @@ from synapse.api.room_versions import RoomVersions
from synapse.events import FrozenEvent
from synapse.push import push_rule_evaluator
from synapse.push.push_rule_evaluator import PushRuleEvaluatorForEvent
+from synapse.types import JsonDict
from tests import unittest
class PushRuleEvaluatorTestCase(unittest.TestCase):
- def _get_evaluator(self, content):
+ def _get_evaluator(self, content: JsonDict) -> PushRuleEvaluatorForEvent:
event = FrozenEvent(
{
"event_id": "$event_id",
@@ -39,12 +40,12 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
)
room_member_count = 0
sender_power_level = 0
- power_levels = {}
+ power_levels: Dict[str, Union[int, Dict[str, int]]] = {}
return PushRuleEvaluatorForEvent(
event, room_member_count, sender_power_level, power_levels
)
- def test_display_name(self):
+ def test_display_name(self) -> None:
"""Check for a matching display name in the body of the event."""
evaluator = self._get_evaluator({"body": "foo bar baz"})
@@ -71,20 +72,20 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
self.assertTrue(evaluator.matches(condition, "@user:test", "foo bar"))
def _assert_matches(
- self, condition: Dict[str, Any], content: Dict[str, Any], msg=None
+ self, condition: JsonDict, content: JsonDict, msg: Optional[str] = None
) -> None:
evaluator = self._get_evaluator(content)
self.assertTrue(evaluator.matches(condition, "@user:test", "display_name"), msg)
def _assert_not_matches(
- self, condition: Dict[str, Any], content: Dict[str, Any], msg=None
+ self, condition: JsonDict, content: JsonDict, msg: Optional[str] = None
) -> None:
evaluator = self._get_evaluator(content)
self.assertFalse(
evaluator.matches(condition, "@user:test", "display_name"), msg
)
- def test_event_match_body(self):
+ def test_event_match_body(self) -> None:
"""Check that event_match conditions on content.body work as expected"""
# if the key is `content.body`, the pattern matches substrings.
@@ -165,7 +166,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
r"? after \ should match any character",
)
- def test_event_match_non_body(self):
+ def test_event_match_non_body(self) -> None:
"""Check that event_match conditions on other keys work as expected"""
# if the key is anything other than 'content.body', the pattern must match the
@@ -241,7 +242,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
"pattern should not match before a newline",
)
- def test_no_body(self):
+ def test_no_body(self) -> None:
"""Not having a body shouldn't break the evaluator."""
evaluator = self._get_evaluator({})
@@ -250,7 +251,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
}
self.assertFalse(evaluator.matches(condition, "@user:test", "foo"))
- def test_invalid_body(self):
+ def test_invalid_body(self) -> None:
"""A non-string body should not break the evaluator."""
condition = {
"kind": "contains_display_name",
@@ -260,7 +261,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
evaluator = self._get_evaluator({"body": body})
self.assertFalse(evaluator.matches(condition, "@user:test", "foo"))
- def test_tweaks_for_actions(self):
+ def test_tweaks_for_actions(self) -> None:
"""
This tests the behaviour of tweaks_for_actions.
"""
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index a7a05a564f..9c5df266bd 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -251,7 +251,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
self.connect_any_redis_attempts,
)
- self.hs.get_tcp_replication().start_replication(self.hs)
+ self.hs.get_replication_command_handler().start_replication(self.hs)
# When we see a connection attempt to the master replication listener we
# automatically set up the connection. This is so that tests don't
@@ -375,7 +375,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
)
if worker_hs.config.redis.redis_enabled:
- worker_hs.get_tcp_replication().start_replication(worker_hs)
+ worker_hs.get_replication_command_handler().start_replication(worker_hs)
return worker_hs
diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py
index f9d5da723c..641a94133b 100644
--- a/tests/replication/tcp/streams/test_events.py
+++ b/tests/replication/tcp/streams/test_events.py
@@ -420,7 +420,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
# Manually send an old RDATA command, which should get dropped. This
# re-uses the row from above, but with an earlier stream token.
- self.hs.get_tcp_replication().send_command(
+ self.hs.get_replication_command_handler().send_command(
RdataCommand("events", "master", 1, row)
)
diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py
index 3ff5afc6e5..9a229dd23f 100644
--- a/tests/replication/tcp/streams/test_typing.py
+++ b/tests/replication/tcp/streams/test_typing.py
@@ -118,7 +118,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
# Reset the typing handler
self.hs.get_replication_streams()["typing"].last_token = 0
- self.hs.get_tcp_replication()._streams["typing"].last_token = 0
+ self.hs.get_replication_command_handler()._streams["typing"].last_token = 0
typing._latest_room_serial = 0
typing._typing_stream_change_cache = StreamChangeCache(
"TypingStreamChangeCache", typing._latest_room_serial
diff --git a/tests/replication/test_federation_ack.py b/tests/replication/test_federation_ack.py
index 1b6a4bf4b0..26b8bd512a 100644
--- a/tests/replication/test_federation_ack.py
+++ b/tests/replication/test_federation_ack.py
@@ -48,7 +48,7 @@ class FederationAckTestCase(HomeserverTestCase):
transport, rather than assuming that the implementation has a
ReplicationCommandHandler.
"""
- rch = self.hs.get_tcp_replication()
+ rch = self.hs.get_replication_command_handler()
# wire up the ReplicationCommandHandler to a mock connection, which needs
# to implement IReplicationConnection. (Note that Mock doesn't understand
diff --git a/tests/rest/admin/test_background_updates.py b/tests/rest/admin/test_background_updates.py
index fb36aa9940..6cf56b1e35 100644
--- a/tests/rest/admin/test_background_updates.py
+++ b/tests/rest/admin/test_background_updates.py
@@ -39,6 +39,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
self.store = hs.get_datastores().main
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
+ self.updater = BackgroundUpdater(hs, self.store.db_pool)
@parameterized.expand(
[
@@ -135,10 +136,10 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"""Test the status API works with a background update."""
# Create a new background update
-
self._register_bg_update()
self.store.db_pool.updates.start_doing_background_updates()
+
self.reactor.pump([1.0, 1.0, 1.0])
channel = self.make_request(
@@ -155,10 +156,10 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"current_updates": {
"master": {
"name": "test_update",
- "average_items_per_ms": 0.001,
+ "average_items_per_ms": 0.1,
"total_duration_ms": 1000.0,
"total_item_count": (
- BackgroundUpdater.MINIMUM_BACKGROUND_BATCH_SIZE
+ self.updater.default_background_batch_size
),
}
},
@@ -210,10 +211,10 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"current_updates": {
"master": {
"name": "test_update",
- "average_items_per_ms": 0.001,
+ "average_items_per_ms": 0.1,
"total_duration_ms": 1000.0,
"total_item_count": (
- BackgroundUpdater.MINIMUM_BACKGROUND_BATCH_SIZE
+ self.updater.default_background_batch_size
),
}
},
@@ -239,10 +240,10 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"current_updates": {
"master": {
"name": "test_update",
- "average_items_per_ms": 0.001,
+ "average_items_per_ms": 0.1,
"total_duration_ms": 1000.0,
"total_item_count": (
- BackgroundUpdater.MINIMUM_BACKGROUND_BATCH_SIZE
+ self.updater.default_background_batch_size
),
}
},
@@ -278,11 +279,9 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"current_updates": {
"master": {
"name": "test_update",
- "average_items_per_ms": 0.001,
+ "average_items_per_ms": 0.05263157894736842,
"total_duration_ms": 2000.0,
- "total_item_count": (
- 2 * BackgroundUpdater.MINIMUM_BACKGROUND_BATCH_SIZE
- ),
+ "total_item_count": (110),
}
},
"enabled": True,
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index a60ea0a563..bef911d5df 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -1050,6 +1050,25 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
self._is_erased("@user:test", True)
+ @override_config({"max_avatar_size": 1234})
+ def test_deactivate_user_erase_true_avatar_nonnull_but_empty(self) -> None:
+ """Check we can erase a user whose avatar is the empty string.
+
+ Reproduces #12257.
+ """
+ # Patch `self.other_user` to have an empty string as their avatar.
+ self.get_success(self.store.set_profile_avatar_url("user", ""))
+
+ # Check we can still erase them.
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={"erase": True},
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self._is_erased("@user:test", True)
+
def test_deactivate_user_erase_false(self) -> None:
"""
Test deactivating a user and set `erase` to `false`
diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py
index def836054d..27946febff 100644
--- a/tests/rest/client/test_account.py
+++ b/tests/rest/client/test_account.py
@@ -31,7 +31,7 @@ from synapse.rest import admin
from synapse.rest.client import account, login, register, room
from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource
from synapse.server import HomeServer
-from synapse.types import JsonDict
+from synapse.types import JsonDict, UserID
from synapse.util import Clock
from tests import unittest
@@ -1222,6 +1222,62 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
expected_failures=[users[2]],
)
+ @unittest.override_config(
+ {
+ "use_account_validity_in_account_status": True,
+ }
+ )
+ def test_no_account_validity(self) -> None:
+ """Tests that if we decide to include account validity in the response but no
+ account validity 'is_user_expired' callback is provided, we default to marking all
+ users as not expired.
+ """
+ user = self.register_user("someuser", "password")
+
+ self._test_status(
+ users=[user],
+ expected_statuses={
+ user: {
+ "exists": True,
+ "deactivated": False,
+ "org.matrix.expired": False,
+ },
+ },
+ expected_failures=[],
+ )
+
+ @unittest.override_config(
+ {
+ "use_account_validity_in_account_status": True,
+ }
+ )
+ def test_account_validity_expired(self) -> None:
+ """Test that if we decide to include account validity in the response and the user
+ is expired, we return the correct info.
+ """
+ user = self.register_user("someuser", "password")
+
+ async def is_expired(user_id: str) -> bool:
+ # We can't blindly say everyone is expired, otherwise the request to get the
+ # account status will fail.
+ return UserID.from_string(user_id).localpart == "someuser"
+
+ self.hs.get_account_validity_handler()._is_user_expired_callbacks.append(
+ is_expired
+ )
+
+ self._test_status(
+ users=[user],
+ expected_statuses={
+ user: {
+ "exists": True,
+ "deactivated": False,
+ "org.matrix.expired": True,
+ },
+ },
+ expected_failures=[],
+ )
+
def _test_status(
self,
users: Optional[List[str]],
diff --git a/tests/rest/client/test_shared_rooms.py b/tests/rest/client/test_mutual_rooms.py
index 3818b7b14b..7b7d283bb6 100644
--- a/tests/rest/client/test_shared_rooms.py
+++ b/tests/rest/client/test_mutual_rooms.py
@@ -14,7 +14,7 @@
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
-from synapse.rest.client import login, room, shared_rooms
+from synapse.rest.client import login, mutual_rooms, room
from synapse.server import HomeServer
from synapse.util import Clock
@@ -22,16 +22,16 @@ from tests import unittest
from tests.server import FakeChannel
-class UserSharedRoomsTest(unittest.HomeserverTestCase):
+class UserMutualRoomsTest(unittest.HomeserverTestCase):
"""
- Tests the UserSharedRoomsServlet.
+ Tests the UserMutualRoomsServlet.
"""
servlets = [
login.register_servlets,
synapse.rest.admin.register_servlets_for_client_rest_resource,
room.register_servlets,
- shared_rooms.register_servlets,
+ mutual_rooms.register_servlets,
]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
@@ -43,10 +43,10 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
self.store = hs.get_datastores().main
self.handler = hs.get_user_directory_handler()
- def _get_shared_rooms(self, token: str, other_user: str) -> FakeChannel:
+ def _get_mutual_rooms(self, token: str, other_user: str) -> FakeChannel:
return self.make_request(
"GET",
- "/_matrix/client/unstable/uk.half-shot.msc2666/user/shared_rooms/%s"
+ "/_matrix/client/unstable/uk.half-shot.msc2666/user/mutual_rooms/%s"
% other_user,
access_token=token,
)
@@ -56,14 +56,14 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
A room should show up in the shared list of rooms between two users
if it is public.
"""
- self._check_shared_rooms_with(room_one_is_public=True, room_two_is_public=True)
+ self._check_mutual_rooms_with(room_one_is_public=True, room_two_is_public=True)
def test_shared_room_list_private(self) -> None:
"""
A room should show up in the shared list of rooms between two users
if it is private.
"""
- self._check_shared_rooms_with(
+ self._check_mutual_rooms_with(
room_one_is_public=False, room_two_is_public=False
)
@@ -72,9 +72,9 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
The shared room list between two users should contain both public and private
rooms.
"""
- self._check_shared_rooms_with(room_one_is_public=True, room_two_is_public=False)
+ self._check_mutual_rooms_with(room_one_is_public=True, room_two_is_public=False)
- def _check_shared_rooms_with(
+ def _check_mutual_rooms_with(
self, room_one_is_public: bool, room_two_is_public: bool
) -> None:
"""Checks that shared public or private rooms between two users appear in
@@ -94,7 +94,7 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
# Check shared rooms from user1's perspective.
# We should see the one room in common
- channel = self._get_shared_rooms(u1_token, u2)
+ channel = self._get_mutual_rooms(u1_token, u2)
self.assertEqual(200, channel.code, channel.result)
self.assertEqual(len(channel.json_body["joined"]), 1)
self.assertEqual(channel.json_body["joined"][0], room_id_one)
@@ -107,7 +107,7 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
self.helper.join(room_id_two, user=u2, tok=u2_token)
# Check shared rooms again. We should now see both rooms.
- channel = self._get_shared_rooms(u1_token, u2)
+ channel = self._get_mutual_rooms(u1_token, u2)
self.assertEqual(200, channel.code, channel.result)
self.assertEqual(len(channel.json_body["joined"]), 2)
for room_id_id in channel.json_body["joined"]:
@@ -128,7 +128,7 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
self.helper.join(room, user=u2, tok=u2_token)
# Assert user directory is not empty
- channel = self._get_shared_rooms(u1_token, u2)
+ channel = self._get_mutual_rooms(u1_token, u2)
self.assertEqual(200, channel.code, channel.result)
self.assertEqual(len(channel.json_body["joined"]), 1)
self.assertEqual(channel.json_body["joined"][0], room)
@@ -136,11 +136,11 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
self.helper.leave(room, user=u1, tok=u1_token)
# Check user1's view of shared rooms with user2
- channel = self._get_shared_rooms(u1_token, u2)
+ channel = self._get_mutual_rooms(u1_token, u2)
self.assertEqual(200, channel.code, channel.result)
self.assertEqual(len(channel.json_body["joined"]), 0)
# Check user2's view of shared rooms with user1
- channel = self._get_shared_rooms(u2_token, u1)
+ channel = self._get_mutual_rooms(u2_token, u1)
self.assertEqual(200, channel.code, channel.result)
self.assertEqual(len(channel.json_body["joined"]), 0)
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index 709f851a38..fe97a0b3dd 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -15,17 +15,16 @@
import itertools
import urllib.parse
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, Callable, Dict, List, Optional, Tuple
from unittest.mock import patch
from twisted.test.proto_helpers import MemoryReactor
-from synapse.api.constants import EventTypes, RelationTypes
+from synapse.api.constants import AccountDataTypes, EventTypes, RelationTypes
from synapse.rest import admin
from synapse.rest.client import login, register, relations, room, sync
from synapse.server import HomeServer
-from synapse.storage.relations import RelationPaginationToken
-from synapse.types import JsonDict, StreamToken
+from synapse.types import JsonDict
from synapse.util import Clock
from tests import unittest
@@ -80,6 +79,7 @@ class BaseRelationsTestCase(unittest.HomeserverTestCase):
content: Optional[dict] = None,
access_token: Optional[str] = None,
parent_id: Optional[str] = None,
+ expected_response_code: int = 200,
) -> FakeChannel:
"""Helper function to send a relation pointing at `self.parent_id`
@@ -116,16 +116,60 @@ class BaseRelationsTestCase(unittest.HomeserverTestCase):
content,
access_token=access_token,
)
+ self.assertEqual(expected_response_code, channel.code, channel.json_body)
return channel
+ def _get_related_events(self) -> List[str]:
+ """
+ Requests /relations on the parent ID and returns a list of event IDs.
+ """
+ # Request the relations of the event.
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}",
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ return [ev["event_id"] for ev in channel.json_body["chunk"]]
+
+ def _get_bundled_aggregations(self) -> JsonDict:
+ """
+ Requests /event on the parent ID and returns the m.relations field (from unsigned), if it exists.
+ """
+ # Fetch the bundled aggregations of the event.
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}/event/{self.parent_id}",
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ return channel.json_body["unsigned"].get("m.relations", {})
+
+ def _get_aggregations(self) -> List[JsonDict]:
+ """Request /aggregations on the parent ID and includes the returned chunk."""
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ return channel.json_body["chunk"]
+
+ def _find_event_in_chunk(self, events: List[JsonDict]) -> JsonDict:
+ """
+ Find the parent event in a chunk of events and assert that it has the proper bundled aggregations.
+ """
+ for event in events:
+ if event["event_id"] == self.parent_id:
+ return event
+
+ raise AssertionError(f"Event {self.parent_id} not found in chunk")
+
class RelationsTestCase(BaseRelationsTestCase):
def test_send_relation(self) -> None:
"""Tests that sending a relation works."""
-
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍")
- self.assertEqual(200, channel.code, channel.json_body)
-
event_id = channel.json_body["event_id"]
channel = self.make_request(
@@ -152,13 +196,13 @@ class RelationsTestCase(BaseRelationsTestCase):
def test_deny_invalid_event(self) -> None:
"""Test that we deny relations on non-existant events"""
- channel = self._send_relation(
+ self._send_relation(
RelationTypes.ANNOTATION,
EventTypes.Message,
parent_id="foo",
content={"body": "foo", "msgtype": "m.text"},
+ expected_response_code=400,
)
- self.assertEqual(400, channel.code, channel.json_body)
# Unless that event is referenced from another event!
self.get_success(
@@ -172,13 +216,12 @@ class RelationsTestCase(BaseRelationsTestCase):
desc="test_deny_invalid_event",
)
)
- channel = self._send_relation(
+ self._send_relation(
RelationTypes.THREAD,
EventTypes.Message,
parent_id="foo",
content={"body": "foo", "msgtype": "m.text"},
)
- self.assertEqual(200, channel.code, channel.json_body)
def test_deny_invalid_room(self) -> None:
"""Test that we deny relations on non-existant events"""
@@ -188,18 +231,20 @@ class RelationsTestCase(BaseRelationsTestCase):
parent_id = res["event_id"]
# Attempt to send an annotation to that event.
- channel = self._send_relation(
- RelationTypes.ANNOTATION, "m.reaction", parent_id=parent_id, key="A"
+ self._send_relation(
+ RelationTypes.ANNOTATION,
+ "m.reaction",
+ parent_id=parent_id,
+ key="A",
+ expected_response_code=400,
)
- self.assertEqual(400, channel.code, channel.json_body)
def test_deny_double_react(self) -> None:
"""Test that we deny relations on membership events"""
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
- self.assertEqual(200, channel.code, channel.json_body)
-
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
- self.assertEqual(400, channel.code, channel.json_body)
+ self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
+ self._send_relation(
+ RelationTypes.ANNOTATION, "m.reaction", "a", expected_response_code=400
+ )
def test_deny_forked_thread(self) -> None:
"""It is invalid to start a thread off a thread."""
@@ -209,386 +254,24 @@ class RelationsTestCase(BaseRelationsTestCase):
content={"msgtype": "m.text", "body": "foo"},
parent_id=self.parent_id,
)
- self.assertEqual(200, channel.code, channel.json_body)
parent_id = channel.json_body["event_id"]
- channel = self._send_relation(
+ self._send_relation(
RelationTypes.THREAD,
"m.room.message",
content={"msgtype": "m.text", "body": "foo"},
parent_id=parent_id,
- )
- self.assertEqual(400, channel.code, channel.json_body)
-
- def test_basic_paginate_relations(self) -> None:
- """Tests that calling pagination API correctly the latest relations."""
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
- self.assertEqual(200, channel.code, channel.json_body)
- first_annotation_id = channel.json_body["event_id"]
-
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
- self.assertEqual(200, channel.code, channel.json_body)
- second_annotation_id = channel.json_body["event_id"]
-
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- # We expect to get back a single pagination result, which is the latest
- # full relation event we sent above.
- self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
- self.assert_dict(
- {
- "event_id": second_annotation_id,
- "sender": self.user_id,
- "type": "m.reaction",
- },
- channel.json_body["chunk"][0],
- )
-
- # We also expect to get the original event (the id of which is self.parent_id)
- self.assertEqual(
- channel.json_body["original_event"]["event_id"], self.parent_id
- )
-
- # Make sure next_batch has something in it that looks like it could be a
- # valid token.
- self.assertIsInstance(
- channel.json_body.get("next_batch"), str, channel.json_body
- )
-
- # Request the relations again, but with a different direction.
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}/relations"
- f"/{self.parent_id}?limit=1&org.matrix.msc3715.dir=f",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- # We expect to get back a single pagination result, which is the earliest
- # full relation event we sent above.
- self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
- self.assert_dict(
- {
- "event_id": first_annotation_id,
- "sender": self.user_id,
- "type": "m.reaction",
- },
- channel.json_body["chunk"][0],
- )
-
- def _stream_token_to_relation_token(self, token: str) -> str:
- """Convert a StreamToken into a legacy token (RelationPaginationToken)."""
- room_key = self.get_success(StreamToken.from_string(self.store, token)).room_key
- return self.get_success(
- RelationPaginationToken(
- topological=room_key.topological, stream=room_key.stream
- ).to_string(self.store)
- )
-
- def test_repeated_paginate_relations(self) -> None:
- """Test that if we paginate using a limit and tokens then we get the
- expected events.
- """
-
- expected_event_ids = []
- for idx in range(10):
- channel = self._send_relation(
- RelationTypes.ANNOTATION, "m.reaction", chr(ord("a") + idx)
- )
- self.assertEqual(200, channel.code, channel.json_body)
- expected_event_ids.append(channel.json_body["event_id"])
-
- prev_token = ""
- found_event_ids: List[str] = []
- for _ in range(20):
- from_token = ""
- if prev_token:
- from_token = "&from=" + prev_token
-
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
- next_batch = channel.json_body.get("next_batch")
-
- self.assertNotEqual(prev_token, next_batch)
- prev_token = next_batch
-
- if not prev_token:
- break
-
- # We paginated backwards, so reverse
- found_event_ids.reverse()
- self.assertEqual(found_event_ids, expected_event_ids)
-
- # Reset and try again, but convert the tokens to the legacy format.
- prev_token = ""
- found_event_ids = []
- for _ in range(20):
- from_token = ""
- if prev_token:
- from_token = "&from=" + self._stream_token_to_relation_token(prev_token)
-
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
- next_batch = channel.json_body.get("next_batch")
-
- self.assertNotEqual(prev_token, next_batch)
- prev_token = next_batch
-
- if not prev_token:
- break
-
- # We paginated backwards, so reverse
- found_event_ids.reverse()
- self.assertEqual(found_event_ids, expected_event_ids)
-
- def test_pagination_from_sync_and_messages(self) -> None:
- """Pagination tokens from /sync and /messages can be used to paginate /relations."""
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "A")
- self.assertEqual(200, channel.code, channel.json_body)
- annotation_id = channel.json_body["event_id"]
- # Send an event after the relation events.
- self.helper.send(self.room, body="Latest event", tok=self.user_token)
-
- # Request /sync, limiting it such that only the latest event is returned
- # (and not the relation).
- filter = urllib.parse.quote_plus(b'{"room": {"timeline": {"limit": 1}}}')
- channel = self.make_request(
- "GET", f"/sync?filter={filter}", access_token=self.user_token
- )
- self.assertEqual(200, channel.code, channel.json_body)
- room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
- sync_prev_batch = room_timeline["prev_batch"]
- self.assertIsNotNone(sync_prev_batch)
- # Ensure the relation event is not in the batch returned from /sync.
- self.assertNotIn(
- annotation_id, [ev["event_id"] for ev in room_timeline["events"]]
- )
-
- # Request /messages, limiting it such that only the latest event is
- # returned (and not the relation).
- channel = self.make_request(
- "GET",
- f"/rooms/{self.room}/messages?dir=b&limit=1",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
- messages_end = channel.json_body["end"]
- self.assertIsNotNone(messages_end)
- # Ensure the relation event is not in the chunk returned from /messages.
- self.assertNotIn(
- annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]]
+ expected_response_code=400,
)
- # Request /relations with the pagination tokens received from both the
- # /sync and /messages responses above, in turn.
- #
- # This is a tiny bit silly since the client wouldn't know the parent ID
- # from the requests above; consider the parent ID to be known from a
- # previous /sync.
- for from_token in (sync_prev_batch, messages_end):
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?from={from_token}",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- # The relation should be in the returned chunk.
- self.assertIn(
- annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]]
- )
-
- def test_aggregation_pagination_groups(self) -> None:
- """Test that we can paginate annotation groups correctly."""
-
- # We need to create ten separate users to send each reaction.
- access_tokens = [self.user_token, self.user2_token]
- idx = 0
- while len(access_tokens) < 10:
- user_id, token = self._create_user("test" + str(idx))
- idx += 1
-
- self.helper.join(self.room, user=user_id, tok=token)
- access_tokens.append(token)
-
- idx = 0
- sent_groups = {"👍": 10, "a": 7, "b": 5, "c": 3, "d": 2, "e": 1}
- for key in itertools.chain.from_iterable(
- itertools.repeat(key, num) for key, num in sent_groups.items()
- ):
- channel = self._send_relation(
- RelationTypes.ANNOTATION,
- "m.reaction",
- key=key,
- access_token=access_tokens[idx],
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- idx += 1
- idx %= len(access_tokens)
-
- prev_token: Optional[str] = None
- found_groups: Dict[str, int] = {}
- for _ in range(20):
- from_token = ""
- if prev_token:
- from_token = "&from=" + prev_token
-
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}?limit=1{from_token}",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
-
- for groups in channel.json_body["chunk"]:
- # We only expect reactions
- self.assertEqual(groups["type"], "m.reaction", channel.json_body)
-
- # We should only see each key once
- self.assertNotIn(groups["key"], found_groups, channel.json_body)
-
- found_groups[groups["key"]] = groups["count"]
-
- next_batch = channel.json_body.get("next_batch")
-
- self.assertNotEqual(prev_token, next_batch)
- prev_token = next_batch
-
- if not prev_token:
- break
-
- self.assertEqual(sent_groups, found_groups)
-
- def test_aggregation_pagination_within_group(self) -> None:
- """Test that we can paginate within an annotation group."""
-
- # We need to create ten separate users to send each reaction.
- access_tokens = [self.user_token, self.user2_token]
- idx = 0
- while len(access_tokens) < 10:
- user_id, token = self._create_user("test" + str(idx))
- idx += 1
-
- self.helper.join(self.room, user=user_id, tok=token)
- access_tokens.append(token)
-
- idx = 0
- expected_event_ids = []
- for _ in range(10):
- channel = self._send_relation(
- RelationTypes.ANNOTATION,
- "m.reaction",
- key="👍",
- access_token=access_tokens[idx],
- )
- self.assertEqual(200, channel.code, channel.json_body)
- expected_event_ids.append(channel.json_body["event_id"])
-
- idx += 1
-
- # Also send a different type of reaction so that we test we don't see it
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
- self.assertEqual(200, channel.code, channel.json_body)
-
- prev_token = ""
- found_event_ids: List[str] = []
- encoded_key = urllib.parse.quote_plus("👍".encode())
- for _ in range(20):
- from_token = ""
- if prev_token:
- from_token = "&from=" + prev_token
-
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}"
- f"/aggregations/{self.parent_id}/{RelationTypes.ANNOTATION}"
- f"/m.reaction/{encoded_key}?limit=1{from_token}",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
-
- found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
-
- next_batch = channel.json_body.get("next_batch")
-
- self.assertNotEqual(prev_token, next_batch)
- prev_token = next_batch
-
- if not prev_token:
- break
-
- # We paginated backwards, so reverse
- found_event_ids.reverse()
- self.assertEqual(found_event_ids, expected_event_ids)
-
- # Reset and try again, but convert the tokens to the legacy format.
- prev_token = ""
- found_event_ids = []
- for _ in range(20):
- from_token = ""
- if prev_token:
- from_token = "&from=" + self._stream_token_to_relation_token(prev_token)
-
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}"
- f"/aggregations/{self.parent_id}/{RelationTypes.ANNOTATION}"
- f"/m.reaction/{encoded_key}?limit=1{from_token}",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
-
- found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
-
- next_batch = channel.json_body.get("next_batch")
-
- self.assertNotEqual(prev_token, next_batch)
- prev_token = next_batch
-
- if not prev_token:
- break
-
- # We paginated backwards, so reverse
- found_event_ids.reverse()
- self.assertEqual(found_event_ids, expected_event_ids)
-
def test_aggregation(self) -> None:
"""Test that annotations get correctly aggregated."""
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
- self.assertEqual(200, channel.code, channel.json_body)
-
- channel = self._send_relation(
+ self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
+ self._send_relation(
RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
)
- self.assertEqual(200, channel.code, channel.json_body)
-
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
- self.assertEqual(200, channel.code, channel.json_body)
+ self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
channel = self.make_request(
"GET",
@@ -618,220 +301,6 @@ class RelationsTestCase(BaseRelationsTestCase):
)
self.assertEqual(400, channel.code, channel.json_body)
- @unittest.override_config(
- {"experimental_features": {"msc3440_enabled": True, "msc3666_enabled": True}}
- )
- def test_bundled_aggregations(self) -> None:
- """
- Test that annotations, references, and threads get correctly bundled.
-
- Note that this doesn't test against /relations since only thread relations
- get bundled via that API. See test_aggregation_get_event_for_thread.
-
- See test_edit for a similar test for edits.
- """
- # Setup by sending a variety of relations.
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
- self.assertEqual(200, channel.code, channel.json_body)
-
- channel = self._send_relation(
- RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
- self.assertEqual(200, channel.code, channel.json_body)
-
- channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test")
- self.assertEqual(200, channel.code, channel.json_body)
- reply_1 = channel.json_body["event_id"]
-
- channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test")
- self.assertEqual(200, channel.code, channel.json_body)
- reply_2 = channel.json_body["event_id"]
-
- channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
- self.assertEqual(200, channel.code, channel.json_body)
-
- channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
- self.assertEqual(200, channel.code, channel.json_body)
- thread_2 = channel.json_body["event_id"]
-
- def assert_bundle(event_json: JsonDict) -> None:
- """Assert the expected values of the bundled aggregations."""
- relations_dict = event_json["unsigned"].get("m.relations")
-
- # Ensure the fields are as expected.
- self.assertCountEqual(
- relations_dict.keys(),
- (
- RelationTypes.ANNOTATION,
- RelationTypes.REFERENCE,
- RelationTypes.THREAD,
- ),
- )
-
- # Check the values of each field.
- self.assertEqual(
- {
- "chunk": [
- {"type": "m.reaction", "key": "a", "count": 2},
- {"type": "m.reaction", "key": "b", "count": 1},
- ]
- },
- relations_dict[RelationTypes.ANNOTATION],
- )
-
- self.assertEqual(
- {"chunk": [{"event_id": reply_1}, {"event_id": reply_2}]},
- relations_dict[RelationTypes.REFERENCE],
- )
-
- self.assertEqual(
- 2,
- relations_dict[RelationTypes.THREAD].get("count"),
- )
- self.assertTrue(
- relations_dict[RelationTypes.THREAD].get("current_user_participated")
- )
- # The latest thread event has some fields that don't matter.
- self.assert_dict(
- {
- "content": {
- "m.relates_to": {
- "event_id": self.parent_id,
- "rel_type": RelationTypes.THREAD,
- }
- },
- "event_id": thread_2,
- "room_id": self.room,
- "sender": self.user_id,
- "type": "m.room.test",
- "user_id": self.user_id,
- },
- relations_dict[RelationTypes.THREAD].get("latest_event"),
- )
-
- # Request the event directly.
- channel = self.make_request(
- "GET",
- f"/rooms/{self.room}/event/{self.parent_id}",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
- assert_bundle(channel.json_body)
-
- # Request the room messages.
- channel = self.make_request(
- "GET",
- f"/rooms/{self.room}/messages?dir=b",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
- assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"]))
-
- # Request the room context.
- channel = self.make_request(
- "GET",
- f"/rooms/{self.room}/context/{self.parent_id}",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
- assert_bundle(channel.json_body["event"])
-
- # Request sync.
- channel = self.make_request("GET", "/sync", access_token=self.user_token)
- self.assertEqual(200, channel.code, channel.json_body)
- room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
- self.assertTrue(room_timeline["limited"])
- assert_bundle(self._find_event_in_chunk(room_timeline["events"]))
-
- # Request search.
- channel = self.make_request(
- "POST",
- "/search",
- # Search term matches the parent message.
- content={"search_categories": {"room_events": {"search_term": "Hi"}}},
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
- chunk = [
- result["result"]
- for result in channel.json_body["search_categories"]["room_events"][
- "results"
- ]
- ]
- assert_bundle(self._find_event_in_chunk(chunk))
-
- def test_aggregation_get_event_for_annotation(self) -> None:
- """Test that annotations do not get bundled aggregations included
- when directly requested.
- """
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
- self.assertEqual(200, channel.code, channel.json_body)
- annotation_id = channel.json_body["event_id"]
-
- # Annotate the annotation.
- channel = self._send_relation(
- RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=annotation_id
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- channel = self.make_request(
- "GET",
- f"/rooms/{self.room}/event/{annotation_id}",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
- self.assertIsNone(channel.json_body["unsigned"].get("m.relations"))
-
- def test_aggregation_get_event_for_thread(self) -> None:
- """Test that threads get bundled aggregations included when directly requested."""
- channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
- self.assertEqual(200, channel.code, channel.json_body)
- thread_id = channel.json_body["event_id"]
-
- # Annotate the annotation.
- channel = self._send_relation(
- RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_id
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- channel = self.make_request(
- "GET",
- f"/rooms/{self.room}/event/{thread_id}",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
- self.assertEqual(
- channel.json_body["unsigned"].get("m.relations"),
- {
- RelationTypes.ANNOTATION: {
- "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}]
- },
- },
- )
-
- # It should also be included when the entire thread is requested.
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
- self.assertEqual(len(channel.json_body["chunk"]), 1)
-
- thread_message = channel.json_body["chunk"][0]
- self.assertEqual(
- thread_message["unsigned"].get("m.relations"),
- {
- RelationTypes.ANNOTATION: {
- "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}]
- },
- },
- )
-
- @unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
def test_ignore_invalid_room(self) -> None:
"""Test that we ignore invalid relations over federation."""
# Create another room and send a message in it.
@@ -953,8 +422,6 @@ class RelationsTestCase(BaseRelationsTestCase):
"m.room.message",
content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
)
- self.assertEqual(200, channel.code, channel.json_body)
-
edit_event_id = channel.json_body["event_id"]
def assert_bundle(event_json: JsonDict) -> None:
@@ -1030,7 +497,7 @@ class RelationsTestCase(BaseRelationsTestCase):
shouldn't be allowed, are correctly handled.
"""
- channel = self._send_relation(
+ self._send_relation(
RelationTypes.REPLACE,
"m.room.message",
content={
@@ -1039,7 +506,6 @@ class RelationsTestCase(BaseRelationsTestCase):
"m.new_content": {"msgtype": "m.text", "body": "First edit"},
},
)
- self.assertEqual(200, channel.code, channel.json_body)
new_body = {"msgtype": "m.text", "body": "I've been edited!"}
channel = self._send_relation(
@@ -1047,11 +513,9 @@ class RelationsTestCase(BaseRelationsTestCase):
"m.room.message",
content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
)
- self.assertEqual(200, channel.code, channel.json_body)
-
edit_event_id = channel.json_body["event_id"]
- channel = self._send_relation(
+ self._send_relation(
RelationTypes.REPLACE,
"m.room.message.WRONG_TYPE",
content={
@@ -1060,7 +524,6 @@ class RelationsTestCase(BaseRelationsTestCase):
"m.new_content": {"msgtype": "m.text", "body": "Edit, but wrong type"},
},
)
- self.assertEqual(200, channel.code, channel.json_body)
channel = self.make_request(
"GET",
@@ -1091,7 +554,6 @@ class RelationsTestCase(BaseRelationsTestCase):
"m.room.message",
content={"msgtype": "m.text", "body": "A reply!"},
)
- self.assertEqual(200, channel.code, channel.json_body)
reply = channel.json_body["event_id"]
new_body = {"msgtype": "m.text", "body": "I've been edited!"}
@@ -1101,8 +563,6 @@ class RelationsTestCase(BaseRelationsTestCase):
content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
parent_id=reply,
)
- self.assertEqual(200, channel.code, channel.json_body)
-
edit_event_id = channel.json_body["event_id"]
channel = self.make_request(
@@ -1138,7 +598,6 @@ class RelationsTestCase(BaseRelationsTestCase):
{"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict
)
- @unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
def test_edit_thread(self) -> None:
"""Test that editing a thread works."""
@@ -1148,17 +607,15 @@ class RelationsTestCase(BaseRelationsTestCase):
"m.room.message",
content={"msgtype": "m.text", "body": "A threaded reply!"},
)
- self.assertEqual(200, channel.code, channel.json_body)
threaded_event_id = channel.json_body["event_id"]
new_body = {"msgtype": "m.text", "body": "I've been edited!"}
- channel = self._send_relation(
+ self._send_relation(
RelationTypes.REPLACE,
"m.room.message",
content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
parent_id=threaded_event_id,
)
- self.assertEqual(200, channel.code, channel.json_body)
# Fetch the thread root, to get the bundled aggregation for the thread.
channel = self.make_request(
@@ -1190,11 +647,10 @@ class RelationsTestCase(BaseRelationsTestCase):
"m.new_content": new_body,
},
)
- self.assertEqual(200, channel.code, channel.json_body)
edit_event_id = channel.json_body["event_id"]
# Edit the edit event.
- channel = self._send_relation(
+ self._send_relation(
RelationTypes.REPLACE,
"m.room.message",
content={
@@ -1204,7 +660,6 @@ class RelationsTestCase(BaseRelationsTestCase):
},
parent_id=edit_event_id,
)
- self.assertEqual(200, channel.code, channel.json_body)
# Request the original event.
channel = self.make_request(
@@ -1231,7 +686,6 @@ class RelationsTestCase(BaseRelationsTestCase):
def test_unknown_relations(self) -> None:
"""Unknown relations should be accepted."""
channel = self._send_relation("m.relation.test", "m.room.test")
- self.assertEqual(200, channel.code, channel.json_body)
event_id = channel.json_body["event_id"]
channel = self.make_request(
@@ -1272,28 +726,15 @@ class RelationsTestCase(BaseRelationsTestCase):
self.assertEqual(200, channel.code, channel.json_body)
self.assertEqual(channel.json_body["chunk"], [])
- def _find_event_in_chunk(self, events: List[JsonDict]) -> JsonDict:
- """
- Find the parent event in a chunk of events and assert that it has the proper bundled aggregations.
- """
- for event in events:
- if event["event_id"] == self.parent_id:
- return event
-
- raise AssertionError(f"Event {self.parent_id} not found in chunk")
-
def test_background_update(self) -> None:
"""Test the event_arbitrary_relations background update."""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍")
- self.assertEqual(200, channel.code, channel.json_body)
annotation_event_id_good = channel.json_body["event_id"]
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="A")
- self.assertEqual(200, channel.code, channel.json_body)
annotation_event_id_bad = channel.json_body["event_id"]
channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
- self.assertEqual(200, channel.code, channel.json_body)
thread_event_id = channel.json_body["event_id"]
# Clean-up the table as if the inserts did not happen during event creation.
@@ -1345,8 +786,638 @@ class RelationsTestCase(BaseRelationsTestCase):
)
+class RelationPaginationTestCase(BaseRelationsTestCase):
+ def test_basic_paginate_relations(self) -> None:
+ """Tests that calling pagination API correctly the latest relations."""
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
+ first_annotation_id = channel.json_body["event_id"]
+
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
+ second_annotation_id = channel.json_body["event_id"]
+
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+
+ # We expect to get back a single pagination result, which is the latest
+ # full relation event we sent above.
+ self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
+ self.assert_dict(
+ {
+ "event_id": second_annotation_id,
+ "sender": self.user_id,
+ "type": "m.reaction",
+ },
+ channel.json_body["chunk"][0],
+ )
+
+ # We also expect to get the original event (the id of which is self.parent_id)
+ self.assertEqual(
+ channel.json_body["original_event"]["event_id"], self.parent_id
+ )
+
+ # Make sure next_batch has something in it that looks like it could be a
+ # valid token.
+ self.assertIsInstance(
+ channel.json_body.get("next_batch"), str, channel.json_body
+ )
+
+ # Request the relations again, but with a different direction.
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}/relations"
+ f"/{self.parent_id}?limit=1&org.matrix.msc3715.dir=f",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+
+ # We expect to get back a single pagination result, which is the earliest
+ # full relation event we sent above.
+ self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
+ self.assert_dict(
+ {
+ "event_id": first_annotation_id,
+ "sender": self.user_id,
+ "type": "m.reaction",
+ },
+ channel.json_body["chunk"][0],
+ )
+
+ def test_repeated_paginate_relations(self) -> None:
+ """Test that if we paginate using a limit and tokens then we get the
+ expected events.
+ """
+
+ expected_event_ids = []
+ for idx in range(10):
+ channel = self._send_relation(
+ RelationTypes.ANNOTATION, "m.reaction", chr(ord("a") + idx)
+ )
+ expected_event_ids.append(channel.json_body["event_id"])
+
+ prev_token = ""
+ found_event_ids: List[str] = []
+ for _ in range(20):
+ from_token = ""
+ if prev_token:
+ from_token = "&from=" + prev_token
+
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+
+ found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
+ next_batch = channel.json_body.get("next_batch")
+
+ self.assertNotEqual(prev_token, next_batch)
+ prev_token = next_batch
+
+ if not prev_token:
+ break
+
+ # We paginated backwards, so reverse
+ found_event_ids.reverse()
+ self.assertEqual(found_event_ids, expected_event_ids)
+
+ def test_pagination_from_sync_and_messages(self) -> None:
+ """Pagination tokens from /sync and /messages can be used to paginate /relations."""
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "A")
+ annotation_id = channel.json_body["event_id"]
+ # Send an event after the relation events.
+ self.helper.send(self.room, body="Latest event", tok=self.user_token)
+
+ # Request /sync, limiting it such that only the latest event is returned
+ # (and not the relation).
+ filter = urllib.parse.quote_plus(b'{"room": {"timeline": {"limit": 1}}}')
+ channel = self.make_request(
+ "GET", f"/sync?filter={filter}", access_token=self.user_token
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
+ sync_prev_batch = room_timeline["prev_batch"]
+ self.assertIsNotNone(sync_prev_batch)
+ # Ensure the relation event is not in the batch returned from /sync.
+ self.assertNotIn(
+ annotation_id, [ev["event_id"] for ev in room_timeline["events"]]
+ )
+
+ # Request /messages, limiting it such that only the latest event is
+ # returned (and not the relation).
+ channel = self.make_request(
+ "GET",
+ f"/rooms/{self.room}/messages?dir=b&limit=1",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ messages_end = channel.json_body["end"]
+ self.assertIsNotNone(messages_end)
+ # Ensure the relation event is not in the chunk returned from /messages.
+ self.assertNotIn(
+ annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]]
+ )
+
+ # Request /relations with the pagination tokens received from both the
+ # /sync and /messages responses above, in turn.
+ #
+ # This is a tiny bit silly since the client wouldn't know the parent ID
+ # from the requests above; consider the parent ID to be known from a
+ # previous /sync.
+ for from_token in (sync_prev_batch, messages_end):
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?from={from_token}",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+
+ # The relation should be in the returned chunk.
+ self.assertIn(
+ annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]]
+ )
+
+ def test_aggregation_pagination_groups(self) -> None:
+ """Test that we can paginate annotation groups correctly."""
+
+ # We need to create ten separate users to send each reaction.
+ access_tokens = [self.user_token, self.user2_token]
+ idx = 0
+ while len(access_tokens) < 10:
+ user_id, token = self._create_user("test" + str(idx))
+ idx += 1
+
+ self.helper.join(self.room, user=user_id, tok=token)
+ access_tokens.append(token)
+
+ idx = 0
+ sent_groups = {"👍": 10, "a": 7, "b": 5, "c": 3, "d": 2, "e": 1}
+ for key in itertools.chain.from_iterable(
+ itertools.repeat(key, num) for key, num in sent_groups.items()
+ ):
+ self._send_relation(
+ RelationTypes.ANNOTATION,
+ "m.reaction",
+ key=key,
+ access_token=access_tokens[idx],
+ )
+
+ idx += 1
+ idx %= len(access_tokens)
+
+ prev_token: Optional[str] = None
+ found_groups: Dict[str, int] = {}
+ for _ in range(20):
+ from_token = ""
+ if prev_token:
+ from_token = "&from=" + prev_token
+
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}?limit=1{from_token}",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+
+ self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
+
+ for groups in channel.json_body["chunk"]:
+ # We only expect reactions
+ self.assertEqual(groups["type"], "m.reaction", channel.json_body)
+
+ # We should only see each key once
+ self.assertNotIn(groups["key"], found_groups, channel.json_body)
+
+ found_groups[groups["key"]] = groups["count"]
+
+ next_batch = channel.json_body.get("next_batch")
+
+ self.assertNotEqual(prev_token, next_batch)
+ prev_token = next_batch
+
+ if not prev_token:
+ break
+
+ self.assertEqual(sent_groups, found_groups)
+
+ def test_aggregation_pagination_within_group(self) -> None:
+ """Test that we can paginate within an annotation group."""
+
+ # We need to create ten separate users to send each reaction.
+ access_tokens = [self.user_token, self.user2_token]
+ idx = 0
+ while len(access_tokens) < 10:
+ user_id, token = self._create_user("test" + str(idx))
+ idx += 1
+
+ self.helper.join(self.room, user=user_id, tok=token)
+ access_tokens.append(token)
+
+ idx = 0
+ expected_event_ids = []
+ for _ in range(10):
+ channel = self._send_relation(
+ RelationTypes.ANNOTATION,
+ "m.reaction",
+ key="👍",
+ access_token=access_tokens[idx],
+ )
+ expected_event_ids.append(channel.json_body["event_id"])
+
+ idx += 1
+
+ # Also send a different type of reaction so that we test we don't see it
+ self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
+
+ prev_token = ""
+ found_event_ids: List[str] = []
+ encoded_key = urllib.parse.quote_plus("👍".encode())
+ for _ in range(20):
+ from_token = ""
+ if prev_token:
+ from_token = "&from=" + prev_token
+
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}"
+ f"/aggregations/{self.parent_id}/{RelationTypes.ANNOTATION}"
+ f"/m.reaction/{encoded_key}?limit=1{from_token}",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+
+ self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
+
+ found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
+
+ next_batch = channel.json_body.get("next_batch")
+
+ self.assertNotEqual(prev_token, next_batch)
+ prev_token = next_batch
+
+ if not prev_token:
+ break
+
+ # We paginated backwards, so reverse
+ found_event_ids.reverse()
+ self.assertEqual(found_event_ids, expected_event_ids)
+
+
+class BundledAggregationsTestCase(BaseRelationsTestCase):
+ """
+ See RelationsTestCase.test_edit for a similar test for edits.
+
+ Note that this doesn't test against /relations since only thread relations
+ get bundled via that API. See test_aggregation_get_event_for_thread.
+ """
+
+ def _test_bundled_aggregations(
+ self,
+ relation_type: str,
+ assertion_callable: Callable[[JsonDict], None],
+ expected_db_txn_for_event: int,
+ ) -> None:
+ """
+ Makes requests to various endpoints which should include bundled aggregations
+ and then calls an assertion function on the bundled aggregations.
+
+ Args:
+ relation_type: The field to search for in the `m.relations` field in unsigned.
+ assertion_callable: Called with the contents of unsigned["m.relations"][relation_type]
+ for relation-specific assertions.
+ expected_db_txn_for_event: The number of database transactions which
+ are expected for a call to /event/.
+ """
+
+ def assert_bundle(event_json: JsonDict) -> None:
+ """Assert the expected values of the bundled aggregations."""
+ relations_dict = event_json["unsigned"].get("m.relations")
+
+ # Ensure the fields are as expected.
+ self.assertCountEqual(relations_dict.keys(), (relation_type,))
+ assertion_callable(relations_dict[relation_type])
+
+ # Request the event directly.
+ channel = self.make_request(
+ "GET",
+ f"/rooms/{self.room}/event/{self.parent_id}",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ assert_bundle(channel.json_body)
+ assert channel.resource_usage is not None
+ self.assertEqual(channel.resource_usage.db_txn_count, expected_db_txn_for_event)
+
+ # Request the room messages.
+ channel = self.make_request(
+ "GET",
+ f"/rooms/{self.room}/messages?dir=b",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"]))
+
+ # Request the room context.
+ channel = self.make_request(
+ "GET",
+ f"/rooms/{self.room}/context/{self.parent_id}",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ assert_bundle(channel.json_body["event"])
+
+ # Request sync.
+ filter = urllib.parse.quote_plus(b'{"room": {"timeline": {"limit": 4}}}')
+ channel = self.make_request(
+ "GET", f"/sync?filter={filter}", access_token=self.user_token
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
+ self.assertTrue(room_timeline["limited"])
+ assert_bundle(self._find_event_in_chunk(room_timeline["events"]))
+
+ # Request search.
+ channel = self.make_request(
+ "POST",
+ "/search",
+ # Search term matches the parent message.
+ content={"search_categories": {"room_events": {"search_term": "Hi"}}},
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ chunk = [
+ result["result"]
+ for result in channel.json_body["search_categories"]["room_events"][
+ "results"
+ ]
+ ]
+ assert_bundle(self._find_event_in_chunk(chunk))
+
+ @unittest.override_config({"experimental_features": {"msc3666_enabled": True}})
+ def test_annotation(self) -> None:
+ """
+ Test that annotations get correctly bundled.
+ """
+ # Setup by sending a variety of relations.
+ self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
+ self._send_relation(
+ RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
+ )
+ self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
+
+ def assert_annotations(bundled_aggregations: JsonDict) -> None:
+ self.assertEqual(
+ {
+ "chunk": [
+ {"type": "m.reaction", "key": "a", "count": 2},
+ {"type": "m.reaction", "key": "b", "count": 1},
+ ]
+ },
+ bundled_aggregations,
+ )
+
+ self._test_bundled_aggregations(RelationTypes.ANNOTATION, assert_annotations, 7)
+
+ @unittest.override_config({"experimental_features": {"msc3666_enabled": True}})
+ def test_reference(self) -> None:
+ """
+ Test that references get correctly bundled.
+ """
+ channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test")
+ reply_1 = channel.json_body["event_id"]
+
+ channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test")
+ reply_2 = channel.json_body["event_id"]
+
+ def assert_annotations(bundled_aggregations: JsonDict) -> None:
+ self.assertEqual(
+ {"chunk": [{"event_id": reply_1}, {"event_id": reply_2}]},
+ bundled_aggregations,
+ )
+
+ self._test_bundled_aggregations(RelationTypes.REFERENCE, assert_annotations, 7)
+
+ @unittest.override_config({"experimental_features": {"msc3666_enabled": True}})
+ def test_thread(self) -> None:
+ """
+ Test that threads get correctly bundled.
+ """
+ self._send_relation(RelationTypes.THREAD, "m.room.test")
+ channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
+ thread_2 = channel.json_body["event_id"]
+
+ def assert_annotations(bundled_aggregations: JsonDict) -> None:
+ self.assertEqual(2, bundled_aggregations.get("count"))
+ self.assertTrue(bundled_aggregations.get("current_user_participated"))
+ # The latest thread event has some fields that don't matter.
+ self.assert_dict(
+ {
+ "content": {
+ "m.relates_to": {
+ "event_id": self.parent_id,
+ "rel_type": RelationTypes.THREAD,
+ }
+ },
+ "event_id": thread_2,
+ "sender": self.user_id,
+ "type": "m.room.test",
+ },
+ bundled_aggregations.get("latest_event"),
+ )
+
+ self._test_bundled_aggregations(RelationTypes.THREAD, assert_annotations, 9)
+
+ def test_aggregation_get_event_for_annotation(self) -> None:
+ """Test that annotations do not get bundled aggregations included
+ when directly requested.
+ """
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
+ annotation_id = channel.json_body["event_id"]
+
+ # Annotate the annotation.
+ self._send_relation(
+ RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=annotation_id
+ )
+
+ channel = self.make_request(
+ "GET",
+ f"/rooms/{self.room}/event/{annotation_id}",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ self.assertIsNone(channel.json_body["unsigned"].get("m.relations"))
+
+ def test_aggregation_get_event_for_thread(self) -> None:
+ """Test that threads get bundled aggregations included when directly requested."""
+ channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
+ thread_id = channel.json_body["event_id"]
+
+ # Annotate the annotation.
+ self._send_relation(
+ RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_id
+ )
+
+ channel = self.make_request(
+ "GET",
+ f"/rooms/{self.room}/event/{thread_id}",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ self.assertEqual(
+ channel.json_body["unsigned"].get("m.relations"),
+ {
+ RelationTypes.ANNOTATION: {
+ "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}]
+ },
+ },
+ )
+
+ # It should also be included when the entire thread is requested.
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ self.assertEqual(len(channel.json_body["chunk"]), 1)
+
+ thread_message = channel.json_body["chunk"][0]
+ self.assertEqual(
+ thread_message["unsigned"].get("m.relations"),
+ {
+ RelationTypes.ANNOTATION: {
+ "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}]
+ },
+ },
+ )
+
+ def test_bundled_aggregations_with_filter(self) -> None:
+ """
+ If "unsigned" is an omitted field (due to filtering), adding the bundled
+ aggregations should not break.
+
+ Note that the spec allows for a server to return additional fields beyond
+ what is specified.
+ """
+ self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
+
+ # Note that the sync filter does not include "unsigned" as a field.
+ filter = urllib.parse.quote_plus(
+ b'{"event_fields": ["content", "event_id"], "room": {"timeline": {"limit": 3}}}'
+ )
+ channel = self.make_request(
+ "GET", f"/sync?filter={filter}", access_token=self.user_token
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+
+ # Ensure the timeline is limited, find the parent event.
+ room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
+ self.assertTrue(room_timeline["limited"])
+ parent_event = self._find_event_in_chunk(room_timeline["events"])
+
+ # Ensure there's bundled aggregations on it.
+ self.assertIn("unsigned", parent_event)
+ self.assertIn("m.relations", parent_event["unsigned"])
+
+
+class RelationIgnoredUserTestCase(BaseRelationsTestCase):
+ """Relations sent from an ignored user should be ignored."""
+
+ def _test_ignored_user(
+ self, allowed_event_ids: List[str], ignored_event_ids: List[str]
+ ) -> None:
+ """
+ Fetch the relations and ensure they're all there, then ignore user2, and
+ repeat.
+ """
+ # Get the relations.
+ event_ids = self._get_related_events()
+ self.assertCountEqual(event_ids, allowed_event_ids + ignored_event_ids)
+
+ # Ignore user2 and re-do the requests.
+ self.get_success(
+ self.store.add_account_data_for_user(
+ self.user_id,
+ AccountDataTypes.IGNORED_USER_LIST,
+ {"ignored_users": {self.user2_id: {}}},
+ )
+ )
+
+ # Get the relations.
+ event_ids = self._get_related_events()
+ self.assertCountEqual(event_ids, allowed_event_ids)
+
+ def test_annotation(self) -> None:
+ """Annotations should ignore"""
+ # Send 2 from us, 2 from the to be ignored user.
+ allowed_event_ids = []
+ ignored_event_ids = []
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
+ allowed_event_ids.append(channel.json_body["event_id"])
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="b")
+ allowed_event_ids.append(channel.json_body["event_id"])
+ channel = self._send_relation(
+ RelationTypes.ANNOTATION,
+ "m.reaction",
+ key="a",
+ access_token=self.user2_token,
+ )
+ ignored_event_ids.append(channel.json_body["event_id"])
+ channel = self._send_relation(
+ RelationTypes.ANNOTATION,
+ "m.reaction",
+ key="c",
+ access_token=self.user2_token,
+ )
+ ignored_event_ids.append(channel.json_body["event_id"])
+
+ self._test_ignored_user(allowed_event_ids, ignored_event_ids)
+
+ def test_reference(self) -> None:
+ """Annotations should ignore"""
+ channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test")
+ allowed_event_ids = [channel.json_body["event_id"]]
+
+ channel = self._send_relation(
+ RelationTypes.REFERENCE, "m.room.test", access_token=self.user2_token
+ )
+ ignored_event_ids = [channel.json_body["event_id"]]
+
+ self._test_ignored_user(allowed_event_ids, ignored_event_ids)
+
+ def test_thread(self) -> None:
+ """Annotations should ignore"""
+ channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
+ allowed_event_ids = [channel.json_body["event_id"]]
+
+ channel = self._send_relation(
+ RelationTypes.THREAD, "m.room.test", access_token=self.user2_token
+ )
+ ignored_event_ids = [channel.json_body["event_id"]]
+
+ self._test_ignored_user(allowed_event_ids, ignored_event_ids)
+
+
class RelationRedactionTestCase(BaseRelationsTestCase):
- """Test the behaviour of relations when the parent or child event is redacted."""
+ """
+ Test the behaviour of relations when the parent or child event is redacted.
+
+ The behaviour of each relation type is subtly different which causes the tests
+ to be a bit repetitive, they follow a naming scheme of:
+
+ test_redact_(relation|parent)_{relation_type}
+
+ The first bit of "relation" means that the event with the relation defined
+ on it (the child event) is to be redacted. A "parent" means that the target
+ of the relation (the parent event) is to be redacted.
+
+ The relation_type describes which type of relation is under test (i.e. it is
+ related to the value of rel_type in the event content).
+ """
def _redact(self, event_id: str) -> None:
channel = self.make_request(
@@ -1358,40 +1429,116 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
self.assertEqual(200, channel.code, channel.json_body)
def test_redact_relation_annotation(self) -> None:
- """Test that annotations of an event are properly handled after the
+ """
+ Test that annotations of an event are properly handled after the
annotation is redacted.
+
+ The redacted relation should not be included in bundled aggregations or
+ the response to relations.
"""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
- self.assertEqual(200, channel.code, channel.json_body)
to_redact_event_id = channel.json_body["event_id"]
channel = self._send_relation(
RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
)
- self.assertEqual(200, channel.code, channel.json_body)
+ unredacted_event_id = channel.json_body["event_id"]
+
+ # Both relations should exist.
+ event_ids = self._get_related_events()
+ relations = self._get_bundled_aggregations()
+ self.assertCountEqual(event_ids, [to_redact_event_id, unredacted_event_id])
+ self.assertEquals(
+ relations["m.annotation"],
+ {"chunk": [{"type": "m.reaction", "key": "a", "count": 2}]},
+ )
+
+ # Both relations appear in the aggregation.
+ chunk = self._get_aggregations()
+ self.assertEqual(chunk, [{"type": "m.reaction", "key": "a", "count": 2}])
# Redact one of the reactions.
self._redact(to_redact_event_id)
- # Ensure that the aggregations are correct.
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}",
- access_token=self.user_token,
+ # The unredacted relation should still exist.
+ event_ids = self._get_related_events()
+ relations = self._get_bundled_aggregations()
+ self.assertEquals(event_ids, [unredacted_event_id])
+ self.assertEquals(
+ relations["m.annotation"],
+ {"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]},
)
- self.assertEqual(200, channel.code, channel.json_body)
+ # The unredacted aggregation should still exist.
+ chunk = self._get_aggregations()
+ self.assertEqual(chunk, [{"type": "m.reaction", "key": "a", "count": 1}])
+
+ def test_redact_relation_thread(self) -> None:
+ """
+ Test that thread replies are properly handled after the thread reply redacted.
+
+ The redacted event should not be included in bundled aggregations or
+ the response to relations.
+ """
+ channel = self._send_relation(
+ RelationTypes.THREAD,
+ EventTypes.Message,
+ content={"body": "reply 1", "msgtype": "m.text"},
+ )
+ unredacted_event_id = channel.json_body["event_id"]
+
+ # Note that the *last* event in the thread is redacted, as that gets
+ # included in the bundled aggregation.
+ channel = self._send_relation(
+ RelationTypes.THREAD,
+ EventTypes.Message,
+ content={"body": "reply 2", "msgtype": "m.text"},
+ )
+ to_redact_event_id = channel.json_body["event_id"]
+
+ # Both relations exist.
+ event_ids = self._get_related_events()
+ relations = self._get_bundled_aggregations()
+ self.assertEquals(event_ids, [to_redact_event_id, unredacted_event_id])
+ self.assertDictContainsSubset(
+ {
+ "count": 2,
+ "current_user_participated": True,
+ },
+ relations[RelationTypes.THREAD],
+ )
+ # And the latest event returned is the event that will be redacted.
+ self.assertEqual(
+ relations[RelationTypes.THREAD]["latest_event"]["event_id"],
+ to_redact_event_id,
+ )
+
+ # Redact one of the reactions.
+ self._redact(to_redact_event_id)
+
+ # The unredacted relation should still exist.
+ event_ids = self._get_related_events()
+ relations = self._get_bundled_aggregations()
+ self.assertEquals(event_ids, [unredacted_event_id])
+ self.assertDictContainsSubset(
+ {
+ "count": 1,
+ "current_user_participated": True,
+ },
+ relations[RelationTypes.THREAD],
+ )
+ # And the latest event is now the unredacted event.
self.assertEqual(
- channel.json_body,
- {"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]},
+ relations[RelationTypes.THREAD]["latest_event"]["event_id"],
+ unredacted_event_id,
)
- def test_redact_relation_edit(self) -> None:
+ def test_redact_parent_edit(self) -> None:
"""Test that edits of an event are redacted when the original event
is redacted.
"""
# Add a relation
- channel = self._send_relation(
+ self._send_relation(
RelationTypes.REPLACE,
"m.room.message",
parent_id=self.parent_id,
@@ -1401,54 +1548,83 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
"m.new_content": {"msgtype": "m.text", "body": "First edit"},
},
)
- self.assertEqual(200, channel.code, channel.json_body)
# Check the relation is returned
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}/relations"
- f"/{self.parent_id}/m.replace/m.room.message",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- self.assertIn("chunk", channel.json_body)
- self.assertEqual(len(channel.json_body["chunk"]), 1)
+ event_ids = self._get_related_events()
+ relations = self._get_bundled_aggregations()
+ self.assertEqual(len(event_ids), 1)
+ self.assertIn(RelationTypes.REPLACE, relations)
# Redact the original event
self._redact(self.parent_id)
- # Try to check for remaining m.replace relations
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}/relations"
- f"/{self.parent_id}/m.replace/m.room.message",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
+ # The relations are not returned.
+ event_ids = self._get_related_events()
+ relations = self._get_bundled_aggregations()
+ self.assertEqual(len(event_ids), 0)
+ self.assertEqual(relations, {})
- # Check that no relations are returned
- self.assertIn("chunk", channel.json_body)
- self.assertEqual(channel.json_body["chunk"], [])
-
- def test_redact_parent(self) -> None:
- """Test that annotations of an event are redacted when the original event
+ def test_redact_parent_annotation(self) -> None:
+ """Test that annotations of an event are viewable when the original event
is redacted.
"""
# Add a relation
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍")
- self.assertEqual(200, channel.code, channel.json_body)
+ related_event_id = channel.json_body["event_id"]
+
+ # The relations should exist.
+ event_ids = self._get_related_events()
+ relations = self._get_bundled_aggregations()
+ self.assertEqual(len(event_ids), 1)
+ self.assertIn(RelationTypes.ANNOTATION, relations)
+
+ # The aggregation should exist.
+ chunk = self._get_aggregations()
+ self.assertEqual(chunk, [{"type": "m.reaction", "key": "👍", "count": 1}])
# Redact the original event.
self._redact(self.parent_id)
- # Check that aggregations returns zero
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}/m.annotation/m.reaction",
- access_token=self.user_token,
+ # The relations are returned.
+ event_ids = self._get_related_events()
+ relations = self._get_bundled_aggregations()
+ self.assertEquals(event_ids, [related_event_id])
+ self.assertEquals(
+ relations["m.annotation"],
+ {"chunk": [{"type": "m.reaction", "key": "👍", "count": 1}]},
)
- self.assertEqual(200, channel.code, channel.json_body)
- self.assertIn("chunk", channel.json_body)
- self.assertEqual(channel.json_body["chunk"], [])
+ # There's nothing to aggregate.
+ chunk = self._get_aggregations()
+ self.assertEqual(chunk, [{"count": 1, "key": "👍", "type": "m.reaction"}])
+
+ @unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
+ def test_redact_parent_thread(self) -> None:
+ """
+ Test that thread replies are still available when the root event is redacted.
+ """
+ channel = self._send_relation(
+ RelationTypes.THREAD,
+ EventTypes.Message,
+ content={"body": "reply 1", "msgtype": "m.text"},
+ )
+ related_event_id = channel.json_body["event_id"]
+
+ # Redact one of the reactions.
+ self._redact(self.parent_id)
+
+ # The unredacted relation should still exist.
+ event_ids = self._get_related_events()
+ relations = self._get_bundled_aggregations()
+ self.assertEquals(len(event_ids), 1)
+ self.assertDictContainsSubset(
+ {
+ "count": 1,
+ "current_user_participated": True,
+ },
+ relations[RelationTypes.THREAD],
+ )
+ self.assertEqual(
+ relations[RelationTypes.THREAD]["latest_event"]["event_id"],
+ related_event_id,
+ )
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index f3bf8d0934..7b8fe6d025 100644
--- a/tests/rest/client/test_retention.py
+++ b/tests/rest/client/test_retention.py
@@ -24,6 +24,7 @@ from synapse.util import Clock
from synapse.visibility import filter_events_for_client
from tests import unittest
+from tests.unittest import override_config
one_hour_ms = 3600000
one_day_ms = one_hour_ms * 24
@@ -38,7 +39,10 @@ class RetentionTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
- config["retention"] = {
+
+ # merge this default retention config with anything that was specified in
+ # @override_config
+ retention_config = {
"enabled": True,
"default_policy": {
"min_lifetime": one_day_ms,
@@ -47,6 +51,8 @@ class RetentionTestCase(unittest.HomeserverTestCase):
"allowed_lifetime_min": one_day_ms,
"allowed_lifetime_max": one_day_ms * 3,
}
+ retention_config.update(config.get("retention", {}))
+ config["retention"] = retention_config
self.hs = self.setup_test_homeserver(config=config)
@@ -115,22 +121,20 @@ class RetentionTestCase(unittest.HomeserverTestCase):
self._test_retention_event_purged(room_id, one_day_ms * 2)
+ @override_config({"retention": {"purge_jobs": [{"interval": "5d"}]}})
def test_visibility(self) -> None:
"""Tests that synapse.visibility.filter_events_for_client correctly filters out
- outdated events
+ outdated events, even if the purge job hasn't got to them yet.
+
+ We do this by setting a very long time between purge jobs.
"""
store = self.hs.get_datastores().main
storage = self.hs.get_storage()
room_id = self.helper.create_room_as(self.user_id, tok=self.token)
- events = []
# Send a first event, which should be filtered out at the end of the test.
resp = self.helper.send(room_id=room_id, body="1", tok=self.token)
-
- # Get the event from the store so that we end up with a FrozenEvent that we can
- # give to filter_events_for_client. We need to do this now because the event won't
- # be in the database anymore after it has expired.
- events.append(self.get_success(store.get_event(resp.get("event_id"))))
+ first_event_id = resp.get("event_id")
# Advance the time by 2 days. We're using the default retention policy, therefore
# after this the first event will still be valid.
@@ -138,16 +142,17 @@ class RetentionTestCase(unittest.HomeserverTestCase):
# Send another event, which shouldn't get filtered out.
resp = self.helper.send(room_id=room_id, body="2", tok=self.token)
-
valid_event_id = resp.get("event_id")
- events.append(self.get_success(store.get_event(valid_event_id)))
-
# Advance the time by another 2 days. After this, the first event should be
# outdated but not the second one.
self.reactor.advance(one_day_ms * 2 / 1000)
- # Run filter_events_for_client with our list of FrozenEvents.
+ # Fetch the events, and run filter_events_for_client on them
+ events = self.get_success(
+ store.get_events_as_list([first_event_id, valid_event_id])
+ )
+ self.assertEqual(2, len(events), "events retrieved from database")
filtered_events = self.get_success(
filter_events_for_client(storage, self.user_id, events)
)
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index 37866ee330..3a9617d6da 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -2141,21 +2141,19 @@ class RelationsTestCase(unittest.HomeserverTestCase):
def test_filter_relation_senders(self) -> None:
# Messages which second user reacted to.
- filter = {"io.element.relation_senders": [self.second_user_id]}
+ filter = {"related_by_senders": [self.second_user_id]}
chunk = self._filter_messages(filter)
self.assertEqual(len(chunk), 1, chunk)
self.assertEqual(chunk[0]["event_id"], self.event_id_1)
# Messages which third user reacted to.
- filter = {"io.element.relation_senders": [self.third_user_id]}
+ filter = {"related_by_senders": [self.third_user_id]}
chunk = self._filter_messages(filter)
self.assertEqual(len(chunk), 1, chunk)
self.assertEqual(chunk[0]["event_id"], self.event_id_2)
# Messages which either user reacted to.
- filter = {
- "io.element.relation_senders": [self.second_user_id, self.third_user_id]
- }
+ filter = {"related_by_senders": [self.second_user_id, self.third_user_id]}
chunk = self._filter_messages(filter)
self.assertEqual(len(chunk), 2, chunk)
self.assertCountEqual(
@@ -2164,20 +2162,20 @@ class RelationsTestCase(unittest.HomeserverTestCase):
def test_filter_relation_type(self) -> None:
# Messages which have annotations.
- filter = {"io.element.relation_types": [RelationTypes.ANNOTATION]}
+ filter = {"related_by_rel_types": [RelationTypes.ANNOTATION]}
chunk = self._filter_messages(filter)
self.assertEqual(len(chunk), 1, chunk)
self.assertEqual(chunk[0]["event_id"], self.event_id_1)
# Messages which have references.
- filter = {"io.element.relation_types": [RelationTypes.REFERENCE]}
+ filter = {"related_by_rel_types": [RelationTypes.REFERENCE]}
chunk = self._filter_messages(filter)
self.assertEqual(len(chunk), 1, chunk)
self.assertEqual(chunk[0]["event_id"], self.event_id_2)
# Messages which have either annotations or references.
filter = {
- "io.element.relation_types": [
+ "related_by_rel_types": [
RelationTypes.ANNOTATION,
RelationTypes.REFERENCE,
]
@@ -2191,8 +2189,8 @@ class RelationsTestCase(unittest.HomeserverTestCase):
def test_filter_relation_senders_and_type(self) -> None:
# Messages which second user reacted to.
filter = {
- "io.element.relation_senders": [self.second_user_id],
- "io.element.relation_types": [RelationTypes.ANNOTATION],
+ "related_by_senders": [self.second_user_id],
+ "related_by_rel_types": [RelationTypes.ANNOTATION],
}
chunk = self._filter_messages(filter)
self.assertEqual(len(chunk), 1, chunk)
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index 58f1ea11b7..e7de67e3a3 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -775,3 +775,124 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
self.assertEqual(args[0], user_id)
self.assertFalse(args[1])
self.assertTrue(args[2])
+
+ def test_check_can_deactivate_user(self) -> None:
+ """Tests that the on_user_deactivation_status_changed module callback is called
+ correctly when processing a user's deactivation.
+ """
+ # Register a mocked callback.
+ deactivation_mock = Mock(return_value=make_awaitable(False))
+ third_party_rules = self.hs.get_third_party_event_rules()
+ third_party_rules._check_can_deactivate_user_callbacks.append(
+ deactivation_mock,
+ )
+
+ # Register a user that we'll deactivate.
+ user_id = self.register_user("altan", "password")
+ tok = self.login("altan", "password")
+
+ # Deactivate that user.
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/v3/account/deactivate",
+ {
+ "auth": {
+ "type": LoginType.PASSWORD,
+ "password": "password",
+ "identifier": {
+ "type": "m.id.user",
+ "user": user_id,
+ },
+ },
+ "erase": True,
+ },
+ access_token=tok,
+ )
+
+ # Check that the deactivation was blocked
+ self.assertEqual(channel.code, 403, channel.json_body)
+
+ # Check that the mock was called once.
+ deactivation_mock.assert_called_once()
+ args = deactivation_mock.call_args[0]
+
+ # Check that the mock was called with the right user ID
+ self.assertEqual(args[0], user_id)
+
+ # Check that the request was not made by an admin
+ self.assertEqual(args[1], False)
+
+ def test_check_can_deactivate_user_admin(self) -> None:
+ """Tests that the on_user_deactivation_status_changed module callback is called
+ correctly when processing a user's deactivation triggered by a server admin.
+ """
+ # Register a mocked callback.
+ deactivation_mock = Mock(return_value=make_awaitable(False))
+ third_party_rules = self.hs.get_third_party_event_rules()
+ third_party_rules._check_can_deactivate_user_callbacks.append(
+ deactivation_mock,
+ )
+
+ # Register an admin user.
+ self.register_user("admin", "password", admin=True)
+ admin_tok = self.login("admin", "password")
+
+ # Register a user that we'll deactivate.
+ user_id = self.register_user("altan", "password")
+
+ # Deactivate the user.
+ channel = self.make_request(
+ "PUT",
+ "/_synapse/admin/v2/users/%s" % user_id,
+ {"deactivated": True},
+ access_token=admin_tok,
+ )
+
+ # Check that the deactivation was blocked
+ self.assertEqual(channel.code, 403, channel.json_body)
+
+ # Check that the mock was called once.
+ deactivation_mock.assert_called_once()
+ args = deactivation_mock.call_args[0]
+
+ # Check that the mock was called with the right user ID
+ self.assertEqual(args[0], user_id)
+
+ # Check that the mock was made by an admin
+ self.assertEqual(args[1], True)
+
+ def test_check_can_shutdown_room(self) -> None:
+ """Tests that the check_can_shutdown_room module callback is called
+ correctly when processing an admin's shutdown room request.
+ """
+ # Register a mocked callback.
+ shutdown_mock = Mock(return_value=make_awaitable(False))
+ third_party_rules = self.hs.get_third_party_event_rules()
+ third_party_rules._check_can_shutdown_room_callbacks.append(
+ shutdown_mock,
+ )
+
+ # Register an admin user.
+ admin_user_id = self.register_user("admin", "password", admin=True)
+ admin_tok = self.login("admin", "password")
+
+ # Shutdown the room.
+ channel = self.make_request(
+ "DELETE",
+ "/_synapse/admin/v2/rooms/%s" % self.room_id,
+ {},
+ access_token=admin_tok,
+ )
+
+ # Check that the shutdown was blocked
+ self.assertEqual(channel.code, 403, channel.json_body)
+
+ # Check that the mock was called once.
+ shutdown_mock.assert_called_once()
+ args = shutdown_mock.call_args[0]
+
+ # Check that the mock was called with the right user ID
+ self.assertEqual(args[0], admin_user_id)
+
+ # Check that the mock was called with the right room ID
+ self.assertEqual(args[1], self.room_id)
diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py
index 3b5747cb12..8d8251b2ac 100644
--- a/tests/rest/client/test_transactions.py
+++ b/tests/rest/client/test_transactions.py
@@ -1,3 +1,18 @@
+# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from http import HTTPStatus
from unittest.mock import Mock, call
from twisted.internet import defer, reactor
@@ -11,14 +26,14 @@ from tests.utils import MockClock
class HttpTransactionCacheTestCase(unittest.TestCase):
- def setUp(self):
+ def setUp(self) -> None:
self.clock = MockClock()
self.hs = Mock()
self.hs.get_clock = Mock(return_value=self.clock)
self.hs.get_auth = Mock()
self.cache = HttpTransactionCache(self.hs)
- self.mock_http_response = (200, "GOOD JOB!")
+ self.mock_http_response = (HTTPStatus.OK, "GOOD JOB!")
self.mock_key = "foo"
@defer.inlineCallbacks
diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py
index 4672a68596..978c252f84 100644
--- a/tests/rest/key/v2/test_remote_key_resource.py
+++ b/tests/rest/key/v2/test_remote_key_resource.py
@@ -13,19 +13,24 @@
# limitations under the License.
import urllib.parse
from io import BytesIO, StringIO
+from typing import Any, Dict, Optional, Union
from unittest.mock import Mock
import signedjson.key
from canonicaljson import encode_canonical_json
-from nacl.signing import SigningKey
from signedjson.sign import sign_json
+from signedjson.types import SigningKey
-from twisted.web.resource import NoResource
+from twisted.test.proto_helpers import MemoryReactor
+from twisted.web.resource import NoResource, Resource
from synapse.crypto.keyring import PerspectivesKeyFetcher
from synapse.http.site import SynapseRequest
from synapse.rest.key.v2 import KeyApiV2Resource
+from synapse.server import HomeServer
from synapse.storage.keys import FetchKeyResult
+from synapse.types import JsonDict
+from synapse.util import Clock
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.stringutils import random_string
@@ -35,11 +40,11 @@ from tests.utils import default_config
class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase):
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.http_client = Mock()
return self.setup_test_homeserver(federation_http_client=self.http_client)
- def create_test_resource(self):
+ def create_test_resource(self) -> Resource:
return create_resource_tree(
{"/_matrix/key/v2": KeyApiV2Resource(self.hs)}, root_resource=NoResource()
)
@@ -51,7 +56,12 @@ class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase):
Tell the mock http client to expect an outgoing GET request for the given key
"""
- async def get_json(destination, path, ignore_backoff=False, **kwargs):
+ async def get_json(
+ destination: str,
+ path: str,
+ ignore_backoff: bool = False,
+ **kwargs: Any,
+ ) -> Union[JsonDict, list]:
self.assertTrue(ignore_backoff)
self.assertEqual(destination, server_name)
key_id = "%s:%s" % (signing_key.alg, signing_key.version)
@@ -84,7 +94,8 @@ class RemoteKeyResourceTestCase(BaseRemoteKeyResourceTestCase):
Checks that the response is a 200 and returns the decoded json body.
"""
channel = FakeChannel(self.site, self.reactor)
- req = SynapseRequest(channel, self.site)
+ # channel is a `FakeChannel` but `HTTPChannel` is expected
+ req = SynapseRequest(channel, self.site) # type: ignore[arg-type]
req.content = BytesIO(b"")
req.requestReceived(
b"GET",
@@ -97,7 +108,7 @@ class RemoteKeyResourceTestCase(BaseRemoteKeyResourceTestCase):
resp = channel.json_body
return resp
- def test_get_key(self):
+ def test_get_key(self) -> None:
"""Fetch a remote key"""
SERVER_NAME = "remote.server"
testkey = signedjson.key.generate_signing_key("ver1")
@@ -114,7 +125,7 @@ class RemoteKeyResourceTestCase(BaseRemoteKeyResourceTestCase):
self.assertIn(SERVER_NAME, keys[0]["signatures"])
self.assertIn(self.hs.hostname, keys[0]["signatures"])
- def test_get_own_key(self):
+ def test_get_own_key(self) -> None:
"""Fetch our own key"""
testkey = signedjson.key.generate_signing_key("ver1")
self.expect_outgoing_key_request(self.hs.hostname, testkey)
@@ -141,7 +152,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
endpoint, to check that the two implementations are compatible.
"""
- def default_config(self):
+ def default_config(self) -> Dict[str, Any]:
config = super().default_config()
# replace the signing key with our own
@@ -152,7 +163,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
return config
- def prepare(self, reactor, clock, homeserver):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# make a second homeserver, configured to use the first one as a key notary
self.http_client2 = Mock()
config = default_config(name="keyclient")
@@ -175,7 +186,9 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
# wire up outbound POST /key/v2/query requests from hs2 so that they
# will be forwarded to hs1
- async def post_json(destination, path, data):
+ async def post_json(
+ destination: str, path: str, data: Optional[JsonDict] = None
+ ) -> Union[JsonDict, list]:
self.assertEqual(destination, self.hs.hostname)
self.assertEqual(
path,
@@ -183,7 +196,8 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
)
channel = FakeChannel(self.site, self.reactor)
- req = SynapseRequest(channel, self.site)
+ # channel is a `FakeChannel` but `HTTPChannel` is expected
+ req = SynapseRequest(channel, self.site) # type: ignore[arg-type]
req.content = BytesIO(encode_canonical_json(data))
req.requestReceived(
@@ -198,7 +212,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
self.http_client2.post_json.side_effect = post_json
- def test_get_key(self):
+ def test_get_key(self) -> None:
"""Fetch a key belonging to a random server"""
# make up a key to be fetched.
testkey = signedjson.key.generate_signing_key("abc")
@@ -218,7 +232,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
signedjson.key.encode_verify_key_base64(testkey.verify_key),
)
- def test_get_notary_key(self):
+ def test_get_notary_key(self) -> None:
"""Fetch a key belonging to the notary server"""
# make up a key to be fetched. We randomise the keyid to try to get it to
# appear before the key server signing key sometimes (otherwise we bail out
@@ -240,7 +254,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
signedjson.key.encode_verify_key_base64(testkey.verify_key),
)
- def test_get_notary_keyserver_key(self):
+ def test_get_notary_keyserver_key(self) -> None:
"""Fetch the notary's keyserver key"""
# we expect hs1 to make a regular key request to itself
self.expect_outgoing_key_request(self.hs.hostname, self.hs_signing_key)
diff --git a/tests/rest/media/v1/test_base.py b/tests/rest/media/v1/test_base.py
index f761e23f1b..c73179151a 100644
--- a/tests/rest/media/v1/test_base.py
+++ b/tests/rest/media/v1/test_base.py
@@ -28,11 +28,11 @@ class GetFileNameFromHeadersTests(unittest.TestCase):
b"inline; filename*=utf-8''foo%C2%A3bar": "foo£bar",
}
- def tests(self):
+ def tests(self) -> None:
for hdr, expected in self.TEST_CASES.items():
res = get_filename_from_headers({b"Content-Disposition": [hdr]})
self.assertEqual(
res,
expected,
- "expected output for %s to be %s but was %s" % (hdr, expected, res),
+ f"expected output for {hdr!r} to be {expected} but was {res}",
)
diff --git a/tests/rest/media/v1/test_filepath.py b/tests/rest/media/v1/test_filepath.py
index 913bc530aa..43e6f0f70a 100644
--- a/tests/rest/media/v1/test_filepath.py
+++ b/tests/rest/media/v1/test_filepath.py
@@ -21,12 +21,12 @@ from tests import unittest
class MediaFilePathsTestCase(unittest.TestCase):
- def setUp(self):
+ def setUp(self) -> None:
super().setUp()
self.filepaths = MediaFilePaths("/media_store")
- def test_local_media_filepath(self):
+ def test_local_media_filepath(self) -> None:
"""Test local media paths"""
self.assertEqual(
self.filepaths.local_media_filepath_rel("GerZNDnDZVjsOtardLuwfIBg"),
@@ -37,7 +37,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
"/media_store/local_content/Ge/rZ/NDnDZVjsOtardLuwfIBg",
)
- def test_local_media_thumbnail(self):
+ def test_local_media_thumbnail(self) -> None:
"""Test local media thumbnail paths"""
self.assertEqual(
self.filepaths.local_media_thumbnail_rel(
@@ -52,14 +52,14 @@ class MediaFilePathsTestCase(unittest.TestCase):
"/media_store/local_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale",
)
- def test_local_media_thumbnail_dir(self):
+ def test_local_media_thumbnail_dir(self) -> None:
"""Test local media thumbnail directory paths"""
self.assertEqual(
self.filepaths.local_media_thumbnail_dir("GerZNDnDZVjsOtardLuwfIBg"),
"/media_store/local_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg",
)
- def test_remote_media_filepath(self):
+ def test_remote_media_filepath(self) -> None:
"""Test remote media paths"""
self.assertEqual(
self.filepaths.remote_media_filepath_rel(
@@ -74,7 +74,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
"/media_store/remote_content/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg",
)
- def test_remote_media_thumbnail(self):
+ def test_remote_media_thumbnail(self) -> None:
"""Test remote media thumbnail paths"""
self.assertEqual(
self.filepaths.remote_media_thumbnail_rel(
@@ -99,7 +99,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
"/media_store/remote_thumbnail/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale",
)
- def test_remote_media_thumbnail_legacy(self):
+ def test_remote_media_thumbnail_legacy(self) -> None:
"""Test old-style remote media thumbnail paths"""
self.assertEqual(
self.filepaths.remote_media_thumbnail_rel_legacy(
@@ -108,7 +108,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
"remote_thumbnail/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg",
)
- def test_remote_media_thumbnail_dir(self):
+ def test_remote_media_thumbnail_dir(self) -> None:
"""Test remote media thumbnail directory paths"""
self.assertEqual(
self.filepaths.remote_media_thumbnail_dir(
@@ -117,7 +117,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
"/media_store/remote_thumbnail/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg",
)
- def test_url_cache_filepath(self):
+ def test_url_cache_filepath(self) -> None:
"""Test URL cache paths"""
self.assertEqual(
self.filepaths.url_cache_filepath_rel("2020-01-02_GerZNDnDZVjsOtar"),
@@ -128,7 +128,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
"/media_store/url_cache/2020-01-02/GerZNDnDZVjsOtar",
)
- def test_url_cache_filepath_legacy(self):
+ def test_url_cache_filepath_legacy(self) -> None:
"""Test old-style URL cache paths"""
self.assertEqual(
self.filepaths.url_cache_filepath_rel("GerZNDnDZVjsOtardLuwfIBg"),
@@ -139,7 +139,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
"/media_store/url_cache/Ge/rZ/NDnDZVjsOtardLuwfIBg",
)
- def test_url_cache_filepath_dirs_to_delete(self):
+ def test_url_cache_filepath_dirs_to_delete(self) -> None:
"""Test URL cache cleanup paths"""
self.assertEqual(
self.filepaths.url_cache_filepath_dirs_to_delete(
@@ -148,7 +148,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
["/media_store/url_cache/2020-01-02"],
)
- def test_url_cache_filepath_dirs_to_delete_legacy(self):
+ def test_url_cache_filepath_dirs_to_delete_legacy(self) -> None:
"""Test old-style URL cache cleanup paths"""
self.assertEqual(
self.filepaths.url_cache_filepath_dirs_to_delete(
@@ -160,7 +160,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
],
)
- def test_url_cache_thumbnail(self):
+ def test_url_cache_thumbnail(self) -> None:
"""Test URL cache thumbnail paths"""
self.assertEqual(
self.filepaths.url_cache_thumbnail_rel(
@@ -175,7 +175,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
"/media_store/url_cache_thumbnails/2020-01-02/GerZNDnDZVjsOtar/800-600-image-jpeg-scale",
)
- def test_url_cache_thumbnail_legacy(self):
+ def test_url_cache_thumbnail_legacy(self) -> None:
"""Test old-style URL cache thumbnail paths"""
self.assertEqual(
self.filepaths.url_cache_thumbnail_rel(
@@ -190,7 +190,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
"/media_store/url_cache_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale",
)
- def test_url_cache_thumbnail_directory(self):
+ def test_url_cache_thumbnail_directory(self) -> None:
"""Test URL cache thumbnail directory paths"""
self.assertEqual(
self.filepaths.url_cache_thumbnail_directory_rel(
@@ -203,7 +203,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
"/media_store/url_cache_thumbnails/2020-01-02/GerZNDnDZVjsOtar",
)
- def test_url_cache_thumbnail_directory_legacy(self):
+ def test_url_cache_thumbnail_directory_legacy(self) -> None:
"""Test old-style URL cache thumbnail directory paths"""
self.assertEqual(
self.filepaths.url_cache_thumbnail_directory_rel(
@@ -216,7 +216,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
"/media_store/url_cache_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg",
)
- def test_url_cache_thumbnail_dirs_to_delete(self):
+ def test_url_cache_thumbnail_dirs_to_delete(self) -> None:
"""Test URL cache thumbnail cleanup paths"""
self.assertEqual(
self.filepaths.url_cache_thumbnail_dirs_to_delete(
@@ -228,7 +228,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
],
)
- def test_url_cache_thumbnail_dirs_to_delete_legacy(self):
+ def test_url_cache_thumbnail_dirs_to_delete_legacy(self) -> None:
"""Test old-style URL cache thumbnail cleanup paths"""
self.assertEqual(
self.filepaths.url_cache_thumbnail_dirs_to_delete(
@@ -241,7 +241,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
],
)
- def test_server_name_validation(self):
+ def test_server_name_validation(self) -> None:
"""Test validation of server names"""
self._test_path_validation(
[
@@ -274,7 +274,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
],
)
- def test_file_id_validation(self):
+ def test_file_id_validation(self) -> None:
"""Test validation of local, remote and legacy URL cache file / media IDs"""
# File / media IDs get split into three parts to form paths, consisting of the
# first two characters, next two characters and rest of the ID.
@@ -357,7 +357,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
invalid_values=invalid_file_ids,
)
- def test_url_cache_media_id_validation(self):
+ def test_url_cache_media_id_validation(self) -> None:
"""Test validation of URL cache media IDs"""
self._test_path_validation(
[
@@ -387,7 +387,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
],
)
- def test_content_type_validation(self):
+ def test_content_type_validation(self) -> None:
"""Test validation of thumbnail content types"""
self._test_path_validation(
[
@@ -410,7 +410,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
],
)
- def test_thumbnail_method_validation(self):
+ def test_thumbnail_method_validation(self) -> None:
"""Test validation of thumbnail methods"""
self._test_path_validation(
[
@@ -440,7 +440,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
parameter: str,
valid_values: Iterable[str],
invalid_values: Iterable[str],
- ):
+ ) -> None:
"""Test that the specified methods validate the named parameter as expected
Args:
diff --git a/tests/rest/media/v1/test_html_preview.py b/tests/rest/media/v1/test_html_preview.py
index a4b57e3d1f..62e308814d 100644
--- a/tests/rest/media/v1/test_html_preview.py
+++ b/tests/rest/media/v1/test_html_preview.py
@@ -16,7 +16,6 @@ from synapse.rest.media.v1.preview_html import (
_get_html_media_encodings,
decode_body,
parse_html_to_open_graph,
- rebase_url,
summarize_paragraphs,
)
@@ -32,7 +31,7 @@ class SummarizeTestCase(unittest.TestCase):
if not lxml:
skip = "url preview feature requires lxml"
- def test_long_summarize(self):
+ def test_long_summarize(self) -> None:
example_paras = [
"""Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:
Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in
@@ -90,7 +89,7 @@ class SummarizeTestCase(unittest.TestCase):
" Tromsøya had a population of 36,088. Substantial parts of the urban…",
)
- def test_short_summarize(self):
+ def test_short_summarize(self) -> None:
example_paras = [
"Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
" Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
@@ -117,7 +116,7 @@ class SummarizeTestCase(unittest.TestCase):
" most of the year.",
)
- def test_small_then_large_summarize(self):
+ def test_small_then_large_summarize(self) -> None:
example_paras = [
"Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
" Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
@@ -150,7 +149,7 @@ class CalcOgTestCase(unittest.TestCase):
if not lxml:
skip = "url preview feature requires lxml"
- def test_simple(self):
+ def test_simple(self) -> None:
html = b"""
<html>
<head><title>Foo</title></head>
@@ -161,11 +160,11 @@ class CalcOgTestCase(unittest.TestCase):
"""
tree = decode_body(html, "http://example.com/test.html")
- og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+ og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
- def test_comment(self):
+ def test_comment(self) -> None:
html = b"""
<html>
<head><title>Foo</title></head>
@@ -177,11 +176,11 @@ class CalcOgTestCase(unittest.TestCase):
"""
tree = decode_body(html, "http://example.com/test.html")
- og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+ og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
- def test_comment2(self):
+ def test_comment2(self) -> None:
html = b"""
<html>
<head><title>Foo</title></head>
@@ -196,7 +195,7 @@ class CalcOgTestCase(unittest.TestCase):
"""
tree = decode_body(html, "http://example.com/test.html")
- og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+ og = parse_html_to_open_graph(tree)
self.assertEqual(
og,
@@ -206,7 +205,7 @@ class CalcOgTestCase(unittest.TestCase):
},
)
- def test_script(self):
+ def test_script(self) -> None:
html = b"""
<html>
<head><title>Foo</title></head>
@@ -218,11 +217,11 @@ class CalcOgTestCase(unittest.TestCase):
"""
tree = decode_body(html, "http://example.com/test.html")
- og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+ og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
- def test_missing_title(self):
+ def test_missing_title(self) -> None:
html = b"""
<html>
<body>
@@ -232,11 +231,11 @@ class CalcOgTestCase(unittest.TestCase):
"""
tree = decode_body(html, "http://example.com/test.html")
- og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+ og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
- def test_h1_as_title(self):
+ def test_h1_as_title(self) -> None:
html = b"""
<html>
<meta property="og:description" content="Some text."/>
@@ -247,11 +246,11 @@ class CalcOgTestCase(unittest.TestCase):
"""
tree = decode_body(html, "http://example.com/test.html")
- og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+ og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."})
- def test_missing_title_and_broken_h1(self):
+ def test_missing_title_and_broken_h1(self) -> None:
html = b"""
<html>
<body>
@@ -262,23 +261,23 @@ class CalcOgTestCase(unittest.TestCase):
"""
tree = decode_body(html, "http://example.com/test.html")
- og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+ og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
- def test_empty(self):
+ def test_empty(self) -> None:
"""Test a body with no data in it."""
html = b""
tree = decode_body(html, "http://example.com/test.html")
self.assertIsNone(tree)
- def test_no_tree(self):
+ def test_no_tree(self) -> None:
"""A valid body with no tree in it."""
html = b"\x00"
tree = decode_body(html, "http://example.com/test.html")
self.assertIsNone(tree)
- def test_xml(self):
+ def test_xml(self) -> None:
"""Test decoding XML and ensure it works properly."""
# Note that the strip() call is important to ensure the xml tag starts
# at the initial byte.
@@ -290,10 +289,10 @@ class CalcOgTestCase(unittest.TestCase):
<head><title>Foo</title></head><body>Some text.</body></html>
""".strip()
tree = decode_body(html, "http://example.com/test.html")
- og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+ og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
- def test_invalid_encoding(self):
+ def test_invalid_encoding(self) -> None:
"""An invalid character encoding should be ignored and treated as UTF-8, if possible."""
html = b"""
<html>
@@ -304,10 +303,10 @@ class CalcOgTestCase(unittest.TestCase):
</html>
"""
tree = decode_body(html, "http://example.com/test.html", "invalid-encoding")
- og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+ og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
- def test_invalid_encoding2(self):
+ def test_invalid_encoding2(self) -> None:
"""A body which doesn't match the sent character encoding."""
# Note that this contains an invalid UTF-8 sequence in the title.
html = b"""
@@ -319,10 +318,10 @@ class CalcOgTestCase(unittest.TestCase):
</html>
"""
tree = decode_body(html, "http://example.com/test.html")
- og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+ og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": "ÿÿ Foo", "og:description": "Some text."})
- def test_windows_1252(self):
+ def test_windows_1252(self) -> None:
"""A body which uses cp1252, but doesn't declare that."""
html = b"""
<html>
@@ -333,12 +332,12 @@ class CalcOgTestCase(unittest.TestCase):
</html>
"""
tree = decode_body(html, "http://example.com/test.html")
- og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+ og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": "ó", "og:description": "Some text."})
class MediaEncodingTestCase(unittest.TestCase):
- def test_meta_charset(self):
+ def test_meta_charset(self) -> None:
"""A character encoding is found via the meta tag."""
encodings = _get_html_media_encodings(
b"""
@@ -363,7 +362,7 @@ class MediaEncodingTestCase(unittest.TestCase):
)
self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"])
- def test_meta_charset_underscores(self):
+ def test_meta_charset_underscores(self) -> None:
"""A character encoding contains underscore."""
encodings = _get_html_media_encodings(
b"""
@@ -376,7 +375,7 @@ class MediaEncodingTestCase(unittest.TestCase):
)
self.assertEqual(list(encodings), ["shift_jis", "utf-8", "cp1252"])
- def test_xml_encoding(self):
+ def test_xml_encoding(self) -> None:
"""A character encoding is found via the meta tag."""
encodings = _get_html_media_encodings(
b"""
@@ -388,7 +387,7 @@ class MediaEncodingTestCase(unittest.TestCase):
)
self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"])
- def test_meta_xml_encoding(self):
+ def test_meta_xml_encoding(self) -> None:
"""Meta tags take precedence over XML encoding."""
encodings = _get_html_media_encodings(
b"""
@@ -402,7 +401,7 @@ class MediaEncodingTestCase(unittest.TestCase):
)
self.assertEqual(list(encodings), ["utf-16", "ascii", "utf-8", "cp1252"])
- def test_content_type(self):
+ def test_content_type(self) -> None:
"""A character encoding is found via the Content-Type header."""
# Test a few variations of the header.
headers = (
@@ -417,12 +416,12 @@ class MediaEncodingTestCase(unittest.TestCase):
encodings = _get_html_media_encodings(b"", header)
self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"])
- def test_fallback(self):
+ def test_fallback(self) -> None:
"""A character encoding cannot be found in the body or header."""
encodings = _get_html_media_encodings(b"", "text/html")
self.assertEqual(list(encodings), ["utf-8", "cp1252"])
- def test_duplicates(self):
+ def test_duplicates(self) -> None:
"""Ensure each encoding is only attempted once."""
encodings = _get_html_media_encodings(
b"""
@@ -436,7 +435,7 @@ class MediaEncodingTestCase(unittest.TestCase):
)
self.assertEqual(list(encodings), ["utf-8", "cp1252"])
- def test_unknown_invalid(self):
+ def test_unknown_invalid(self) -> None:
"""A character encoding should be ignored if it is unknown or invalid."""
encodings = _get_html_media_encodings(
b"""
@@ -448,34 +447,3 @@ class MediaEncodingTestCase(unittest.TestCase):
'text/html; charset="invalid"',
)
self.assertEqual(list(encodings), ["utf-8", "cp1252"])
-
-
-class RebaseUrlTestCase(unittest.TestCase):
- def test_relative(self):
- """Relative URLs should be resolved based on the context of the base URL."""
- self.assertEqual(
- rebase_url("subpage", "https://example.com/foo/"),
- "https://example.com/foo/subpage",
- )
- self.assertEqual(
- rebase_url("sibling", "https://example.com/foo"),
- "https://example.com/sibling",
- )
- self.assertEqual(
- rebase_url("/bar", "https://example.com/foo/"),
- "https://example.com/bar",
- )
-
- def test_absolute(self):
- """Absolute URLs should not be modified."""
- self.assertEqual(
- rebase_url("https://alice.com/a/", "https://example.com/foo/"),
- "https://alice.com/a/",
- )
-
- def test_data(self):
- """Data URLs should not be modified."""
- self.assertEqual(
- rebase_url("data:,Hello%2C%20World%21", "https://example.com/foo/"),
- "data:,Hello%2C%20World%21",
- )
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index cba9be17c4..7204b2dfe0 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -16,7 +16,7 @@ import shutil
import tempfile
from binascii import unhexlify
from io import BytesIO
-from typing import Optional
+from typing import Any, BinaryIO, Dict, List, Optional, Union
from unittest.mock import Mock
from urllib import parse
@@ -26,18 +26,24 @@ from PIL import Image as Image
from twisted.internet import defer
from twisted.internet.defer import Deferred
+from twisted.test.proto_helpers import MemoryReactor
+from synapse.events import EventBase
from synapse.events.spamcheck import load_legacy_spam_checkers
from synapse.logging.context import make_deferred_yieldable
+from synapse.module_api import ModuleApi
from synapse.rest import admin
from synapse.rest.client import login
from synapse.rest.media.v1._base import FileInfo
from synapse.rest.media.v1.filepath import MediaFilePaths
-from synapse.rest.media.v1.media_storage import MediaStorage
+from synapse.rest.media.v1.media_storage import MediaStorage, ReadableFileWrapper
from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend
+from synapse.server import HomeServer
+from synapse.types import RoomAlias
+from synapse.util import Clock
from tests import unittest
-from tests.server import FakeSite, make_request
+from tests.server import FakeChannel, FakeSite, make_request
from tests.test_utils import SMALL_PNG
from tests.utils import default_config
@@ -46,7 +52,7 @@ class MediaStorageTests(unittest.HomeserverTestCase):
needs_threadpool = True
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.test_dir = tempfile.mkdtemp(prefix="synapse-tests-")
self.addCleanup(shutil.rmtree, self.test_dir)
@@ -62,7 +68,7 @@ class MediaStorageTests(unittest.HomeserverTestCase):
hs, self.primary_base_path, self.filepaths, storage_providers
)
- def test_ensure_media_is_in_local_cache(self):
+ def test_ensure_media_is_in_local_cache(self) -> None:
media_id = "some_media_id"
test_body = "Test\n"
@@ -105,7 +111,7 @@ class MediaStorageTests(unittest.HomeserverTestCase):
self.assertEqual(test_body, body)
-@attr.s(slots=True, frozen=True)
+@attr.s(auto_attribs=True, slots=True, frozen=True)
class _TestImage:
"""An image for testing thumbnailing with the expected results
@@ -121,18 +127,18 @@ class _TestImage:
a 404 is expected.
"""
- data = attr.ib(type=bytes)
- content_type = attr.ib(type=bytes)
- extension = attr.ib(type=bytes)
- expected_cropped = attr.ib(type=Optional[bytes], default=None)
- expected_scaled = attr.ib(type=Optional[bytes], default=None)
- expected_found = attr.ib(default=True, type=bool)
+ data: bytes
+ content_type: bytes
+ extension: bytes
+ expected_cropped: Optional[bytes] = None
+ expected_scaled: Optional[bytes] = None
+ expected_found: bool = True
@parameterized_class(
("test_image",),
[
- # smoll png
+ # small png
(
_TestImage(
SMALL_PNG,
@@ -193,11 +199,17 @@ class MediaRepoTests(unittest.HomeserverTestCase):
hijack_auth = True
user_id = "@test:user"
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.fetches = []
- def get_file(destination, path, output_stream, args=None, max_size=None):
+ def get_file(
+ destination: str,
+ path: str,
+ output_stream: BinaryIO,
+ args: Optional[Dict[str, Union[str, List[str]]]] = None,
+ max_size: Optional[int] = None,
+ ) -> Deferred:
"""
Returns tuple[int,dict,str,int] of file length, response headers,
absolute URI, and response code.
@@ -238,7 +250,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
return hs
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
media_resource = hs.get_media_repository_resource()
self.download_resource = media_resource.children[b"download"]
@@ -248,8 +260,9 @@ class MediaRepoTests(unittest.HomeserverTestCase):
self.media_id = "example.com/12345"
- def _req(self, content_disposition, include_content_type=True):
-
+ def _req(
+ self, content_disposition: Optional[bytes], include_content_type: bool = True
+ ) -> FakeChannel:
channel = make_request(
self.reactor,
FakeSite(self.download_resource, self.reactor),
@@ -288,7 +301,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
return channel
- def test_handle_missing_content_type(self):
+ def test_handle_missing_content_type(self) -> None:
channel = self._req(
b"inline; filename=out" + self.test_image.extension,
include_content_type=False,
@@ -299,7 +312,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
headers.getRawHeaders(b"Content-Type"), [b"application/octet-stream"]
)
- def test_disposition_filename_ascii(self):
+ def test_disposition_filename_ascii(self) -> None:
"""
If the filename is filename=<ascii> then Synapse will decode it as an
ASCII string, and use filename= in the response.
@@ -315,7 +328,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
[b"inline; filename=out" + self.test_image.extension],
)
- def test_disposition_filenamestar_utf8escaped(self):
+ def test_disposition_filenamestar_utf8escaped(self) -> None:
"""
If the filename is filename=*utf8''<utf8 escaped> then Synapse will
correctly decode it as the UTF-8 string, and use filename* in the
@@ -335,7 +348,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
[b"inline; filename*=utf-8''" + filename + self.test_image.extension],
)
- def test_disposition_none(self):
+ def test_disposition_none(self) -> None:
"""
If there is no filename, one isn't passed on in the Content-Disposition
of the request.
@@ -348,26 +361,26 @@ class MediaRepoTests(unittest.HomeserverTestCase):
)
self.assertEqual(headers.getRawHeaders(b"Content-Disposition"), None)
- def test_thumbnail_crop(self):
+ def test_thumbnail_crop(self) -> None:
"""Test that a cropped remote thumbnail is available."""
self._test_thumbnail(
"crop", self.test_image.expected_cropped, self.test_image.expected_found
)
- def test_thumbnail_scale(self):
+ def test_thumbnail_scale(self) -> None:
"""Test that a scaled remote thumbnail is available."""
self._test_thumbnail(
"scale", self.test_image.expected_scaled, self.test_image.expected_found
)
- def test_invalid_type(self):
+ def test_invalid_type(self) -> None:
"""An invalid thumbnail type is never available."""
self._test_thumbnail("invalid", None, False)
@unittest.override_config(
{"thumbnail_sizes": [{"width": 32, "height": 32, "method": "scale"}]}
)
- def test_no_thumbnail_crop(self):
+ def test_no_thumbnail_crop(self) -> None:
"""
Override the config to generate only scaled thumbnails, but request a cropped one.
"""
@@ -376,13 +389,13 @@ class MediaRepoTests(unittest.HomeserverTestCase):
@unittest.override_config(
{"thumbnail_sizes": [{"width": 32, "height": 32, "method": "crop"}]}
)
- def test_no_thumbnail_scale(self):
+ def test_no_thumbnail_scale(self) -> None:
"""
Override the config to generate only cropped thumbnails, but request a scaled one.
"""
self._test_thumbnail("scale", None, False)
- def test_thumbnail_repeated_thumbnail(self):
+ def test_thumbnail_repeated_thumbnail(self) -> None:
"""Test that fetching the same thumbnail works, and deleting the on disk
thumbnail regenerates it.
"""
@@ -443,7 +456,9 @@ class MediaRepoTests(unittest.HomeserverTestCase):
channel.result["body"],
)
- def _test_thumbnail(self, method, expected_body, expected_found):
+ def _test_thumbnail(
+ self, method: str, expected_body: Optional[bytes], expected_found: bool
+ ) -> None:
params = "?width=32&height=32&method=" + method
channel = make_request(
self.reactor,
@@ -485,7 +500,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
)
@parameterized.expand([("crop", 16), ("crop", 64), ("scale", 16), ("scale", 64)])
- def test_same_quality(self, method, desired_size):
+ def test_same_quality(self, method: str, desired_size: int) -> None:
"""Test that choosing between thumbnails with the same quality rating succeeds.
We are not particular about which thumbnail is chosen."""
@@ -521,7 +536,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
)
)
- def test_x_robots_tag_header(self):
+ def test_x_robots_tag_header(self) -> None:
"""
Tests that the `X-Robots-Tag` header is present, which informs web crawlers
to not index, archive, or follow links in media.
@@ -540,29 +555,38 @@ class TestSpamChecker:
`evil`.
"""
- def __init__(self, config, api):
+ def __init__(self, config: Dict[str, Any], api: ModuleApi) -> None:
self.config = config
self.api = api
- def parse_config(config):
+ def parse_config(config: Dict[str, Any]) -> Dict[str, Any]:
return config
- async def check_event_for_spam(self, foo):
+ async def check_event_for_spam(self, event: EventBase) -> Union[bool, str]:
return False # allow all events
- async def user_may_invite(self, inviter_userid, invitee_userid, room_id):
+ async def user_may_invite(
+ self,
+ inviter_userid: str,
+ invitee_userid: str,
+ room_id: str,
+ ) -> bool:
return True # allow all invites
- async def user_may_create_room(self, userid):
+ async def user_may_create_room(self, userid: str) -> bool:
return True # allow all room creations
- async def user_may_create_room_alias(self, userid, room_alias):
+ async def user_may_create_room_alias(
+ self, userid: str, room_alias: RoomAlias
+ ) -> bool:
return True # allow all room aliases
- async def user_may_publish_room(self, userid, room_id):
+ async def user_may_publish_room(self, userid: str, room_id: str) -> bool:
return True # allow publishing of all rooms
- async def check_media_file_for_spam(self, file_wrapper, file_info) -> bool:
+ async def check_media_file_for_spam(
+ self, file_wrapper: ReadableFileWrapper, file_info: FileInfo
+ ) -> bool:
buf = BytesIO()
await file_wrapper.write_chunks_to(buf.write)
@@ -575,7 +599,7 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase):
admin.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.user = self.register_user("user", "pass")
self.tok = self.login("user", "pass")
@@ -586,7 +610,7 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase):
load_legacy_spam_checkers(hs)
- def default_config(self):
+ def default_config(self) -> Dict[str, Any]:
config = default_config("test")
config.update(
@@ -602,13 +626,13 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase):
return config
- def test_upload_innocent(self):
+ def test_upload_innocent(self) -> None:
"""Attempt to upload some innocent data that should be allowed."""
self.helper.upload_media(
self.upload_resource, SMALL_PNG, tok=self.tok, expect_code=200
)
- def test_upload_ban(self):
+ def test_upload_ban(self) -> None:
"""Attempt to upload some data that includes bytes "evil", which should
get rejected by the spam checker.
"""
diff --git a/tests/rest/media/v1/test_oembed.py b/tests/rest/media/v1/test_oembed.py
index 048d0ca44a..f38d7225f8 100644
--- a/tests/rest/media/v1/test_oembed.py
+++ b/tests/rest/media/v1/test_oembed.py
@@ -16,7 +16,7 @@ import json
from twisted.test.proto_helpers import MemoryReactor
-from synapse.rest.media.v1.oembed import OEmbedProvider
+from synapse.rest.media.v1.oembed import OEmbedProvider, OEmbedResult
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock
@@ -25,15 +25,15 @@ from tests.unittest import HomeserverTestCase
class OEmbedTests(HomeserverTestCase):
- def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer):
- self.oembed = OEmbedProvider(homeserver)
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.oembed = OEmbedProvider(hs)
- def parse_response(self, response: JsonDict):
+ def parse_response(self, response: JsonDict) -> OEmbedResult:
return self.oembed.parse_oembed_response(
"https://test", json.dumps(response).encode("utf-8")
)
- def test_version(self):
+ def test_version(self) -> None:
"""Accept versions that are similar to 1.0 as a string or int (or missing)."""
for version in ("1.0", 1.0, 1):
result = self.parse_response({"version": version, "type": "link"})
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
index da2c533260..5148c39874 100644
--- a/tests/rest/media/v1/test_url_preview.py
+++ b/tests/rest/media/v1/test_url_preview.py
@@ -16,16 +16,21 @@ import base64
import json
import os
import re
+from typing import Any, Dict, Optional, Sequence, Tuple, Type
from urllib.parse import urlencode
from twisted.internet._resolver import HostResolution
from twisted.internet.address import IPv4Address, IPv6Address
from twisted.internet.error import DNSLookupError
-from twisted.test.proto_helpers import AccumulatingProtocol
+from twisted.internet.interfaces import IAddress, IResolutionReceiver
+from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactor
from synapse.config.oembed import OEmbedEndpointConfig
+from synapse.rest.media.v1.media_repository import MediaRepositoryResource
from synapse.rest.media.v1.preview_url_resource import IMAGE_CACHE_EXPIRY_MS
+from synapse.server import HomeServer
from synapse.types import JsonDict
+from synapse.util import Clock
from synapse.util.stringutils import parse_and_validate_mxc_uri
from tests import unittest
@@ -52,7 +57,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
b"</head></html>"
)
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
config["url_preview_enabled"] = True
@@ -113,22 +118,22 @@ class URLPreviewTests(unittest.HomeserverTestCase):
return hs
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.media_repo = hs.get_media_repository_resource()
self.preview_url = self.media_repo.children[b"preview_url"]
- self.lookups = {}
+ self.lookups: Dict[str, Any] = {}
class Resolver:
def resolveHostName(
_self,
- resolutionReceiver,
- hostName,
- portNumber=0,
- addressTypes=None,
- transportSemantics="TCP",
- ):
+ resolutionReceiver: IResolutionReceiver,
+ hostName: str,
+ portNumber: int = 0,
+ addressTypes: Optional[Sequence[Type[IAddress]]] = None,
+ transportSemantics: str = "TCP",
+ ) -> IResolutionReceiver:
resolution = HostResolution(hostName)
resolutionReceiver.resolutionBegan(resolution)
@@ -140,9 +145,9 @@ class URLPreviewTests(unittest.HomeserverTestCase):
resolutionReceiver.resolutionComplete()
return resolutionReceiver
- self.reactor.nameResolver = Resolver()
+ self.reactor.nameResolver = Resolver() # type: ignore[assignment]
- def create_test_resource(self):
+ def create_test_resource(self) -> MediaRepositoryResource:
return self.hs.get_media_repository_resource()
def _assert_small_png(self, json_body: JsonDict) -> None:
@@ -153,7 +158,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertEqual(json_body["og:image:type"], "image/png")
self.assertEqual(json_body["matrix:image:size"], 67)
- def test_cache_returns_correct_type(self):
+ def test_cache_returns_correct_type(self) -> None:
self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
channel = self.make_request(
@@ -207,7 +212,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
)
- def test_non_ascii_preview_httpequiv(self):
+ def test_non_ascii_preview_httpequiv(self) -> None:
self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
end_content = (
@@ -243,7 +248,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430")
- def test_video_rejected(self):
+ def test_video_rejected(self) -> None:
self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
end_content = b"anything"
@@ -279,7 +284,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
},
)
- def test_audio_rejected(self):
+ def test_audio_rejected(self) -> None:
self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
end_content = b"anything"
@@ -315,7 +320,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
},
)
- def test_non_ascii_preview_content_type(self):
+ def test_non_ascii_preview_content_type(self) -> None:
self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
end_content = (
@@ -350,7 +355,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430")
- def test_overlong_title(self):
+ def test_overlong_title(self) -> None:
self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
end_content = (
@@ -387,7 +392,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
# We should only see the `og:description` field, as `title` is too long and should be stripped out
self.assertCountEqual(["og:description"], res.keys())
- def test_ipaddr(self):
+ def test_ipaddr(self) -> None:
"""
IP addresses can be previewed directly.
"""
@@ -417,7 +422,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
)
- def test_blacklisted_ip_specific(self):
+ def test_blacklisted_ip_specific(self) -> None:
"""
Blacklisted IP addresses, found via DNS, are not spidered.
"""
@@ -438,7 +443,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
},
)
- def test_blacklisted_ip_range(self):
+ def test_blacklisted_ip_range(self) -> None:
"""
Blacklisted IP ranges, IPs found over DNS, are not spidered.
"""
@@ -457,7 +462,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
},
)
- def test_blacklisted_ip_specific_direct(self):
+ def test_blacklisted_ip_specific_direct(self) -> None:
"""
Blacklisted IP addresses, accessed directly, are not spidered.
"""
@@ -476,7 +481,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, 403)
- def test_blacklisted_ip_range_direct(self):
+ def test_blacklisted_ip_range_direct(self) -> None:
"""
Blacklisted IP ranges, accessed directly, are not spidered.
"""
@@ -493,7 +498,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
},
)
- def test_blacklisted_ip_range_whitelisted_ip(self):
+ def test_blacklisted_ip_range_whitelisted_ip(self) -> None:
"""
Blacklisted but then subsequently whitelisted IP addresses can be
spidered.
@@ -526,7 +531,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
)
- def test_blacklisted_ip_with_external_ip(self):
+ def test_blacklisted_ip_with_external_ip(self) -> None:
"""
If a hostname resolves a blacklisted IP, even if there's a
non-blacklisted one, it will be rejected.
@@ -549,7 +554,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
},
)
- def test_blacklisted_ipv6_specific(self):
+ def test_blacklisted_ipv6_specific(self) -> None:
"""
Blacklisted IP addresses, found via DNS, are not spidered.
"""
@@ -572,7 +577,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
},
)
- def test_blacklisted_ipv6_range(self):
+ def test_blacklisted_ipv6_range(self) -> None:
"""
Blacklisted IP ranges, IPs found over DNS, are not spidered.
"""
@@ -591,7 +596,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
},
)
- def test_OPTIONS(self):
+ def test_OPTIONS(self) -> None:
"""
OPTIONS returns the OPTIONS.
"""
@@ -601,7 +606,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body, {})
- def test_accept_language_config_option(self):
+ def test_accept_language_config_option(self) -> None:
"""
Accept-Language header is sent to the remote server
"""
@@ -652,7 +657,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
server.data,
)
- def test_data_url(self):
+ def test_data_url(self) -> None:
"""
Requesting to preview a data URL is not supported.
"""
@@ -675,7 +680,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 500)
- def test_inline_data_url(self):
+ def test_inline_data_url(self) -> None:
"""
An inline image (as a data URL) should be parsed properly.
"""
@@ -712,7 +717,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200)
self._assert_small_png(channel.json_body)
- def test_oembed_photo(self):
+ def test_oembed_photo(self) -> None:
"""Test an oEmbed endpoint which returns a 'photo' type which redirects the preview to a new URL."""
self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")]
self.lookups["cdn.twitter.com"] = [(IPv4Address, "10.1.2.3")]
@@ -771,7 +776,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertEqual(body["og:url"], "http://twitter.com/matrixdotorg/status/12345")
self._assert_small_png(body)
- def test_oembed_rich(self):
+ def test_oembed_rich(self) -> None:
"""Test an oEmbed endpoint which returns HTML content via the 'rich' type."""
self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")]
@@ -817,7 +822,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
},
)
- def test_oembed_format(self):
+ def test_oembed_format(self) -> None:
"""Test an oEmbed endpoint which requires the format in the URL."""
self.lookups["www.hulu.com"] = [(IPv4Address, "10.1.2.3")]
@@ -866,7 +871,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
},
)
- def test_oembed_autodiscovery(self):
+ def test_oembed_autodiscovery(self) -> None:
"""
Autodiscovery works by finding the link in the HTML response and then requesting an oEmbed URL.
1. Request a preview of a URL which is not known to the oEmbed code.
@@ -962,7 +967,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
)
self._assert_small_png(body)
- def _download_image(self):
+ def _download_image(self) -> Tuple[str, str]:
"""Downloads an image into the URL cache.
Returns:
A (host, media_id) tuple representing the MXC URI of the image.
@@ -995,7 +1000,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertIsNone(_port)
return host, media_id
- def test_storage_providers_exclude_files(self):
+ def test_storage_providers_exclude_files(self) -> None:
"""Test that files are not stored in or fetched from storage providers."""
host, media_id = self._download_image()
@@ -1037,7 +1042,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
"URL cache file was unexpectedly retrieved from a storage provider",
)
- def test_storage_providers_exclude_thumbnails(self):
+ def test_storage_providers_exclude_thumbnails(self) -> None:
"""Test that thumbnails are not stored in or fetched from storage providers."""
host, media_id = self._download_image()
@@ -1090,7 +1095,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
"URL cache thumbnail was unexpectedly retrieved from a storage provider",
)
- def test_cache_expiry(self):
+ def test_cache_expiry(self) -> None:
"""Test that URL cache files and thumbnails are cleaned up properly on expiry."""
self.preview_url.clock = MockClock()
diff --git a/tests/rest/test_health.py b/tests/rest/test_health.py
index 01d48c3860..da325955f8 100644
--- a/tests/rest/test_health.py
+++ b/tests/rest/test_health.py
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+from http import HTTPStatus
from synapse.rest.health import HealthResource
@@ -19,12 +19,12 @@ from tests import unittest
class HealthCheckTests(unittest.HomeserverTestCase):
- def create_test_resource(self):
+ def create_test_resource(self) -> HealthResource:
# replace the JsonResource with a HealthResource.
return HealthResource()
- def test_health(self):
+ def test_health(self) -> None:
channel = self.make_request("GET", "/health", shorthand=False)
- self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.code, HTTPStatus.OK)
self.assertEqual(channel.result["body"], b"OK")
diff --git a/tests/rest/test_well_known.py b/tests/rest/test_well_known.py
index 118aa93a32..11f78f52b8 100644
--- a/tests/rest/test_well_known.py
+++ b/tests/rest/test_well_known.py
@@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from http import HTTPStatus
+
from twisted.web.resource import Resource
from synapse.rest.well_known import well_known_resource
@@ -19,7 +21,7 @@ from tests import unittest
class WellKnownTests(unittest.HomeserverTestCase):
- def create_test_resource(self):
+ def create_test_resource(self) -> Resource:
# replace the JsonResource with a Resource wrapping the WellKnownResource
res = Resource()
res.putChild(b".well-known", well_known_resource(self.hs))
@@ -31,12 +33,12 @@ class WellKnownTests(unittest.HomeserverTestCase):
"default_identity_server": "https://testis",
}
)
- def test_client_well_known(self):
+ def test_client_well_known(self) -> None:
channel = self.make_request(
"GET", "/.well-known/matrix/client", shorthand=False
)
- self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.code, HTTPStatus.OK)
self.assertEqual(
channel.json_body,
{
@@ -50,27 +52,27 @@ class WellKnownTests(unittest.HomeserverTestCase):
"public_baseurl": None,
}
)
- def test_client_well_known_no_public_baseurl(self):
+ def test_client_well_known_no_public_baseurl(self) -> None:
channel = self.make_request(
"GET", "/.well-known/matrix/client", shorthand=False
)
- self.assertEqual(channel.code, 404)
+ self.assertEqual(channel.code, HTTPStatus.NOT_FOUND)
@unittest.override_config({"serve_server_wellknown": True})
- def test_server_well_known(self):
+ def test_server_well_known(self) -> None:
channel = self.make_request(
"GET", "/.well-known/matrix/server", shorthand=False
)
- self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.code, HTTPStatus.OK)
self.assertEqual(
channel.json_body,
{"m.server": "test:443"},
)
- def test_server_well_known_disabled(self):
+ def test_server_well_known_disabled(self) -> None:
channel = self.make_request(
"GET", "/.well-known/matrix/server", shorthand=False
)
- self.assertEqual(channel.code, 404)
+ self.assertEqual(channel.code, HTTPStatus.NOT_FOUND)
diff --git a/tests/server.py b/tests/server.py
index 82990c2eb9..6ce2a17bf4 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -54,13 +54,18 @@ from twisted.internet.interfaces import (
ITransport,
)
from twisted.python.failure import Failure
-from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
+from twisted.test.proto_helpers import (
+ AccumulatingProtocol,
+ MemoryReactor,
+ MemoryReactorClock,
+)
from twisted.web.http_headers import Headers
from twisted.web.resource import IResource
from twisted.web.server import Request, Site
from synapse.config.database import DatabaseConnectionConfig
from synapse.http.site import SynapseRequest
+from synapse.logging.context import ContextResourceUsage
from synapse.server import HomeServer
from synapse.storage import DataStore
from synapse.storage.engines import PostgresEngine, create_engine
@@ -88,18 +93,19 @@ class TimedOutException(Exception):
"""
-@attr.s
+@attr.s(auto_attribs=True)
class FakeChannel:
"""
A fake Twisted Web Channel (the part that interfaces with the
wire).
"""
- site = attr.ib(type=Union[Site, "FakeSite"])
- _reactor = attr.ib()
- result = attr.ib(type=dict, default=attr.Factory(dict))
- _ip = attr.ib(type=str, default="127.0.0.1")
+ site: Union[Site, "FakeSite"]
+ _reactor: MemoryReactor
+ result: dict = attr.Factory(dict)
+ _ip: str = "127.0.0.1"
_producer: Optional[Union[IPullProducer, IPushProducer]] = None
+ resource_usage: Optional[ContextResourceUsage] = None
@property
def json_body(self):
@@ -168,6 +174,8 @@ class FakeChannel:
def requestDone(self, _self):
self.result["done"] = True
+ if isinstance(_self, SynapseRequest):
+ self.resource_usage = _self.logcontext.get_resource_usage()
def getPeer(self):
# We give an address so that getClientIP returns a non null entry,
diff --git a/tests/storage/test_account_data.py b/tests/storage/test_account_data.py
index 272cd35402..72bf5b3d31 100644
--- a/tests/storage/test_account_data.py
+++ b/tests/storage/test_account_data.py
@@ -47,9 +47,18 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase):
expected_ignorer_user_ids,
)
+ def assert_ignored(
+ self, ignorer_user_id: str, expected_ignored_user_ids: Set[str]
+ ) -> None:
+ self.assertEqual(
+ self.get_success(self.store.ignored_users(ignorer_user_id)),
+ expected_ignored_user_ids,
+ )
+
def test_ignoring_users(self):
"""Basic adding/removing of users from the ignore list."""
self._update_ignore_list("@other:test", "@another:remote")
+ self.assert_ignored(self.user, {"@other:test", "@another:remote"})
# Check a user which no one ignores.
self.assert_ignorers("@user:test", set())
@@ -62,6 +71,7 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase):
# Add one user, remove one user, and leave one user.
self._update_ignore_list("@foo:test", "@another:remote")
+ self.assert_ignored(self.user, {"@foo:test", "@another:remote"})
# Check the removed user.
self.assert_ignorers("@other:test", set())
@@ -76,20 +86,24 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase):
"""Ensure that caching works properly between different users."""
# The first user ignores a user.
self._update_ignore_list("@other:test")
+ self.assert_ignored(self.user, {"@other:test"})
self.assert_ignorers("@other:test", {self.user})
# The second user ignores them.
self._update_ignore_list("@other:test", ignorer_user_id="@second:test")
+ self.assert_ignored("@second:test", {"@other:test"})
self.assert_ignorers("@other:test", {self.user, "@second:test"})
# The first user un-ignores them.
self._update_ignore_list()
+ self.assert_ignored(self.user, set())
self.assert_ignorers("@other:test", {"@second:test"})
def test_invalid_data(self):
"""Invalid data ends up clearing out the ignored users list."""
# Add some data and ensure it is there.
self._update_ignore_list("@other:test")
+ self.assert_ignored(self.user, {"@other:test"})
self.assert_ignorers("@other:test", {self.user})
# No ignored_users key.
@@ -102,10 +116,12 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase):
)
# No one ignores the user now.
+ self.assert_ignored(self.user, set())
self.assert_ignorers("@other:test", set())
# Add some data and ensure it is there.
self._update_ignore_list("@other:test")
+ self.assert_ignored(self.user, {"@other:test"})
self.assert_ignorers("@other:test", {self.user})
# Invalid data.
@@ -118,4 +134,5 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase):
)
# No one ignores the user now.
+ self.assert_ignored(self.user, set())
self.assert_ignorers("@other:test", set())
diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index 39dcc094bd..fd619b64d4 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -14,16 +14,23 @@
from unittest.mock import Mock
+import yaml
+
from twisted.internet.defer import Deferred, ensureDeferred
+from twisted.test.proto_helpers import MemoryReactor
+from synapse.server import HomeServer
from synapse.storage.background_updates import BackgroundUpdater
+from synapse.types import JsonDict
+from synapse.util import Clock
from tests import unittest
from tests.test_utils import make_awaitable, simple_async_mock
+from tests.unittest import override_config
class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
- def prepare(self, reactor, clock, homeserver):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.updates: BackgroundUpdater = self.hs.get_datastores().main.db_pool.updates
# the base test class should have run the real bg updates for us
self.assertTrue(
@@ -34,50 +41,50 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
self.updates.register_background_update_handler(
"test_update", self.update_handler
)
+ self.store = self.hs.get_datastores().main
+
+ async def update(self, progress: JsonDict, count: int) -> int:
+ duration_ms = 10
+ await self.clock.sleep((count * duration_ms) / 1000)
+ progress = {"my_key": progress["my_key"] + 1}
+ await self.store.db_pool.runInteraction(
+ "update_progress",
+ self.updates._background_update_progress_txn,
+ "test_update",
+ progress,
+ )
+ return count
- def test_do_background_update(self):
+ def test_do_background_update(self) -> None:
# the time we claim it takes to update one item when running the update
duration_ms = 10
# the target runtime for each bg update
target_background_update_duration_ms = 100
- store = self.hs.get_datastores().main
self.get_success(
- store.db_pool.simple_insert(
+ self.store.db_pool.simple_insert(
"background_updates",
values={"update_name": "test_update", "progress_json": '{"my_key": 1}'},
)
)
- # first step: make a bit of progress
- async def update(progress, count):
- await self.clock.sleep((count * duration_ms) / 1000)
- progress = {"my_key": progress["my_key"] + 1}
- await store.db_pool.runInteraction(
- "update_progress",
- self.updates._background_update_progress_txn,
- "test_update",
- progress,
- )
- return count
-
- self.update_handler.side_effect = update
+ self.update_handler.side_effect = self.update
self.update_handler.reset_mock()
res = self.get_success(
self.updates.do_next_background_update(False),
- by=0.01,
+ by=0.02,
)
self.assertFalse(res)
# on the first call, we should get run with the default background update size
self.update_handler.assert_called_once_with(
- {"my_key": 1}, self.updates.MINIMUM_BACKGROUND_BATCH_SIZE
+ {"my_key": 1}, self.updates.default_background_batch_size
)
# second step: complete the update
# we should now get run with a much bigger number of items to update
- async def update(progress, count):
+ async def update(progress: JsonDict, count: int) -> int:
self.assertEqual(progress, {"my_key": 2})
self.assertAlmostEqual(
count,
@@ -99,16 +106,234 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
self.assertTrue(result)
self.assertFalse(self.update_handler.called)
+ @override_config(
+ yaml.safe_load(
+ """
+ background_updates:
+ default_batch_size: 20
+ """
+ )
+ )
+ def test_background_update_default_batch_set_by_config(self) -> None:
+ """
+ Test that the background update is run with the default_batch_size set by the config
+ """
+
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "background_updates",
+ values={"update_name": "test_update", "progress_json": '{"my_key": 1}'},
+ )
+ )
+
+ self.update_handler.side_effect = self.update
+ self.update_handler.reset_mock()
+ res = self.get_success(
+ self.updates.do_next_background_update(False),
+ by=0.01,
+ )
+ self.assertFalse(res)
+
+ # on the first call, we should get run with the default background update size specified in the config
+ self.update_handler.assert_called_once_with({"my_key": 1}, 20)
+
+ def test_background_update_default_sleep_behavior(self) -> None:
+ """
+ Test default background update behavior, which is to sleep
+ """
+
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "background_updates",
+ values={"update_name": "test_update", "progress_json": '{"my_key": 1}'},
+ )
+ )
+
+ self.update_handler.side_effect = self.update
+ self.update_handler.reset_mock()
+ self.updates.start_doing_background_updates()
+
+ # 2: advance the reactor less than the default sleep duration (1000ms)
+ self.reactor.pump([0.5])
+ # check that an update has not been run
+ self.update_handler.assert_not_called()
+
+ # advance reactor past default sleep duration
+ self.reactor.pump([1])
+ # check that update has been run
+ self.update_handler.assert_called()
+
+ @override_config(
+ yaml.safe_load(
+ """
+ background_updates:
+ sleep_duration_ms: 500
+ """
+ )
+ )
+ def test_background_update_sleep_set_in_config(self) -> None:
+ """
+ Test that changing the sleep time in the config changes how long it sleeps
+ """
+
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "background_updates",
+ values={"update_name": "test_update", "progress_json": '{"my_key": 1}'},
+ )
+ )
+
+ self.update_handler.side_effect = self.update
+ self.update_handler.reset_mock()
+ self.updates.start_doing_background_updates()
+
+ # 2: advance the reactor less than the configured sleep duration (500ms)
+ self.reactor.pump([0.45])
+ # check that an update has not been run
+ self.update_handler.assert_not_called()
+
+ # advance reactor past config sleep duration but less than default duration
+ self.reactor.pump([0.75])
+ # check that update has been run
+ self.update_handler.assert_called()
+
+ @override_config(
+ yaml.safe_load(
+ """
+ background_updates:
+ sleep_enabled: false
+ """
+ )
+ )
+ def test_disabling_background_update_sleep(self) -> None:
+ """
+ Test that disabling sleep in the config results in bg update not sleeping
+ """
+
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "background_updates",
+ values={"update_name": "test_update", "progress_json": '{"my_key": 1}'},
+ )
+ )
+
+ self.update_handler.side_effect = self.update
+ self.update_handler.reset_mock()
+ self.updates.start_doing_background_updates()
+
+ # 2: advance the reactor very little
+ self.reactor.pump([0.025])
+ # check that an update has run
+ self.update_handler.assert_called()
+
+ @override_config(
+ yaml.safe_load(
+ """
+ background_updates:
+ background_update_duration_ms: 500
+ """
+ )
+ )
+ def test_background_update_duration_set_in_config(self) -> None:
+ """
+ Test that the desired duration set in the config is used in determining batch size
+ """
+ # Duration of one background update item
+ duration_ms = 10
+
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "background_updates",
+ values={"update_name": "test_update", "progress_json": '{"my_key": 1}'},
+ )
+ )
+
+ self.update_handler.side_effect = self.update
+ self.update_handler.reset_mock()
+ res = self.get_success(
+ self.updates.do_next_background_update(False),
+ by=0.02,
+ )
+ self.assertFalse(res)
+
+ # the first update was run with the default batch size, this should be run with 500ms as the
+ # desired duration
+ async def update(progress: JsonDict, count: int) -> int:
+ self.assertEqual(progress, {"my_key": 2})
+ self.assertAlmostEqual(
+ count,
+ 500 / duration_ms,
+ places=0,
+ )
+ await self.updates._end_background_update("test_update")
+ return count
+
+ self.update_handler.side_effect = update
+ self.get_success(self.updates.do_next_background_update(False))
+
+ @override_config(
+ yaml.safe_load(
+ """
+ background_updates:
+ min_batch_size: 5
+ """
+ )
+ )
+ def test_background_update_min_batch_set_in_config(self) -> None:
+ """
+ Test that the minimum batch size set in the config is used
+ """
+ # a very long-running individual update
+ duration_ms = 50
+
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "background_updates",
+ values={"update_name": "test_update", "progress_json": '{"my_key": 1}'},
+ )
+ )
+
+ # Run the update with the long-running update item
+ async def update_long(progress: JsonDict, count: int) -> int:
+ await self.clock.sleep((count * duration_ms) / 1000)
+ progress = {"my_key": progress["my_key"] + 1}
+ await self.store.db_pool.runInteraction(
+ "update_progress",
+ self.updates._background_update_progress_txn,
+ "test_update",
+ progress,
+ )
+ return count
+
+ self.update_handler.side_effect = update_long
+ self.update_handler.reset_mock()
+ res = self.get_success(
+ self.updates.do_next_background_update(False),
+ by=1,
+ )
+ self.assertFalse(res)
+
+ # the first update was run with the default batch size, this should be run with minimum batch size
+ # as the first items took a very long time
+ async def update_short(progress: JsonDict, count: int) -> int:
+ self.assertEqual(progress, {"my_key": 2})
+ self.assertEqual(count, 5)
+ await self.updates._end_background_update("test_update")
+ return count
+
+ self.update_handler.side_effect = update_short
+ self.get_success(self.updates.do_next_background_update(False))
+
class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase):
- def prepare(self, reactor, clock, homeserver):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.updates: BackgroundUpdater = self.hs.get_datastores().main.db_pool.updates
# the base test class should have run the real bg updates for us
self.assertTrue(
self.get_success(self.updates.has_completed_background_updates())
)
- self.update_deferred = Deferred()
+ self.update_deferred: Deferred[int] = Deferred()
self.update_handler = Mock(return_value=self.update_deferred)
self.updates.register_background_update_handler(
"test_update", self.update_handler
@@ -137,7 +362,7 @@ class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase):
),
)
- def test_controller(self):
+ def test_controller(self) -> None:
store = self.hs.get_datastores().main
self.get_success(
store.db_pool.simple_insert(
@@ -147,7 +372,7 @@ class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase):
)
# Set the return value for the context manager.
- enter_defer = Deferred()
+ enter_defer: Deferred[int] = Deferred()
self._update_ctx_manager.__aenter__ = Mock(return_value=enter_defer)
# Start the background update.
diff --git a/tests/storage/test_database.py b/tests/storage/test_database.py
index 6fbac0ab14..a40fc20ef9 100644
--- a/tests/storage/test_database.py
+++ b/tests/storage/test_database.py
@@ -12,25 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.database import make_tuple_comparison_clause
-from synapse.storage.engines import BaseDatabaseEngine
+from typing import Callable, Tuple
+from unittest.mock import Mock, call
-from tests import unittest
+from twisted.internet import defer
+from twisted.internet.defer import CancelledError, Deferred
+from twisted.test.proto_helpers import MemoryReactor
+from synapse.server import HomeServer
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingTransaction,
+ make_tuple_comparison_clause,
+)
+from synapse.util import Clock
-def _stub_db_engine(**kwargs) -> BaseDatabaseEngine:
- # returns a DatabaseEngine, circumventing the abc mechanism
- # any kwargs are set as attributes on the class before instantiating it
- t = type(
- "TestBaseDatabaseEngine",
- (BaseDatabaseEngine,),
- dict(BaseDatabaseEngine.__dict__),
- )
- # defeat the abc mechanism
- t.__abstractmethods__ = set()
- for k, v in kwargs.items():
- setattr(t, k, v)
- return t(None, None)
+from tests import unittest
class TupleComparisonClauseTestCase(unittest.TestCase):
@@ -38,3 +35,150 @@ class TupleComparisonClauseTestCase(unittest.TestCase):
clause, args = make_tuple_comparison_clause([("a", 1), ("b", 2)])
self.assertEqual(clause, "(a,b) > (?,?)")
self.assertEqual(args, [1, 2])
+
+
+class CallbacksTestCase(unittest.HomeserverTestCase):
+ """Tests for transaction callbacks."""
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.store = hs.get_datastores().main
+ self.db_pool: DatabasePool = self.store.db_pool
+
+ def _run_interaction(
+ self, func: Callable[[LoggingTransaction], object]
+ ) -> Tuple[Mock, Mock]:
+ """Run the given function in a database transaction, with callbacks registered.
+
+ Args:
+ func: The function to be run in a transaction. The transaction will be
+ retried if `func` raises an `OperationalError`.
+
+ Returns:
+ Two mocks, which were registered as an `after_callback` and an
+ `exception_callback` respectively, on every transaction attempt.
+ """
+ after_callback = Mock()
+ exception_callback = Mock()
+
+ def _test_txn(txn: LoggingTransaction) -> None:
+ txn.call_after(after_callback, 123, 456, extra=789)
+ txn.call_on_exception(exception_callback, 987, 654, extra=321)
+ func(txn)
+
+ try:
+ self.get_success_or_raise(
+ self.db_pool.runInteraction("test_transaction", _test_txn)
+ )
+ except Exception:
+ pass
+
+ return after_callback, exception_callback
+
+ def test_after_callback(self) -> None:
+ """Test that the after callback is called when a transaction succeeds."""
+ after_callback, exception_callback = self._run_interaction(lambda txn: None)
+
+ after_callback.assert_called_once_with(123, 456, extra=789)
+ exception_callback.assert_not_called()
+
+ def test_exception_callback(self) -> None:
+ """Test that the exception callback is called when a transaction fails."""
+ _test_txn = Mock(side_effect=ZeroDivisionError)
+ after_callback, exception_callback = self._run_interaction(_test_txn)
+
+ after_callback.assert_not_called()
+ exception_callback.assert_called_once_with(987, 654, extra=321)
+
+ def test_failed_retry(self) -> None:
+ """Test that the exception callback is called for every failed attempt."""
+ # Always raise an `OperationalError`.
+ _test_txn = Mock(side_effect=self.db_pool.engine.module.OperationalError)
+ after_callback, exception_callback = self._run_interaction(_test_txn)
+
+ after_callback.assert_not_called()
+ exception_callback.assert_has_calls(
+ [
+ call(987, 654, extra=321),
+ call(987, 654, extra=321),
+ call(987, 654, extra=321),
+ call(987, 654, extra=321),
+ call(987, 654, extra=321),
+ call(987, 654, extra=321),
+ ]
+ )
+ self.assertEqual(exception_callback.call_count, 6) # no additional calls
+
+ def test_successful_retry(self) -> None:
+ """Test callbacks for a failed transaction followed by a successful attempt."""
+ # Raise an `OperationalError` on the first attempt only.
+ _test_txn = Mock(
+ side_effect=[self.db_pool.engine.module.OperationalError, None]
+ )
+ after_callback, exception_callback = self._run_interaction(_test_txn)
+
+ # Calling both `after_callback`s when the first attempt failed is rather
+ # surprising (#12184). Let's document the behaviour in a test.
+ after_callback.assert_has_calls(
+ [
+ call(123, 456, extra=789),
+ call(123, 456, extra=789),
+ ]
+ )
+ self.assertEqual(after_callback.call_count, 2) # no additional calls
+ exception_callback.assert_not_called()
+
+
+class CancellationTestCase(unittest.HomeserverTestCase):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.store = hs.get_datastores().main
+ self.db_pool: DatabasePool = self.store.db_pool
+
+ def test_after_callback(self) -> None:
+ """Test that the after callback is called when a transaction succeeds."""
+ d: "Deferred[None]"
+ after_callback = Mock()
+ exception_callback = Mock()
+
+ def _test_txn(txn: LoggingTransaction) -> None:
+ txn.call_after(after_callback, 123, 456, extra=789)
+ txn.call_on_exception(exception_callback, 987, 654, extra=321)
+ d.cancel()
+
+ d = defer.ensureDeferred(
+ self.db_pool.runInteraction("test_transaction", _test_txn)
+ )
+ self.get_failure(d, CancelledError)
+
+ after_callback.assert_called_once_with(123, 456, extra=789)
+ exception_callback.assert_not_called()
+
+ def test_exception_callback(self) -> None:
+ """Test that the exception callback is called when a transaction fails."""
+ d: "Deferred[None]"
+ after_callback = Mock()
+ exception_callback = Mock()
+
+ def _test_txn(txn: LoggingTransaction) -> None:
+ txn.call_after(after_callback, 123, 456, extra=789)
+ txn.call_on_exception(exception_callback, 987, 654, extra=321)
+ d.cancel()
+ # Simulate a retryable failure on every attempt.
+ raise self.db_pool.engine.module.OperationalError()
+
+ d = defer.ensureDeferred(
+ self.db_pool.runInteraction("test_transaction", _test_txn)
+ )
+ self.get_failure(d, CancelledError)
+
+ after_callback.assert_not_called()
+ exception_callback.assert_has_calls(
+ [
+ call(987, 654, extra=321),
+ call(987, 654, extra=321),
+ call(987, 654, extra=321),
+ call(987, 654, extra=321),
+ call(987, 654, extra=321),
+ call(987, 654, extra=321),
+ ]
+ )
+ self.assertEqual(exception_callback.call_count, 6) # no additional calls
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index 6ac4b93f98..395396340b 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -13,9 +13,13 @@
# limitations under the License.
from typing import List, Optional
-from synapse.storage.database import DatabasePool
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.server import HomeServer
+from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.engines import IncorrectDatabaseSetup
from synapse.storage.util.id_generators import MultiWriterIdGenerator
+from synapse.util import Clock
from tests.unittest import HomeserverTestCase
from tests.utils import USE_POSTGRES_FOR_TESTS
@@ -25,13 +29,13 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
if not USE_POSTGRES_FOR_TESTS:
skip = "Requires Postgres"
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.db_pool: DatabasePool = self.store.db_pool
self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
- def _setup_db(self, txn):
+ def _setup_db(self, txn: LoggingTransaction) -> None:
txn.execute("CREATE SEQUENCE foobar_seq")
txn.execute(
"""
@@ -59,12 +63,12 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
return self.get_success_or_raise(self.db_pool.runWithConnection(_create))
- def _insert_rows(self, instance_name: str, number: int):
+ def _insert_rows(self, instance_name: str, number: int) -> None:
"""Insert N rows as the given instance, inserting with stream IDs pulled
from the postgres sequence.
"""
- def _insert(txn):
+ def _insert(txn: LoggingTransaction) -> None:
for _ in range(number):
txn.execute(
"INSERT INTO foobar VALUES (nextval('foobar_seq'), ?)",
@@ -80,12 +84,12 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
- def _insert_row_with_id(self, instance_name: str, stream_id: int):
+ def _insert_row_with_id(self, instance_name: str, stream_id: int) -> None:
"""Insert one row as the given instance with given stream_id, updating
the postgres sequence position to match.
"""
- def _insert(txn):
+ def _insert(txn: LoggingTransaction) -> None:
txn.execute(
"INSERT INTO foobar VALUES (?, ?)",
(
@@ -104,7 +108,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.get_success(self.db_pool.runInteraction("_insert_row_with_id", _insert))
- def test_empty(self):
+ def test_empty(self) -> None:
"""Test an ID generator against an empty database gives sensible
current positions.
"""
@@ -114,7 +118,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
# The table is empty so we expect an empty map for positions
self.assertEqual(id_gen.get_positions(), {})
- def test_single_instance(self):
+ def test_single_instance(self) -> None:
"""Test that reads and writes from a single process are handled
correctly.
"""
@@ -130,7 +134,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
# Try allocating a new ID gen and check that we only see position
# advanced after we leave the context manager.
- async def _get_next_async():
+ async def _get_next_async() -> None:
async with id_gen.get_next() as stream_id:
self.assertEqual(stream_id, 8)
@@ -142,7 +146,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen.get_positions(), {"master": 8})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
- def test_out_of_order_finish(self):
+ def test_out_of_order_finish(self) -> None:
"""Test that IDs persisted out of order are correctly handled"""
# Prefill table with 7 rows written by 'master'
@@ -191,7 +195,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen.get_positions(), {"master": 11})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 11)
- def test_multi_instance(self):
+ def test_multi_instance(self) -> None:
"""Test that reads and writes from multiple processes are handled
correctly.
"""
@@ -215,7 +219,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
# Try allocating a new ID gen and check that we only see position
# advanced after we leave the context manager.
- async def _get_next_async():
+ async def _get_next_async() -> None:
async with first_id_gen.get_next() as stream_id:
self.assertEqual(stream_id, 8)
@@ -233,7 +237,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
# ... but calling `get_next` on the second instance should give a unique
# stream ID
- async def _get_next_async():
+ async def _get_next_async2() -> None:
async with second_id_gen.get_next() as stream_id:
self.assertEqual(stream_id, 9)
@@ -241,7 +245,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
second_id_gen.get_positions(), {"first": 3, "second": 7}
)
- self.get_success(_get_next_async())
+ self.get_success(_get_next_async2())
self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 9})
@@ -249,7 +253,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
second_id_gen.advance("first", 8)
self.assertEqual(second_id_gen.get_positions(), {"first": 8, "second": 9})
- def test_get_next_txn(self):
+ def test_get_next_txn(self) -> None:
"""Test that the `get_next_txn` function works correctly."""
# Prefill table with 7 rows written by 'master'
@@ -263,7 +267,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
# Try allocating a new ID gen and check that we only see position
# advanced after we leave the context manager.
- def _get_next_txn(txn):
+ def _get_next_txn(txn: LoggingTransaction) -> None:
stream_id = id_gen.get_next_txn(txn)
self.assertEqual(stream_id, 8)
@@ -275,7 +279,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen.get_positions(), {"master": 8})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
- def test_get_persisted_upto_position(self):
+ def test_get_persisted_upto_position(self) -> None:
"""Test that `get_persisted_upto_position` correctly tracks updates to
positions.
"""
@@ -317,7 +321,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
id_gen.advance("second", 15)
self.assertEqual(id_gen.get_persisted_upto_position(), 11)
- def test_get_persisted_upto_position_get_next(self):
+ def test_get_persisted_upto_position_get_next(self) -> None:
"""Test that `get_persisted_upto_position` correctly tracks updates to
positions when `get_next` is called.
"""
@@ -331,7 +335,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen.get_persisted_upto_position(), 5)
- async def _get_next_async():
+ async def _get_next_async() -> None:
async with id_gen.get_next() as stream_id:
self.assertEqual(stream_id, 6)
self.assertEqual(id_gen.get_persisted_upto_position(), 5)
@@ -344,7 +348,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
# `persisted_upto_position` in this case, then it will be correct in the
# other cases that are tested above (since they'll hit the same code).
- def test_restart_during_out_of_order_persistence(self):
+ def test_restart_during_out_of_order_persistence(self) -> None:
"""Test that restarting a process while another process is writing out
of order updates are handled correctly.
"""
@@ -388,7 +392,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
id_gen_worker.advance("master", 9)
self.assertEqual(id_gen_worker.get_positions(), {"master": 9})
- def test_writer_config_change(self):
+ def test_writer_config_change(self) -> None:
"""Test that changing the writer config correctly works."""
self._insert_row_with_id("first", 3)
@@ -421,7 +425,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
# Check that we get a sane next stream ID with this new config.
- async def _get_next_async():
+ async def _get_next_async() -> None:
async with id_gen_3.get_next() as stream_id:
self.assertEqual(stream_id, 6)
@@ -435,7 +439,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen_5.get_current_token_for_writer("first"), 6)
self.assertEqual(id_gen_5.get_current_token_for_writer("third"), 6)
- def test_sequence_consistency(self):
+ def test_sequence_consistency(self) -> None:
"""Test that we error out if the table and sequence diverges."""
# Prefill with some rows
@@ -458,13 +462,13 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
if not USE_POSTGRES_FOR_TESTS:
skip = "Requires Postgres"
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.db_pool: DatabasePool = self.store.db_pool
self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
- def _setup_db(self, txn):
+ def _setup_db(self, txn: LoggingTransaction) -> None:
txn.execute("CREATE SEQUENCE foobar_seq")
txn.execute(
"""
@@ -493,10 +497,10 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
return self.get_success(self.db_pool.runWithConnection(_create))
- def _insert_row(self, instance_name: str, stream_id: int):
+ def _insert_row(self, instance_name: str, stream_id: int) -> None:
"""Insert one row as the given instance with given stream_id."""
- def _insert(txn):
+ def _insert(txn: LoggingTransaction) -> None:
txn.execute(
"INSERT INTO foobar VALUES (?, ?)",
(
@@ -514,13 +518,13 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.get_success(self.db_pool.runInteraction("_insert_row", _insert))
- def test_single_instance(self):
+ def test_single_instance(self) -> None:
"""Test that reads and writes from a single process are handled
correctly.
"""
id_gen = self._create_id_generator()
- async def _get_next_async():
+ async def _get_next_async() -> None:
async with id_gen.get_next() as stream_id:
self._insert_row("master", stream_id)
@@ -530,7 +534,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen.get_current_token_for_writer("master"), -1)
self.assertEqual(id_gen.get_persisted_upto_position(), -1)
- async def _get_next_async2():
+ async def _get_next_async2() -> None:
async with id_gen.get_next_mult(3) as stream_ids:
for stream_id in stream_ids:
self._insert_row("master", stream_id)
@@ -548,14 +552,14 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(second_id_gen.get_current_token_for_writer("master"), -4)
self.assertEqual(second_id_gen.get_persisted_upto_position(), -4)
- def test_multiple_instance(self):
+ def test_multiple_instance(self) -> None:
"""Tests that having multiple instances that get advanced over
federation works corretly.
"""
id_gen_1 = self._create_id_generator("first", writers=["first", "second"])
id_gen_2 = self._create_id_generator("second", writers=["first", "second"])
- async def _get_next_async():
+ async def _get_next_async() -> None:
async with id_gen_1.get_next() as stream_id:
self._insert_row("first", stream_id)
id_gen_2.advance("first", stream_id)
@@ -567,7 +571,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen_1.get_persisted_upto_position(), -1)
self.assertEqual(id_gen_2.get_persisted_upto_position(), -1)
- async def _get_next_async2():
+ async def _get_next_async2() -> None:
async with id_gen_2.get_next() as stream_id:
self._insert_row("second", stream_id)
id_gen_1.advance("second", stream_id)
@@ -584,13 +588,13 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
if not USE_POSTGRES_FOR_TESTS:
skip = "Requires Postgres"
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.db_pool: DatabasePool = self.store.db_pool
self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
- def _setup_db(self, txn):
+ def _setup_db(self, txn: LoggingTransaction) -> None:
txn.execute("CREATE SEQUENCE foobar_seq")
txn.execute(
"""
@@ -642,7 +646,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
from the postgres sequence.
"""
- def _insert(txn):
+ def _insert(txn: LoggingTransaction) -> None:
for _ in range(number):
txn.execute(
"INSERT INTO %s VALUES (nextval('foobar_seq'), ?)" % (table,),
@@ -659,7 +663,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
- def test_load_existing_stream(self):
+ def test_load_existing_stream(self) -> None:
"""Test creating ID gens with multiple tables that have rows from after
the position in `stream_positions` table.
"""
diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py
index 6a1cf33054..eaa0d7d749 100644
--- a/tests/storage/test_stream.py
+++ b/tests/storage/test_stream.py
@@ -129,21 +129,19 @@ class PaginationTestCase(HomeserverTestCase):
def test_filter_relation_senders(self):
# Messages which second user reacted to.
- filter = {"io.element.relation_senders": [self.second_user_id]}
+ filter = {"related_by_senders": [self.second_user_id]}
chunk = self._filter_messages(filter)
self.assertEqual(len(chunk), 1, chunk)
self.assertEqual(chunk[0].event_id, self.event_id_1)
# Messages which third user reacted to.
- filter = {"io.element.relation_senders": [self.third_user_id]}
+ filter = {"related_by_senders": [self.third_user_id]}
chunk = self._filter_messages(filter)
self.assertEqual(len(chunk), 1, chunk)
self.assertEqual(chunk[0].event_id, self.event_id_2)
# Messages which either user reacted to.
- filter = {
- "io.element.relation_senders": [self.second_user_id, self.third_user_id]
- }
+ filter = {"related_by_senders": [self.second_user_id, self.third_user_id]}
chunk = self._filter_messages(filter)
self.assertEqual(len(chunk), 2, chunk)
self.assertCountEqual(
@@ -152,20 +150,20 @@ class PaginationTestCase(HomeserverTestCase):
def test_filter_relation_type(self):
# Messages which have annotations.
- filter = {"io.element.relation_types": [RelationTypes.ANNOTATION]}
+ filter = {"related_by_rel_types": [RelationTypes.ANNOTATION]}
chunk = self._filter_messages(filter)
self.assertEqual(len(chunk), 1, chunk)
self.assertEqual(chunk[0].event_id, self.event_id_1)
# Messages which have references.
- filter = {"io.element.relation_types": [RelationTypes.REFERENCE]}
+ filter = {"related_by_rel_types": [RelationTypes.REFERENCE]}
chunk = self._filter_messages(filter)
self.assertEqual(len(chunk), 1, chunk)
self.assertEqual(chunk[0].event_id, self.event_id_2)
# Messages which have either annotations or references.
filter = {
- "io.element.relation_types": [
+ "related_by_rel_types": [
RelationTypes.ANNOTATION,
RelationTypes.REFERENCE,
]
@@ -179,8 +177,8 @@ class PaginationTestCase(HomeserverTestCase):
def test_filter_relation_senders_and_type(self):
# Messages which second user reacted to.
filter = {
- "io.element.relation_senders": [self.second_user_id],
- "io.element.relation_types": [RelationTypes.ANNOTATION],
+ "related_by_senders": [self.second_user_id],
+ "related_by_rel_types": [RelationTypes.ANNOTATION],
}
chunk = self._filter_messages(filter)
self.assertEqual(len(chunk), 1, chunk)
@@ -201,7 +199,7 @@ class PaginationTestCase(HomeserverTestCase):
tok=self.second_tok,
)
- filter = {"io.element.relation_senders": [self.second_user_id]}
+ filter = {"related_by_senders": [self.second_user_id]}
chunk = self._filter_messages(filter)
self.assertEqual(len(chunk), 1, chunk)
self.assertEqual(chunk[0].event_id, self.event_id_1)
diff --git a/tests/storage/test_unsafe_locale.py b/tests/storage/test_unsafe_locale.py
new file mode 100644
index 0000000000..ba53c22818
--- /dev/null
+++ b/tests/storage/test_unsafe_locale.py
@@ -0,0 +1,46 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from unittest.mock import MagicMock, patch
+
+from synapse.storage.database import make_conn
+from synapse.storage.engines._base import IncorrectDatabaseSetup
+
+from tests.unittest import HomeserverTestCase
+from tests.utils import USE_POSTGRES_FOR_TESTS
+
+
+class UnsafeLocaleTest(HomeserverTestCase):
+ if not USE_POSTGRES_FOR_TESTS:
+ skip = "Requires Postgres"
+
+ @patch("synapse.storage.engines.postgres.PostgresEngine.get_db_locale")
+ def test_unsafe_locale(self, mock_db_locale: MagicMock) -> None:
+ mock_db_locale.return_value = ("B", "B")
+ database = self.hs.get_datastores().databases[0]
+
+ db_conn = make_conn(database._database_config, database.engine, "test_unsafe")
+ with self.assertRaises(IncorrectDatabaseSetup):
+ database.engine.check_database(db_conn)
+ with self.assertRaises(IncorrectDatabaseSetup):
+ database.engine.check_new_database(db_conn)
+ db_conn.close()
+
+ def test_safe_locale(self) -> None:
+ database = self.hs.get_datastores().databases[0]
+
+ db_conn = make_conn(database._database_config, database.engine, "test_unsafe")
+ with db_conn.cursor() as txn:
+ res = database.engine.get_db_locale(txn)
+ self.assertEqual(res, ("C", "C"))
+ db_conn.close()
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index 219b5660b1..532e3fe9cd 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -13,11 +13,12 @@
# limitations under the License.
import logging
from typing import Optional
+from unittest.mock import patch
from synapse.api.room_versions import RoomVersions
-from synapse.events import EventBase
-from synapse.types import JsonDict
-from synapse.visibility import filter_events_for_server
+from synapse.events import EventBase, make_event_from_dict
+from synapse.types import JsonDict, create_requester
+from synapse.visibility import filter_events_for_client, filter_events_for_server
from tests import unittest
from tests.utils import create_room
@@ -185,3 +186,72 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
self.get_success(self.storage.persistence.persist_event(event, context))
return event
+
+
+class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase):
+ def test_out_of_band_invite_rejection(self):
+ # this is where we have received an invite event over federation, and then
+ # rejected it.
+ invite_pdu = {
+ "room_id": "!room:id",
+ "depth": 1,
+ "auth_events": [],
+ "prev_events": [],
+ "origin_server_ts": 1,
+ "sender": "@someone:" + self.OTHER_SERVER_NAME,
+ "type": "m.room.member",
+ "state_key": "@user:test",
+ "content": {"membership": "invite"},
+ }
+ self.add_hashes_and_signatures(invite_pdu)
+ invite_event_id = make_event_from_dict(invite_pdu, RoomVersions.V9).event_id
+
+ self.get_success(
+ self.hs.get_federation_server().on_invite_request(
+ self.OTHER_SERVER_NAME,
+ invite_pdu,
+ "9",
+ )
+ )
+
+ # stub out do_remotely_reject_invite so that we fall back to a locally-
+ # generated rejection
+ with patch.object(
+ self.hs.get_federation_handler(),
+ "do_remotely_reject_invite",
+ side_effect=Exception(),
+ ):
+ reject_event_id, _ = self.get_success(
+ self.hs.get_room_member_handler().remote_reject_invite(
+ invite_event_id,
+ txn_id=None,
+ requester=create_requester("@user:test"),
+ content={},
+ )
+ )
+
+ invite_event, reject_event = self.get_success(
+ self.hs.get_datastores().main.get_events_as_list(
+ [invite_event_id, reject_event_id]
+ )
+ )
+
+ # the invited user should be able to see both the invite and the rejection
+ self.assertEqual(
+ self.get_success(
+ filter_events_for_client(
+ self.hs.get_storage(), "@user:test", [invite_event, reject_event]
+ )
+ ),
+ [invite_event, reject_event],
+ )
+
+ # other users should see neither
+ self.assertEqual(
+ self.get_success(
+ filter_events_for_client(
+ self.hs.get_storage(), "@other:test", [invite_event, reject_event]
+ )
+ ),
+ [],
+ )
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 19741ffcda..48e616ac74 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -17,7 +17,7 @@ from typing import Set
from unittest import mock
from twisted.internet import defer, reactor
-from twisted.internet.defer import Deferred
+from twisted.internet.defer import CancelledError, Deferred
from synapse.api.errors import SynapseError
from synapse.logging.context import (
@@ -28,7 +28,7 @@ from synapse.logging.context import (
make_deferred_yieldable,
)
from synapse.util.caches import descriptors
-from synapse.util.caches.descriptors import cached, lru_cache
+from synapse.util.caches.descriptors import cached, cachedList, lru_cache
from tests import unittest
from tests.test_utils import get_awaitable_result
@@ -141,6 +141,84 @@ class DescriptorTestCase(unittest.TestCase):
self.assertEqual(r, "chips")
obj.mock.assert_not_called()
+ @defer.inlineCallbacks
+ def test_cache_uncached_args(self):
+ """
+ Only the arguments not named in uncached_args should matter to the cache
+
+ Note that this is identical to test_cache_num_args, but provides the
+ arguments differently.
+ """
+
+ class Cls:
+ # Note that it is important that this is not the last argument to
+ # test behaviour of skipping arguments properly.
+ @descriptors.cached(uncached_args=("arg2",))
+ def fn(self, arg1, arg2, arg3):
+ return self.mock(arg1, arg2, arg3)
+
+ def __init__(self):
+ self.mock = mock.Mock()
+
+ obj = Cls()
+ obj.mock.return_value = "fish"
+ r = yield obj.fn(1, 2, 3)
+ self.assertEqual(r, "fish")
+ obj.mock.assert_called_once_with(1, 2, 3)
+ obj.mock.reset_mock()
+
+ # a call with different params should call the mock again
+ obj.mock.return_value = "chips"
+ r = yield obj.fn(2, 3, 4)
+ self.assertEqual(r, "chips")
+ obj.mock.assert_called_once_with(2, 3, 4)
+ obj.mock.reset_mock()
+
+ # the two values should now be cached; we should be able to vary
+ # the second argument and still get the cached result.
+ r = yield obj.fn(1, 4, 3)
+ self.assertEqual(r, "fish")
+ r = yield obj.fn(2, 5, 4)
+ self.assertEqual(r, "chips")
+ obj.mock.assert_not_called()
+
+ @defer.inlineCallbacks
+ def test_cache_kwargs(self):
+ """Test that keyword arguments are treated properly"""
+
+ class Cls:
+ def __init__(self):
+ self.mock = mock.Mock()
+
+ @descriptors.cached()
+ def fn(self, arg1, kwarg1=2):
+ return self.mock(arg1, kwarg1=kwarg1)
+
+ obj = Cls()
+ obj.mock.return_value = "fish"
+ r = yield obj.fn(1, kwarg1=2)
+ self.assertEqual(r, "fish")
+ obj.mock.assert_called_once_with(1, kwarg1=2)
+ obj.mock.reset_mock()
+
+ # a call with different params should call the mock again
+ obj.mock.return_value = "chips"
+ r = yield obj.fn(1, kwarg1=3)
+ self.assertEqual(r, "chips")
+ obj.mock.assert_called_once_with(1, kwarg1=3)
+ obj.mock.reset_mock()
+
+ # the values should now be cached.
+ r = yield obj.fn(1, kwarg1=2)
+ self.assertEqual(r, "fish")
+ # We should be able to not provide kwarg1 and get the cached value back.
+ r = yield obj.fn(1)
+ self.assertEqual(r, "fish")
+ # Keyword arguments can be in any order.
+ r = yield obj.fn(kwarg1=2, arg1=1)
+ self.assertEqual(r, "fish")
+ obj.mock.assert_not_called()
+
def test_cache_with_sync_exception(self):
"""If the wrapped function throws synchronously, things should continue to work"""
@@ -415,6 +493,74 @@ class DescriptorTestCase(unittest.TestCase):
obj.invalidate()
top_invalidate.assert_called_once()
+ def test_cancel(self):
+ """Test that cancelling a lookup does not cancel other lookups"""
+ complete_lookup: "Deferred[None]" = Deferred()
+
+ class Cls:
+ @cached()
+ async def fn(self, arg1):
+ await complete_lookup
+ return str(arg1)
+
+ obj = Cls()
+
+ d1 = obj.fn(123)
+ d2 = obj.fn(123)
+ self.assertFalse(d1.called)
+ self.assertFalse(d2.called)
+
+ # Cancel `d1`, which is the lookup that caused `fn` to run.
+ d1.cancel()
+
+ # `d2` should complete normally.
+ complete_lookup.callback(None)
+ self.failureResultOf(d1, CancelledError)
+ self.assertEqual(d2.result, "123")
+
+ def test_cancel_logcontexts(self):
+ """Test that cancellation does not break logcontexts.
+
+ * The `CancelledError` must be raised with the correct logcontext.
+ * The inner lookup must not resume with a finished logcontext.
+ * The inner lookup must not restore a finished logcontext when done.
+ """
+ complete_lookup: "Deferred[None]" = Deferred()
+
+ class Cls:
+ inner_context_was_finished = False
+
+ @cached()
+ async def fn(self, arg1):
+ await make_deferred_yieldable(complete_lookup)
+ self.inner_context_was_finished = current_context().finished
+ return str(arg1)
+
+ obj = Cls()
+
+ async def do_lookup():
+ with LoggingContext("c1") as c1:
+ try:
+ await obj.fn(123)
+ self.fail("No CancelledError thrown")
+ except CancelledError:
+ self.assertEqual(
+ current_context(),
+ c1,
+ "CancelledError was not raised with the correct logcontext",
+ )
+ # suppress the error and succeed
+
+ d = defer.ensureDeferred(do_lookup())
+ d.cancel()
+
+ complete_lookup.callback(None)
+ self.successResultOf(d)
+ self.assertFalse(
+ obj.inner_context_was_finished, "Tried to restart a finished logcontext"
+ )
+ self.assertEqual(current_context(), SENTINEL_CONTEXT)
+
class CacheDecoratorTestCase(unittest.HomeserverTestCase):
"""More tests for @cached
@@ -656,7 +802,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
def fn(self, arg1, arg2):
pass
- @descriptors.cachedList("fn", "args1")
+ @descriptors.cachedList(cached_method_name="fn", list_name="args1")
async def list_fn(self, args1, arg2):
assert current_context().name == "c1"
# we want this to behave like an asynchronous function
@@ -715,7 +861,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
def fn(self, arg1):
pass
- @descriptors.cachedList("fn", "args1")
+ @descriptors.cachedList(cached_method_name="fn", list_name="args1")
def list_fn(self, args1) -> "Deferred[dict]":
return self.mock(args1)
@@ -758,7 +904,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
def fn(self, arg1, arg2):
pass
- @descriptors.cachedList("fn", "args1")
+ @descriptors.cachedList(cached_method_name="fn", list_name="args1")
async def list_fn(self, args1, arg2):
# we want this to behave like an asynchronous function
await run_on_reactor()
@@ -787,3 +933,78 @@ class CachedListDescriptorTestCase(unittest.TestCase):
obj.fn.invalidate((10, 2))
invalidate0.assert_called_once()
invalidate1.assert_called_once()
+
+ def test_cancel(self):
+ """Test that cancelling a lookup does not cancel other lookups"""
+ complete_lookup: "Deferred[None]" = Deferred()
+
+ class Cls:
+ @cached()
+ def fn(self, arg1):
+ pass
+
+ @cachedList(cached_method_name="fn", list_name="args")
+ async def list_fn(self, args):
+ await complete_lookup
+ return {arg: str(arg) for arg in args}
+
+ obj = Cls()
+
+ d1 = obj.list_fn([123, 456])
+ d2 = obj.list_fn([123, 456, 789])
+ self.assertFalse(d1.called)
+ self.assertFalse(d2.called)
+
+ d1.cancel()
+
+ # `d2` should complete normally.
+ complete_lookup.callback(None)
+ self.failureResultOf(d1, CancelledError)
+ self.assertEqual(d2.result, {123: "123", 456: "456", 789: "789"})
+
+ def test_cancel_logcontexts(self):
+ """Test that cancellation does not break logcontexts.
+
+ * The `CancelledError` must be raised with the correct logcontext.
+ * The inner lookup must not resume with a finished logcontext.
+ * The inner lookup must not restore a finished logcontext when done.
+ """
+ complete_lookup: "Deferred[None]" = Deferred()
+
+ class Cls:
+ inner_context_was_finished = False
+
+ @cached()
+ def fn(self, arg1):
+ pass
+
+ @cachedList(cached_method_name="fn", list_name="args")
+ async def list_fn(self, args):
+ await make_deferred_yieldable(complete_lookup)
+ self.inner_context_was_finished = current_context().finished
+ return {arg: str(arg) for arg in args}
+
+ obj = Cls()
+
+ async def do_lookup():
+ with LoggingContext("c1") as c1:
+ try:
+ await obj.list_fn([123])
+ self.fail("No CancelledError thrown")
+ except CancelledError:
+ self.assertEqual(
+ current_context(),
+ c1,
+ "CancelledError was not raised with the correct logcontext",
+ )
+ # suppress the error and succeed
+
+ d = defer.ensureDeferred(do_lookup())
+ d.cancel()
+
+ complete_lookup.callback(None)
+ self.successResultOf(d)
+ self.assertFalse(
+ obj.inner_context_was_finished, "Tried to restart a finished logcontext"
+ )
+ self.assertEqual(current_context(), SENTINEL_CONTEXT)
diff --git a/tests/util/test_async_helpers.py b/tests/util/test_async_helpers.py
index 362014f4cb..e5bc416de1 100644
--- a/tests/util/test_async_helpers.py
+++ b/tests/util/test_async_helpers.py
@@ -13,6 +13,8 @@
# limitations under the License.
import traceback
+from parameterized import parameterized_class
+
from twisted.internet import defer
from twisted.internet.defer import CancelledError, Deferred, ensureDeferred
from twisted.internet.task import Clock
@@ -23,10 +25,12 @@ from synapse.logging.context import (
LoggingContext,
PreserveLoggingContext,
current_context,
+ make_deferred_yieldable,
)
from synapse.util.async_helpers import (
ObservableDeferred,
concurrently_execute,
+ delay_cancellation,
stop_cancellation,
timeout_deferred,
)
@@ -100,6 +104,34 @@ class ObservableDeferredTest(TestCase):
self.assertEqual(str(results[0].value), "gah!", "observer 1 errback result")
self.assertEqual(str(results[1].value), "gah!", "observer 2 errback result")
+ def test_cancellation(self):
+ """Test that cancelling an observer does not affect other observers."""
+ origin_d: "Deferred[int]" = Deferred()
+ observable = ObservableDeferred(origin_d, consumeErrors=True)
+
+ observer1 = observable.observe()
+ observer2 = observable.observe()
+ observer3 = observable.observe()
+
+ self.assertFalse(observer1.called)
+ self.assertFalse(observer2.called)
+ self.assertFalse(observer3.called)
+
+ # cancel the second observer
+ observer2.cancel()
+ self.assertFalse(observer1.called)
+ self.failureResultOf(observer2, CancelledError)
+ self.assertFalse(observer3.called)
+
+ # other observers resolve as normal
+ origin_d.callback(123)
+ self.assertEqual(observer1.result, 123, "observer 1 callback result")
+ self.assertEqual(observer3.result, 123, "observer 3 callback result")
+
+ # additional observers resolve as normal
+ observer4 = observable.observe()
+ self.assertEqual(observer4.result, 123, "observer 4 callback result")
+
class TimeoutDeferredTest(TestCase):
def setUp(self):
@@ -285,13 +317,27 @@ class ConcurrentlyExecuteTest(TestCase):
self.successResultOf(d2)
-class StopCancellationTests(TestCase):
- """Tests for the `stop_cancellation` function."""
+@parameterized_class(
+ ("wrapper",),
+ [("stop_cancellation",), ("delay_cancellation",)],
+)
+class CancellationWrapperTests(TestCase):
+ """Common tests for the `stop_cancellation` and `delay_cancellation` functions."""
+
+ wrapper: str
+
+ def wrap_deferred(self, deferred: "Deferred[str]") -> "Deferred[str]":
+ if self.wrapper == "stop_cancellation":
+ return stop_cancellation(deferred)
+ elif self.wrapper == "delay_cancellation":
+ return delay_cancellation(deferred)
+ else:
+ raise ValueError(f"Unsupported wrapper type: {self.wrapper}")
def test_succeed(self):
"""Test that the new `Deferred` receives the result."""
deferred: "Deferred[str]" = Deferred()
- wrapper_deferred = stop_cancellation(deferred)
+ wrapper_deferred = self.wrap_deferred(deferred)
# Success should propagate through.
deferred.callback("success")
@@ -301,7 +347,7 @@ class StopCancellationTests(TestCase):
def test_failure(self):
"""Test that the new `Deferred` receives the `Failure`."""
deferred: "Deferred[str]" = Deferred()
- wrapper_deferred = stop_cancellation(deferred)
+ wrapper_deferred = self.wrap_deferred(deferred)
# Failure should propagate through.
deferred.errback(ValueError("abc"))
@@ -309,6 +355,10 @@ class StopCancellationTests(TestCase):
self.failureResultOf(wrapper_deferred, ValueError)
self.assertIsNone(deferred.result, "`Failure` was not consumed")
+
+class StopCancellationTests(TestCase):
+ """Tests for the `stop_cancellation` function."""
+
def test_cancellation(self):
"""Test that cancellation of the new `Deferred` leaves the original running."""
deferred: "Deferred[str]" = Deferred()
@@ -319,11 +369,101 @@ class StopCancellationTests(TestCase):
self.assertTrue(wrapper_deferred.called)
self.failureResultOf(wrapper_deferred, CancelledError)
self.assertFalse(
- deferred.called, "Original `Deferred` was unexpectedly cancelled."
+ deferred.called, "Original `Deferred` was unexpectedly cancelled"
+ )
+
+ # Now make the original `Deferred` fail.
+ # The `Failure` must be consumed, otherwise unwanted tracebacks will be printed
+ # in logs.
+ deferred.errback(ValueError("abc"))
+ self.assertIsNone(deferred.result, "`Failure` was not consumed")
+
+
+class DelayCancellationTests(TestCase):
+ """Tests for the `delay_cancellation` function."""
+
+ def test_cancellation(self):
+ """Test that cancellation of the new `Deferred` waits for the original."""
+ deferred: "Deferred[str]" = Deferred()
+ wrapper_deferred = delay_cancellation(deferred)
+
+ # Cancel the new `Deferred`.
+ wrapper_deferred.cancel()
+ self.assertNoResult(wrapper_deferred)
+ self.assertFalse(
+ deferred.called, "Original `Deferred` was unexpectedly cancelled"
+ )
+
+ # Now make the original `Deferred` fail.
+ # The `Failure` must be consumed, otherwise unwanted tracebacks will be printed
+ # in logs.
+ deferred.errback(ValueError("abc"))
+ self.assertIsNone(deferred.result, "`Failure` was not consumed")
+
+ # Now that the original `Deferred` has failed, we should get a `CancelledError`.
+ self.failureResultOf(wrapper_deferred, CancelledError)
+
+ def test_suppresses_second_cancellation(self):
+ """Test that a second cancellation is suppressed.
+
+ Identical to `test_cancellation` except the new `Deferred` is cancelled twice.
+ """
+ deferred: "Deferred[str]" = Deferred()
+ wrapper_deferred = delay_cancellation(deferred)
+
+ # Cancel the new `Deferred`, twice.
+ wrapper_deferred.cancel()
+ wrapper_deferred.cancel()
+ self.assertNoResult(wrapper_deferred)
+ self.assertFalse(
+ deferred.called, "Original `Deferred` was unexpectedly cancelled"
)
- # Now make the inner `Deferred` fail.
+ # Now make the original `Deferred` fail.
# The `Failure` must be consumed, otherwise unwanted tracebacks will be printed
# in logs.
deferred.errback(ValueError("abc"))
self.assertIsNone(deferred.result, "`Failure` was not consumed")
+
+ # Now that the original `Deferred` has failed, we should get a `CancelledError`.
+ self.failureResultOf(wrapper_deferred, CancelledError)
+
+ def test_propagates_cancelled_error(self):
+ """Test that a `CancelledError` from the original `Deferred` gets propagated."""
+ deferred: "Deferred[str]" = Deferred()
+ wrapper_deferred = delay_cancellation(deferred)
+
+ # Fail the original `Deferred` with a `CancelledError`.
+ cancelled_error = CancelledError()
+ deferred.errback(cancelled_error)
+
+ # The new `Deferred` should fail with exactly the same `CancelledError`.
+ self.assertTrue(wrapper_deferred.called)
+ self.assertIs(cancelled_error, self.failureResultOf(wrapper_deferred).value)
+
+ def test_preserves_logcontext(self):
+ """Test that logging contexts are preserved."""
+ blocking_d: "Deferred[None]" = Deferred()
+
+ async def inner():
+ await make_deferred_yieldable(blocking_d)
+
+ async def outer():
+ with LoggingContext("c") as c:
+ try:
+ await delay_cancellation(defer.ensureDeferred(inner()))
+ self.fail("`CancelledError` was not raised")
+ except CancelledError:
+ self.assertEqual(c, current_context())
+ # Succeed with no error, unless the logging context is wrong.
+
+ # Run and block inside `inner()`.
+ d = defer.ensureDeferred(outer())
+ self.assertEqual(SENTINEL_CONTEXT, current_context())
+
+ d.cancel()
+
+ # Now unblock. `outer()` will consume the `CancelledError` and check the
+ # logging context.
+ blocking_d.callback(None)
+ self.successResultOf(d)
diff --git a/tests/util/test_check_dependencies.py b/tests/util/test_check_dependencies.py
index 3c07252252..5d1aa025d1 100644
--- a/tests/util/test_check_dependencies.py
+++ b/tests/util/test_check_dependencies.py
@@ -12,7 +12,7 @@ from tests.unittest import TestCase
class DummyDistribution(metadata.Distribution):
- def __init__(self, version: str):
+ def __init__(self, version: object):
self._version = version
@property
@@ -27,7 +27,10 @@ class DummyDistribution(metadata.Distribution):
old = DummyDistribution("0.1.2")
+old_release_candidate = DummyDistribution("0.1.2rc3")
new = DummyDistribution("1.2.3")
+new_release_candidate = DummyDistribution("1.2.3rc4")
+distribution_with_no_version = DummyDistribution(None)
# could probably use stdlib TestCase --- no need for twisted here
@@ -65,6 +68,35 @@ class TestDependencyChecker(TestCase):
# should not raise
check_requirements()
+ def test_version_reported_as_none(self) -> None:
+ """Complain if importlib.metadata.version() returns None.
+
+ This shouldn't normally happen, but it was seen in the wild (#12223).
+ """
+ with patch(
+ "synapse.util.check_dependencies.metadata.requires",
+ return_value=["dummypkg >= 1"],
+ ):
+ with self.mock_installed_package(distribution_with_no_version):
+ self.assertRaises(DependencyException, check_requirements)
+
+ def test_checks_ignore_dev_dependencies(self) -> None:
+ """Bot generic and per-extra checks should ignore dev dependencies."""
+ with patch(
+ "synapse.util.check_dependencies.metadata.requires",
+ return_value=["dummypkg >= 1; extra == 'mypy'"],
+ ), patch("synapse.util.check_dependencies.RUNTIME_EXTRAS", {"cool-extra"}):
+ # We're testing that none of these calls raise.
+ with self.mock_installed_package(None):
+ check_requirements()
+ check_requirements("cool-extra")
+ with self.mock_installed_package(old):
+ check_requirements()
+ check_requirements("cool-extra")
+ with self.mock_installed_package(new):
+ check_requirements()
+ check_requirements("cool-extra")
+
def test_generic_check_of_optional_dependency(self) -> None:
"""Complain if an optional package is old."""
with patch(
@@ -85,11 +117,28 @@ class TestDependencyChecker(TestCase):
with patch(
"synapse.util.check_dependencies.metadata.requires",
return_value=["dummypkg >= 1; extra == 'cool-extra'"],
- ), patch("synapse.util.check_dependencies.EXTRAS", {"cool-extra"}):
+ ), patch("synapse.util.check_dependencies.RUNTIME_EXTRAS", {"cool-extra"}):
with self.mock_installed_package(None):
self.assertRaises(DependencyException, check_requirements, "cool-extra")
with self.mock_installed_package(old):
self.assertRaises(DependencyException, check_requirements, "cool-extra")
with self.mock_installed_package(new):
# should not raise
+ check_requirements("cool-extra")
+
+ def test_release_candidates_satisfy_dependency(self) -> None:
+ """
+ Tests that release candidates count as far as satisfying a dependency
+ is concerned.
+ (Regression test, see #12176.)
+ """
+ with patch(
+ "synapse.util.check_dependencies.metadata.requires",
+ return_value=["dummypkg >= 1"],
+ ):
+ with self.mock_installed_package(old_release_candidate):
+ self.assertRaises(DependencyException, check_requirements)
+
+ with self.mock_installed_package(new_release_candidate):
+ # should not raise
check_requirements()
diff --git a/tests/util/test_rwlock.py b/tests/util/test_rwlock.py
index 0774625b85..0c84226197 100644
--- a/tests/util/test_rwlock.py
+++ b/tests/util/test_rwlock.py
@@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import AsyncContextManager, Callable, Sequence, Tuple
+
from twisted.internet import defer
-from twisted.internet.defer import Deferred
+from twisted.internet.defer import CancelledError, Deferred
from synapse.util.async_helpers import ReadWriteLock
@@ -21,87 +23,187 @@ from tests import unittest
class ReadWriteLockTestCase(unittest.TestCase):
- def _assert_called_before_not_after(self, lst, first_false):
- for i, d in enumerate(lst[:first_false]):
- self.assertTrue(d.called, msg="%d was unexpectedly false" % i)
+ def _start_reader_or_writer(
+ self,
+ read_or_write: Callable[[str], AsyncContextManager],
+ key: str,
+ return_value: str,
+ ) -> Tuple["Deferred[str]", "Deferred[None]", "Deferred[None]"]:
+ """Starts a reader or writer which acquires the lock, blocks, then completes.
+
+ Args:
+ read_or_write: A function returning a context manager for a lock.
+ Either a bound `ReadWriteLock.read` or `ReadWriteLock.write`.
+ key: The key to read or write.
+ return_value: A string that the reader or writer will resolve with when
+ done.
+
+ Returns:
+ A tuple of three `Deferred`s:
+ * A `Deferred` that resolves with `return_value` once the reader or writer
+ completes successfully.
+ * A `Deferred` that resolves once the reader or writer acquires the lock.
+ * A `Deferred` that blocks the reader or writer. Must be resolved by the
+ caller to allow the reader or writer to release the lock and complete.
+ """
+ acquired_d: "Deferred[None]" = Deferred()
+ unblock_d: "Deferred[None]" = Deferred()
+
+ async def reader_or_writer():
+ async with read_or_write(key):
+ acquired_d.callback(None)
+ await unblock_d
+ return return_value
+
+ d = defer.ensureDeferred(reader_or_writer())
+ return d, acquired_d, unblock_d
+
+ def _start_blocking_reader(
+ self, rwlock: ReadWriteLock, key: str, return_value: str
+ ) -> Tuple["Deferred[str]", "Deferred[None]", "Deferred[None]"]:
+ """Starts a reader which acquires the lock, blocks, then releases the lock.
+
+ See the docstring for `_start_reader_or_writer` for details about the arguments
+ and return values.
+ """
+ return self._start_reader_or_writer(rwlock.read, key, return_value)
+
+ def _start_blocking_writer(
+ self, rwlock: ReadWriteLock, key: str, return_value: str
+ ) -> Tuple["Deferred[str]", "Deferred[None]", "Deferred[None]"]:
+ """Starts a writer which acquires the lock, blocks, then releases the lock.
+
+ See the docstring for `_start_reader_or_writer` for details about the arguments
+ and return values.
+ """
+ return self._start_reader_or_writer(rwlock.write, key, return_value)
+
+ def _start_nonblocking_reader(
+ self, rwlock: ReadWriteLock, key: str, return_value: str
+ ) -> Tuple["Deferred[str]", "Deferred[None]"]:
+ """Starts a reader which acquires the lock, then releases it immediately.
+
+ See the docstring for `_start_reader_or_writer` for details about the arguments.
+
+ Returns:
+ A tuple of two `Deferred`s:
+ * A `Deferred` that resolves with `return_value` once the reader completes
+ successfully.
+ * A `Deferred` that resolves once the reader acquires the lock.
+ """
+ d, acquired_d, unblock_d = self._start_reader_or_writer(
+ rwlock.read, key, return_value
+ )
+ unblock_d.callback(None)
+ return d, acquired_d
+
+ def _start_nonblocking_writer(
+ self, rwlock: ReadWriteLock, key: str, return_value: str
+ ) -> Tuple["Deferred[str]", "Deferred[None]"]:
+ """Starts a writer which acquires the lock, then releases it immediately.
+
+ See the docstring for `_start_reader_or_writer` for details about the arguments.
+
+ Returns:
+ A tuple of two `Deferred`s:
+ * A `Deferred` that resolves with `return_value` once the writer completes
+ successfully.
+ * A `Deferred` that resolves once the writer acquires the lock.
+ """
+ d, acquired_d, unblock_d = self._start_reader_or_writer(
+ rwlock.write, key, return_value
+ )
+ unblock_d.callback(None)
+ return d, acquired_d
+
+ def _assert_first_n_resolved(
+ self, deferreds: Sequence["defer.Deferred[None]"], n: int
+ ) -> None:
+ """Assert that exactly the first n `Deferred`s in the given list are resolved.
- for i, d in enumerate(lst[first_false:]):
+ Args:
+ deferreds: The list of `Deferred`s to be checked.
+ n: The number of `Deferred`s at the start of `deferreds` that should be
+ resolved.
+ """
+ for i, d in enumerate(deferreds[:n]):
+ self.assertTrue(d.called, msg="deferred %d was unexpectedly unresolved" % i)
+
+ for i, d in enumerate(deferreds[n:]):
self.assertFalse(
- d.called, msg="%d was unexpectedly true" % (i + first_false)
+ d.called, msg="deferred %d was unexpectedly resolved" % (i + n)
)
def test_rwlock(self):
rwlock = ReadWriteLock()
-
- key = object()
+ key = "key"
ds = [
- rwlock.read(key), # 0
- rwlock.read(key), # 1
- rwlock.write(key), # 2
- rwlock.write(key), # 3
- rwlock.read(key), # 4
- rwlock.read(key), # 5
- rwlock.write(key), # 6
+ self._start_blocking_reader(rwlock, key, "0"),
+ self._start_blocking_reader(rwlock, key, "1"),
+ self._start_blocking_writer(rwlock, key, "2"),
+ self._start_blocking_writer(rwlock, key, "3"),
+ self._start_blocking_reader(rwlock, key, "4"),
+ self._start_blocking_reader(rwlock, key, "5"),
+ self._start_blocking_writer(rwlock, key, "6"),
]
- ds = [defer.ensureDeferred(d) for d in ds]
+ # `Deferred`s that resolve when each reader or writer acquires the lock.
+ acquired_ds = [acquired_d for _, acquired_d, _ in ds]
+ # `Deferred`s that will trigger the release of locks when resolved.
+ release_ds = [release_d for _, _, release_d in ds]
- self._assert_called_before_not_after(ds, 2)
+ # The first two readers should acquire their locks.
+ self._assert_first_n_resolved(acquired_ds, 2)
- with ds[0].result:
- self._assert_called_before_not_after(ds, 2)
- self._assert_called_before_not_after(ds, 2)
+ # Release one of the read locks. The next writer should not acquire the lock,
+ # because there is another reader holding the lock.
+ self._assert_first_n_resolved(acquired_ds, 2)
+ release_ds[0].callback(None)
+ self._assert_first_n_resolved(acquired_ds, 2)
- with ds[1].result:
- self._assert_called_before_not_after(ds, 2)
- self._assert_called_before_not_after(ds, 3)
+ # Release the other read lock. The next writer should acquire the lock.
+ self._assert_first_n_resolved(acquired_ds, 2)
+ release_ds[1].callback(None)
+ self._assert_first_n_resolved(acquired_ds, 3)
- with ds[2].result:
- self._assert_called_before_not_after(ds, 3)
- self._assert_called_before_not_after(ds, 4)
+ # Release the write lock. The next writer should acquire the lock.
+ self._assert_first_n_resolved(acquired_ds, 3)
+ release_ds[2].callback(None)
+ self._assert_first_n_resolved(acquired_ds, 4)
- with ds[3].result:
- self._assert_called_before_not_after(ds, 4)
- self._assert_called_before_not_after(ds, 6)
+ # Release the write lock. The next two readers should acquire locks.
+ self._assert_first_n_resolved(acquired_ds, 4)
+ release_ds[3].callback(None)
+ self._assert_first_n_resolved(acquired_ds, 6)
- with ds[5].result:
- self._assert_called_before_not_after(ds, 6)
- self._assert_called_before_not_after(ds, 6)
+ # Release one of the read locks. The next writer should not acquire the lock,
+ # because there is another reader holding the lock.
+ self._assert_first_n_resolved(acquired_ds, 6)
+ release_ds[5].callback(None)
+ self._assert_first_n_resolved(acquired_ds, 6)
- with ds[4].result:
- self._assert_called_before_not_after(ds, 6)
- self._assert_called_before_not_after(ds, 7)
+ # Release the other read lock. The next writer should acquire the lock.
+ self._assert_first_n_resolved(acquired_ds, 6)
+ release_ds[4].callback(None)
+ self._assert_first_n_resolved(acquired_ds, 7)
- with ds[6].result:
- pass
+ # Release the write lock.
+ release_ds[6].callback(None)
- d = defer.ensureDeferred(rwlock.write(key))
- self.assertTrue(d.called)
- with d.result:
- pass
+ # Acquire and release the write and read locks one last time for good measure.
+ _, acquired_d = self._start_nonblocking_writer(rwlock, key, "last writer")
+ self.assertTrue(acquired_d.called)
- d = defer.ensureDeferred(rwlock.read(key))
- self.assertTrue(d.called)
- with d.result:
- pass
+ _, acquired_d = self._start_nonblocking_reader(rwlock, key, "last reader")
+ self.assertTrue(acquired_d.called)
def test_lock_handoff_to_nonblocking_writer(self):
"""Test a writer handing the lock to another writer that completes instantly."""
rwlock = ReadWriteLock()
key = "key"
- unblock: "Deferred[None]" = Deferred()
-
- async def blocking_write():
- with await rwlock.write(key):
- await unblock
-
- async def nonblocking_write():
- with await rwlock.write(key):
- pass
-
- d1 = defer.ensureDeferred(blocking_write())
- d2 = defer.ensureDeferred(nonblocking_write())
+ d1, _, unblock = self._start_blocking_writer(rwlock, key, "write 1 completed")
+ d2, _ = self._start_nonblocking_writer(rwlock, key, "write 2 completed")
self.assertFalse(d1.called)
self.assertFalse(d2.called)
@@ -111,5 +213,182 @@ class ReadWriteLockTestCase(unittest.TestCase):
self.assertTrue(d2.called)
# The `ReadWriteLock` should operate as normal.
- d3 = defer.ensureDeferred(nonblocking_write())
+ d3, _ = self._start_nonblocking_writer(rwlock, key, "write 3 completed")
self.assertTrue(d3.called)
+
+ def test_cancellation_while_holding_read_lock(self):
+ """Test cancellation while holding a read lock.
+
+ A waiting writer should be given the lock when the reader holding the lock is
+ cancelled.
+ """
+ rwlock = ReadWriteLock()
+ key = "key"
+
+ # 1. A reader takes the lock and blocks.
+ reader_d, _, _ = self._start_blocking_reader(rwlock, key, "read completed")
+
+ # 2. A writer waits for the reader to complete.
+ writer_d, _ = self._start_nonblocking_writer(rwlock, key, "write completed")
+ self.assertFalse(writer_d.called)
+
+ # 3. The reader is cancelled.
+ reader_d.cancel()
+ self.failureResultOf(reader_d, CancelledError)
+
+ # 4. The writer should take the lock and complete.
+ self.assertTrue(
+ writer_d.called, "Writer is stuck waiting for a cancelled reader"
+ )
+ self.assertEqual("write completed", self.successResultOf(writer_d))
+
+ def test_cancellation_while_holding_write_lock(self):
+ """Test cancellation while holding a write lock.
+
+ A waiting reader should be given the lock when the writer holding the lock is
+ cancelled.
+ """
+ rwlock = ReadWriteLock()
+ key = "key"
+
+ # 1. A writer takes the lock and blocks.
+ writer_d, _, _ = self._start_blocking_writer(rwlock, key, "write completed")
+
+ # 2. A reader waits for the writer to complete.
+ reader_d, _ = self._start_nonblocking_reader(rwlock, key, "read completed")
+ self.assertFalse(reader_d.called)
+
+ # 3. The writer is cancelled.
+ writer_d.cancel()
+ self.failureResultOf(writer_d, CancelledError)
+
+ # 4. The reader should take the lock and complete.
+ self.assertTrue(
+ reader_d.called, "Reader is stuck waiting for a cancelled writer"
+ )
+ self.assertEqual("read completed", self.successResultOf(reader_d))
+
+ def test_cancellation_while_waiting_for_read_lock(self):
+ """Test cancellation while waiting for a read lock.
+
+ Tests that cancelling a waiting reader:
+ * does not cancel the writer it is waiting on
+ * does not cancel the next writer waiting on it
+ * does not allow the next writer to acquire the lock before an earlier writer
+ has finished
+ * does not keep the next writer waiting indefinitely
+
+ These correspond to the asserts with explicit messages.
+ """
+ rwlock = ReadWriteLock()
+ key = "key"
+
+ # 1. A writer takes the lock and blocks.
+ writer1_d, _, unblock_writer1 = self._start_blocking_writer(
+ rwlock, key, "write 1 completed"
+ )
+
+ # 2. A reader waits for the first writer to complete.
+ # This reader will be cancelled later.
+ reader_d, _ = self._start_nonblocking_reader(rwlock, key, "read completed")
+ self.assertFalse(reader_d.called)
+
+ # 3. A second writer waits for both the first writer and the reader to complete.
+ writer2_d, _ = self._start_nonblocking_writer(rwlock, key, "write 2 completed")
+ self.assertFalse(writer2_d.called)
+
+ # 4. The waiting reader is cancelled.
+ # Neither of the writers should be cancelled.
+ # The second writer should still be waiting, but only on the first writer.
+ reader_d.cancel()
+ self.failureResultOf(reader_d, CancelledError)
+ self.assertFalse(writer1_d.called, "First writer was unexpectedly cancelled")
+ self.assertFalse(
+ writer2_d.called,
+ "Second writer was unexpectedly cancelled or given the lock before the "
+ "first writer finished",
+ )
+
+ # 5. Unblock the first writer, which should complete.
+ unblock_writer1.callback(None)
+ self.assertEqual("write 1 completed", self.successResultOf(writer1_d))
+
+ # 6. The second writer should take the lock and complete.
+ self.assertTrue(
+ writer2_d.called, "Second writer is stuck waiting for a cancelled reader"
+ )
+ self.assertEqual("write 2 completed", self.successResultOf(writer2_d))
+
+ def test_cancellation_while_waiting_for_write_lock(self):
+ """Test cancellation while waiting for a write lock.
+
+ Tests that cancelling a waiting writer:
+ * does not cancel the reader or writer it is waiting on
+ * does not cancel the next writer waiting on it
+ * does not allow the next writer to acquire the lock before an earlier reader
+ and writer have finished
+ * does not keep the next writer waiting indefinitely
+
+ These correspond to the asserts with explicit messages.
+ """
+ rwlock = ReadWriteLock()
+ key = "key"
+
+ # 1. A reader takes the lock and blocks.
+ reader_d, _, unblock_reader = self._start_blocking_reader(
+ rwlock, key, "read completed"
+ )
+
+ # 2. A writer waits for the reader to complete.
+ writer1_d, _, unblock_writer1 = self._start_blocking_writer(
+ rwlock, key, "write 1 completed"
+ )
+
+ # 3. A second writer waits for both the reader and first writer to complete.
+ # This writer will be cancelled later.
+ writer2_d, _ = self._start_nonblocking_writer(rwlock, key, "write 2 completed")
+ self.assertFalse(writer2_d.called)
+
+ # 4. A third writer waits for the second writer to complete.
+ writer3_d, _ = self._start_nonblocking_writer(rwlock, key, "write 3 completed")
+ self.assertFalse(writer3_d.called)
+
+ # 5. The second writer is cancelled, but continues waiting for the lock.
+ # The reader, first writer and third writer should not be cancelled.
+ # The first writer should still be waiting on the reader.
+ # The third writer should still be waiting on the second writer.
+ writer2_d.cancel()
+ self.assertNoResult(writer2_d)
+ self.assertFalse(reader_d.called, "Reader was unexpectedly cancelled")
+ self.assertFalse(writer1_d.called, "First writer was unexpectedly cancelled")
+ self.assertFalse(
+ writer3_d.called,
+ "Third writer was unexpectedly cancelled or given the lock before the first "
+ "writer finished",
+ )
+
+ # 6. Unblock the reader, which should complete.
+ # The first writer should be given the lock and block.
+ # The third writer should still be waiting on the second writer.
+ unblock_reader.callback(None)
+ self.assertEqual("read completed", self.successResultOf(reader_d))
+ self.assertNoResult(writer2_d)
+ self.assertFalse(
+ writer3_d.called,
+ "Third writer was unexpectedly given the lock before the first writer "
+ "finished",
+ )
+
+ # 7. Unblock the first writer, which should complete.
+ unblock_writer1.callback(None)
+ self.assertEqual("write 1 completed", self.successResultOf(writer1_d))
+
+ # 8. The second writer should take the lock and release it immediately, since it
+ # has been cancelled.
+ self.failureResultOf(writer2_d, CancelledError)
+
+ # 9. The third writer should take the lock and complete.
+ self.assertTrue(
+ writer3_d.called, "Third writer is stuck waiting for a cancelled writer"
+ )
+ self.assertEqual("write 3 completed", self.successResultOf(writer3_d))
diff --git a/tests/utils.py b/tests/utils.py
index ef99c72e0b..f6b1d60371 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -15,13 +15,8 @@
import atexit
import os
-from unittest.mock import Mock, patch
-from urllib import parse as urlparse
-
-from twisted.internet import defer
from synapse.api.constants import EventTypes
-from synapse.api.errors import CodeMessageException, cs_error
from synapse.api.room_versions import RoomVersions
from synapse.config.homeserver import HomeServerConfig
from synapse.config.server import DEFAULT_ROOM_VERSION
@@ -187,111 +182,6 @@ def mock_getRawHeaders(headers=None):
return getRawHeaders
-# This is a mock /resource/ not an entire server
-class MockHttpResource:
- def __init__(self, prefix=""):
- self.callbacks = [] # 3-tuple of method/pattern/function
- self.prefix = prefix
-
- def trigger_get(self, path):
- return self.trigger(b"GET", path, None)
-
- @patch("twisted.web.http.Request")
- @defer.inlineCallbacks
- def trigger(
- self, http_method, path, content, mock_request, federation_auth_origin=None
- ):
- """Fire an HTTP event.
-
- Args:
- http_method : The HTTP method
- path : The HTTP path
- content : The HTTP body
- mock_request : Mocked request to pass to the event so it can get
- content.
- federation_auth_origin (bytes|None): domain to authenticate as, for federation
- Returns:
- A tuple of (code, response)
- Raises:
- KeyError If no event is found which will handle the path.
- """
- path = self.prefix + path
-
- # annoyingly we return a twisted http request which has chained calls
- # to get at the http content, hence mock it here.
- mock_content = Mock()
- config = {"read.return_value": content}
- mock_content.configure_mock(**config)
- mock_request.content = mock_content
-
- mock_request.method = http_method.encode("ascii")
- mock_request.uri = path.encode("ascii")
-
- mock_request.getClientIP.return_value = "-"
-
- headers = {}
- if federation_auth_origin is not None:
- headers[b"Authorization"] = [
- b"X-Matrix origin=%s,key=,sig=" % (federation_auth_origin,)
- ]
- mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers)
-
- # return the right path if the event requires it
- mock_request.path = path
-
- # add in query params to the right place
- try:
- mock_request.args = urlparse.parse_qs(path.split("?")[1])
- mock_request.path = path.split("?")[0]
- path = mock_request.path
- except Exception:
- pass
-
- if isinstance(path, bytes):
- path = path.decode("utf8")
-
- for (method, pattern, func) in self.callbacks:
- if http_method != method:
- continue
-
- matcher = pattern.match(path)
- if matcher:
- try:
- args = [urlparse.unquote(u) for u in matcher.groups()]
-
- (code, response) = yield defer.ensureDeferred(
- func(mock_request, *args)
- )
- return code, response
- except CodeMessageException as e:
- return e.code, cs_error(e.msg, code=e.errcode)
-
- raise KeyError("No event can handle %s" % path)
-
- def register_paths(self, method, path_patterns, callback, servlet_name):
- for path_pattern in path_patterns:
- self.callbacks.append((method, path_pattern, callback))
-
-
-class MockKey:
- alg = "mock_alg"
- version = "mock_version"
- signature = b"\x9a\x87$"
-
- @property
- def verify_key(self):
- return self
-
- def sign(self, message):
- return self
-
- def verify(self, message, sig):
- assert sig == b"\x9a\x87$"
-
- def encode(self):
- return b"<fake_encoded_key>"
-
-
class MockClock:
now = 1000
|