diff options
author | Brendan Abolivier <babolivier@matrix.org> | 2022-03-28 13:54:02 +0100 |
---|---|---|
committer | Brendan Abolivier <babolivier@matrix.org> | 2022-03-28 13:54:02 +0100 |
commit | 25507bffc67c40e83cbcd4a79fdfee3667855a7c (patch) | |
tree | 5620b2a06a5a9894ac875ddcf3b232db45cae48d /tests | |
parent | Merge branch 'develop' of github.com:matrix-org/synapse into babolivier/sign_... (diff) | |
parent | Add restrictions by default to open registration in Synapse (#12091) (diff) | |
download | synapse-babolivier/sign_json_module.tar.xz |
Merge branch 'develop' into babolivier/sign_json_module github/babolivier/sign_json_module babolivier/sign_json_module
Diffstat (limited to 'tests')
59 files changed, 3330 insertions, 1510 deletions
diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py index 9bd6275e92..edc584d0cf 100644 --- a/tests/appservice/test_appservice.py +++ b/tests/appservice/test_appservice.py @@ -36,7 +36,10 @@ class ApplicationServiceTestCase(unittest.TestCase): hostname="matrix.org", # only used by get_groups_for_user ) self.event = Mock( - type="m.something", room_id="!foo:bar", sender="@someone:somewhere" + event_id="$abc:xyz", + type="m.something", + room_id="!foo:bar", + sender="@someone:somewhere", ) self.store = Mock() @@ -50,7 +53,9 @@ class ApplicationServiceTestCase(unittest.TestCase): self.assertTrue( ( yield defer.ensureDeferred( - self.service.is_interested(self.event, self.store) + self.service.is_interested_in_event( + self.event.event_id, self.event, self.store + ) ) ) ) @@ -62,7 +67,9 @@ class ApplicationServiceTestCase(unittest.TestCase): self.assertFalse( ( yield defer.ensureDeferred( - self.service.is_interested(self.event, self.store) + self.service.is_interested_in_event( + self.event.event_id, self.event, self.store + ) ) ) ) @@ -76,7 +83,9 @@ class ApplicationServiceTestCase(unittest.TestCase): self.assertTrue( ( yield defer.ensureDeferred( - self.service.is_interested(self.event, self.store) + self.service.is_interested_in_event( + self.event.event_id, self.event, self.store + ) ) ) ) @@ -90,7 +99,9 @@ class ApplicationServiceTestCase(unittest.TestCase): self.assertTrue( ( yield defer.ensureDeferred( - self.service.is_interested(self.event, self.store) + self.service.is_interested_in_event( + self.event.event_id, self.event, self.store + ) ) ) ) @@ -104,7 +115,9 @@ class ApplicationServiceTestCase(unittest.TestCase): self.assertFalse( ( yield defer.ensureDeferred( - self.service.is_interested(self.event, self.store) + self.service.is_interested_in_event( + self.event.event_id, self.event, self.store + ) ) ) ) @@ -121,7 +134,9 @@ class ApplicationServiceTestCase(unittest.TestCase): self.assertTrue( ( yield defer.ensureDeferred( - self.service.is_interested(self.event, self.store) + self.service.is_interested_in_event( + self.event.event_id, self.event, self.store + ) ) ) ) @@ -174,7 +189,9 @@ class ApplicationServiceTestCase(unittest.TestCase): self.assertFalse( ( yield defer.ensureDeferred( - self.service.is_interested(self.event, self.store) + self.service.is_interested_in_event( + self.event.event_id, self.event, self.store + ) ) ) ) @@ -191,7 +208,9 @@ class ApplicationServiceTestCase(unittest.TestCase): self.assertTrue( ( yield defer.ensureDeferred( - self.service.is_interested(self.event, self.store) + self.service.is_interested_in_event( + self.event.event_id, self.event, self.store + ) ) ) ) @@ -207,7 +226,9 @@ class ApplicationServiceTestCase(unittest.TestCase): self.assertTrue( ( yield defer.ensureDeferred( - self.service.is_interested(self.event, self.store) + self.service.is_interested_in_event( + self.event.event_id, self.event, self.store + ) ) ) ) @@ -225,7 +246,9 @@ class ApplicationServiceTestCase(unittest.TestCase): self.assertTrue( ( yield defer.ensureDeferred( - self.service.is_interested(event=self.event, store=self.store) + self.service.is_interested_in_event( + self.event.event_id, self.event, self.store + ) ) ) ) diff --git a/tests/config/test_background_update.py b/tests/config/test_background_update.py new file mode 100644 index 0000000000..0c32c1ca29 --- /dev/null +++ b/tests/config/test_background_update.py @@ -0,0 +1,58 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import yaml + +from synapse.storage.background_updates import BackgroundUpdater + +from tests.unittest import HomeserverTestCase, override_config + + +class BackgroundUpdateConfigTestCase(HomeserverTestCase): + # Tests that the default values in the config are correctly loaded. Note that the default + # values are loaded when the corresponding config options are commented out, which is why there isn't + # a config specified here. + def test_default_configuration(self): + background_updater = BackgroundUpdater( + self.hs, self.hs.get_datastores().main.db_pool + ) + + self.assertEqual(background_updater.minimum_background_batch_size, 1) + self.assertEqual(background_updater.default_background_batch_size, 100) + self.assertEqual(background_updater.sleep_enabled, True) + self.assertEqual(background_updater.sleep_duration_ms, 1000) + self.assertEqual(background_updater.update_duration_ms, 100) + + # Tests that non-default values for the config options are properly picked up and passed on. + @override_config( + yaml.safe_load( + """ + background_updates: + background_update_duration_ms: 1000 + sleep_enabled: false + sleep_duration_ms: 600 + min_batch_size: 5 + default_batch_size: 50 + """ + ) + ) + def test_custom_configuration(self): + background_updater = BackgroundUpdater( + self.hs, self.hs.get_datastores().main.db_pool + ) + + self.assertEqual(background_updater.minimum_background_batch_size, 5) + self.assertEqual(background_updater.default_background_batch_size, 50) + self.assertEqual(background_updater.sleep_enabled, False) + self.assertEqual(background_updater.sleep_duration_ms, 600) + self.assertEqual(background_updater.update_duration_ms, 1000) diff --git a/tests/config/test_registration_config.py b/tests/config/test_registration_config.py index 17a84d20d8..2acdb6ac61 100644 --- a/tests/config/test_registration_config.py +++ b/tests/config/test_registration_config.py @@ -11,14 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +import synapse.app.homeserver from synapse.config import ConfigError from synapse.config.homeserver import HomeServerConfig -from tests.unittest import TestCase +from tests.config.utils import ConfigFileTestCase from tests.utils import default_config -class RegistrationConfigTestCase(TestCase): +class RegistrationConfigTestCase(ConfigFileTestCase): def test_session_lifetime_must_not_be_exceeded_by_smaller_lifetimes(self): """ session_lifetime should logically be larger than, or at least as large as, @@ -76,3 +78,19 @@ class RegistrationConfigTestCase(TestCase): HomeServerConfig().parse_config_dict( {"session_lifetime": "31m", "refresh_token_lifetime": "31m", **config_dict} ) + + def test_refuse_to_start_if_open_registration_and_no_verification(self): + self.generate_config() + self.add_lines_to_config( + [ + " ", + "enable_registration: true", + "registrations_require_3pid: []", + "enable_registration_captcha: false", + "registration_requires_token: false", + ] + ) + + # Test that allowing open registration without verification raises an error + with self.assertRaises(ConfigError): + synapse.app.homeserver.setup(["-c", self.config_file]) diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py index 45e3395b33..00ad19e446 100644 --- a/tests/events/test_utils.py +++ b/tests/events/test_utils.py @@ -16,6 +16,7 @@ from synapse.api.constants import EventContentFields from synapse.api.room_versions import RoomVersions from synapse.events import make_event_from_dict from synapse.events.utils import ( + SerializeEventConfig, copy_power_levels_contents, prune_event, serialize_event, @@ -392,7 +393,9 @@ class PruneEventTestCase(unittest.TestCase): class SerializeEventTestCase(unittest.TestCase): def serialize(self, ev, fields): - return serialize_event(ev, 1479807801915, only_event_fields=fields) + return serialize_event( + ev, 1479807801915, config=SerializeEventConfig(only_event_fields=fields) + ) def test_event_fields_works_with_keys(self): self.assertEqual( diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index 60e0c31f43..e90592855a 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -201,9 +201,12 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): self.assertEqual(len(self.edus), 1) stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", None) + # We queue up device list updates to be sent over federation, so we + # advance to clear the queue. + self.reactor.advance(1) + # a second call should produce no new device EDUs self.hs.get_federation_sender().send_device_messages("host2") - self.pump() self.assertEqual(self.edus, []) # a second device @@ -232,6 +235,10 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): device1_signing_key = self.generate_and_upload_device_signing_key(u1, "D1") device2_signing_key = self.generate_and_upload_device_signing_key(u1, "D2") + # We queue up device list updates to be sent over federation, so we + # advance to clear the queue. + self.reactor.advance(1) + # expect two more edus self.assertEqual(len(self.edus), 2) stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", stream_id) @@ -265,6 +272,10 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): e2e_handler.upload_signing_keys_for_user(u1, cross_signing_keys) ) + # We queue up device list updates to be sent over federation, so we + # advance to clear the queue. + self.reactor.advance(1) + # expect signing key update edu self.assertEqual(len(self.edus), 2) self.assertEqual(self.edus.pop(0)["edu_type"], "m.signing_key_update") @@ -284,6 +295,10 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): ) self.assertEqual(ret["failures"], {}) + # We queue up device list updates to be sent over federation, so we + # advance to clear the queue. + self.reactor.advance(1) + # expect two edus, in one or two transactions. We don't know what order the # devices will be updated. self.assertEqual(len(self.edus), 2) @@ -307,6 +322,10 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): self.login("user", "pass", device_id="D2") self.login("user", "pass", device_id="D3") + # We queue up device list updates to be sent over federation, so we + # advance to clear the queue. + self.reactor.advance(1) + # expect three edus self.assertEqual(len(self.edus), 3) stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", None) @@ -318,6 +337,10 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"]) ) + # We queue up device list updates to be sent over federation, so we + # advance to clear the queue. + self.reactor.advance(1) + # expect three edus, in an unknown order self.assertEqual(len(self.edus), 3) for edu in self.edus: @@ -350,12 +373,19 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"]) ) + # We queue up device list updates to be sent over federation, so we + # advance to clear the queue. + self.reactor.advance(1) + self.assertGreaterEqual(mock_send_txn.call_count, 4) # recover the server mock_send_txn.side_effect = self.record_transaction self.hs.get_federation_sender().send_device_messages("host2") - self.pump() + + # We queue up device list updates to be sent over federation, so we + # advance to clear the queue. + self.reactor.advance(1) # for each device, there should be a single update self.assertEqual(len(self.edus), 3) @@ -390,6 +420,10 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"]) ) + # We queue up device list updates to be sent over federation, so we + # advance to clear the queue. + self.reactor.advance(1) + self.assertGreaterEqual(mock_send_txn.call_count, 4) # run the prune job @@ -401,7 +435,10 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): # recover the server mock_send_txn.side_effect = self.record_transaction self.hs.get_federation_sender().send_device_messages("host2") - self.pump() + + # We queue up device list updates to be sent over federation, so we + # advance to clear the queue. + self.reactor.advance(1) # there should be a single update for this user. self.assertEqual(len(self.edus), 1) @@ -435,6 +472,10 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): self.login("user", "pass", device_id="D2") self.login("user", "pass", device_id="D3") + # We queue up device list updates to be sent over federation, so we + # advance to clear the queue. + self.reactor.advance(1) + # delete them again self.get_success( self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"]) @@ -451,7 +492,10 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): # recover the server mock_send_txn.side_effect = self.record_transaction self.hs.get_federation_sender().send_device_messages("host2") - self.pump() + + # We queue up device list updates to be sent over federation, so we + # advance to clear the queue. + self.reactor.advance(1) # ... and we should get a single update for this user. self.assertEqual(len(self.edus), 1) diff --git a/tests/handlers/test_admin.py b/tests/handlers/test_admin.py index abf2a0fe0d..c1579dac61 100644 --- a/tests/handlers/test_admin.py +++ b/tests/handlers/test_admin.py @@ -15,11 +15,15 @@ from collections import Counter from unittest.mock import Mock +from twisted.test.proto_helpers import MemoryReactor + import synapse.rest.admin import synapse.storage from synapse.api.constants import EventTypes, JoinRules from synapse.api.room_versions import RoomVersions from synapse.rest.client import knock, login, room +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest @@ -32,7 +36,7 @@ class ExfiltrateData(unittest.HomeserverTestCase): knock.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.admin_handler = hs.get_admin_handler() self.user1 = self.register_user("user1", "password") @@ -41,7 +45,7 @@ class ExfiltrateData(unittest.HomeserverTestCase): self.user2 = self.register_user("user2", "password") self.token2 = self.login("user2", "password") - def test_single_public_joined_room(self): + def test_single_public_joined_room(self) -> None: """Test that we write *all* events for a public room""" room_id = self.helper.create_room_as( self.user1, tok=self.token1, is_public=True @@ -74,7 +78,7 @@ class ExfiltrateData(unittest.HomeserverTestCase): self.assertEqual(counter[(EventTypes.Member, self.user1)], 1) self.assertEqual(counter[(EventTypes.Member, self.user2)], 1) - def test_single_private_joined_room(self): + def test_single_private_joined_room(self) -> None: """Tests that we correctly write state when we can't see all events in a room. """ @@ -112,7 +116,7 @@ class ExfiltrateData(unittest.HomeserverTestCase): self.assertEqual(counter[(EventTypes.Member, self.user1)], 1) self.assertEqual(counter[(EventTypes.Member, self.user2)], 1) - def test_single_left_room(self): + def test_single_left_room(self) -> None: """Tests that we don't see events in the room after we leave.""" room_id = self.helper.create_room_as(self.user1, tok=self.token1) self.helper.send(room_id, body="Hello!", tok=self.token1) @@ -144,7 +148,7 @@ class ExfiltrateData(unittest.HomeserverTestCase): self.assertEqual(counter[(EventTypes.Member, self.user1)], 1) self.assertEqual(counter[(EventTypes.Member, self.user2)], 2) - def test_single_left_rejoined_private_room(self): + def test_single_left_rejoined_private_room(self) -> None: """Tests that see the correct events in private rooms when we repeatedly join and leave. """ @@ -185,7 +189,7 @@ class ExfiltrateData(unittest.HomeserverTestCase): self.assertEqual(counter[(EventTypes.Member, self.user1)], 1) self.assertEqual(counter[(EventTypes.Member, self.user2)], 3) - def test_invite(self): + def test_invite(self) -> None: """Tests that pending invites get handled correctly.""" room_id = self.helper.create_room_as(self.user1, tok=self.token1) self.helper.send(room_id, body="Hello!", tok=self.token1) @@ -204,7 +208,7 @@ class ExfiltrateData(unittest.HomeserverTestCase): self.assertEqual(args[1].content["membership"], "invite") self.assertTrue(args[2]) # Assert there is at least one bit of state - def test_knock(self): + def test_knock(self) -> None: """Tests that knock get handled correctly.""" # create a knockable v7 room room_id = self.helper.create_room_as( diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index 072e6bbcdd..cead9f90df 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -59,11 +59,11 @@ class AppServiceHandlerTestCase(unittest.TestCase): self.event_source = hs.get_event_sources() def test_notify_interested_services(self): - interested_service = self._mkservice(is_interested=True) + interested_service = self._mkservice(is_interested_in_event=True) services = [ - self._mkservice(is_interested=False), + self._mkservice(is_interested_in_event=False), interested_service, - self._mkservice(is_interested=False), + self._mkservice(is_interested_in_event=False), ] self.mock_as_api.query_user.return_value = make_awaitable(True) @@ -85,7 +85,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): def test_query_user_exists_unknown_user(self): user_id = "@someone:anywhere" - services = [self._mkservice(is_interested=True)] + services = [self._mkservice(is_interested_in_event=True)] services[0].is_interested_in_user.return_value = True self.mock_store.get_app_services.return_value = services self.mock_store.get_user_by_id.return_value = make_awaitable(None) @@ -102,7 +102,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): def test_query_user_exists_known_user(self): user_id = "@someone:anywhere" - services = [self._mkservice(is_interested=True)] + services = [self._mkservice(is_interested_in_event=True)] services[0].is_interested_in_user.return_value = True self.mock_store.get_app_services.return_value = services self.mock_store.get_user_by_id.return_value = make_awaitable({"name": user_id}) @@ -127,11 +127,11 @@ class AppServiceHandlerTestCase(unittest.TestCase): room_id = "!alpha:bet" servers = ["aperture"] - interested_service = self._mkservice_alias(is_interested_in_alias=True) + interested_service = self._mkservice_alias(is_room_alias_in_namespace=True) services = [ - self._mkservice_alias(is_interested_in_alias=False), + self._mkservice_alias(is_room_alias_in_namespace=False), interested_service, - self._mkservice_alias(is_interested_in_alias=False), + self._mkservice_alias(is_room_alias_in_namespace=False), ] self.mock_as_api.query_alias.return_value = make_awaitable(True) @@ -275,7 +275,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): to be pushed out to interested appservices, and that the stream ID is updated accordingly. """ - interested_service = self._mkservice(is_interested=True) + interested_service = self._mkservice(is_interested_in_event=True) services = [interested_service] self.mock_store.get_app_services.return_value = services self.mock_store.get_type_stream_id_for_appservice.return_value = make_awaitable( @@ -304,7 +304,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): Test sending out of order ephemeral events to the appservice handler are ignored. """ - interested_service = self._mkservice(is_interested=True) + interested_service = self._mkservice(is_interested_in_event=True) services = [interested_service] self.mock_store.get_app_services.return_value = services @@ -325,17 +325,45 @@ class AppServiceHandlerTestCase(unittest.TestCase): interested_service, ephemeral=[] ) - def _mkservice(self, is_interested, protocols=None): + def _mkservice( + self, is_interested_in_event: bool, protocols: Optional[Iterable] = None + ) -> Mock: + """ + Create a new mock representing an ApplicationService. + + Args: + is_interested_in_event: Whether this application service will be considered + interested in all events. + protocols: The third-party protocols that this application service claims to + support. + + Returns: + A mock representing the ApplicationService. + """ service = Mock() - service.is_interested.return_value = make_awaitable(is_interested) + service.is_interested_in_event.return_value = make_awaitable( + is_interested_in_event + ) service.token = "mock_service_token" service.url = "mock_service_url" service.protocols = protocols return service - def _mkservice_alias(self, is_interested_in_alias): + def _mkservice_alias(self, is_room_alias_in_namespace: bool) -> Mock: + """ + Create a new mock representing an ApplicationService that is or is not interested + any given room aliase. + + Args: + is_room_alias_in_namespace: If true, the application service will be interested + in all room aliases that are queried against it. If false, the application + service will not be interested in any room aliases. + + Returns: + A mock representing the ApplicationService. + """ service = Mock() - service.is_interested_in_alias.return_value = is_interested_in_alias + service.is_room_alias_in_namespace.return_value = is_room_alias_in_namespace service.token = "mock_service_token" service.url = "mock_service_url" return service diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index 0c6e55e725..67a7829769 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -15,8 +15,12 @@ from unittest.mock import Mock import pymacaroons +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.errors import AuthError, ResourceLimitError from synapse.rest import admin +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest from tests.test_utils import make_awaitable @@ -27,7 +31,7 @@ class AuthTestCase(unittest.HomeserverTestCase): admin.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.auth_handler = hs.get_auth_handler() self.macaroon_generator = hs.get_macaroon_generator() @@ -42,23 +46,23 @@ class AuthTestCase(unittest.HomeserverTestCase): self.user1 = self.register_user("a_user", "pass") - def test_macaroon_caveats(self): + def test_macaroon_caveats(self) -> None: token = self.macaroon_generator.generate_guest_access_token("a_user") macaroon = pymacaroons.Macaroon.deserialize(token) - def verify_gen(caveat): + def verify_gen(caveat: str) -> bool: return caveat == "gen = 1" - def verify_user(caveat): + def verify_user(caveat: str) -> bool: return caveat == "user_id = a_user" - def verify_type(caveat): + def verify_type(caveat: str) -> bool: return caveat == "type = access" - def verify_nonce(caveat): + def verify_nonce(caveat: str) -> bool: return caveat.startswith("nonce =") - def verify_guest(caveat): + def verify_guest(caveat: str) -> bool: return caveat == "guest = true" v = pymacaroons.Verifier() @@ -69,7 +73,7 @@ class AuthTestCase(unittest.HomeserverTestCase): v.satisfy_general(verify_guest) v.verify(macaroon, self.hs.config.key.macaroon_secret_key) - def test_short_term_login_token_gives_user_id(self): + def test_short_term_login_token_gives_user_id(self) -> None: token = self.macaroon_generator.generate_short_term_login_token( self.user1, "", duration_in_ms=5000 ) @@ -84,7 +88,7 @@ class AuthTestCase(unittest.HomeserverTestCase): AuthError, ) - def test_short_term_login_token_gives_auth_provider(self): + def test_short_term_login_token_gives_auth_provider(self) -> None: token = self.macaroon_generator.generate_short_term_login_token( self.user1, auth_provider_id="my_idp" ) @@ -92,7 +96,7 @@ class AuthTestCase(unittest.HomeserverTestCase): self.assertEqual(self.user1, res.user_id) self.assertEqual("my_idp", res.auth_provider_id) - def test_short_term_login_token_cannot_replace_user_id(self): + def test_short_term_login_token_cannot_replace_user_id(self) -> None: token = self.macaroon_generator.generate_short_term_login_token( self.user1, "", duration_in_ms=5000 ) @@ -112,7 +116,7 @@ class AuthTestCase(unittest.HomeserverTestCase): AuthError, ) - def test_mau_limits_disabled(self): + def test_mau_limits_disabled(self) -> None: self.auth_blocking._limit_usage_by_mau = False # Ensure does not throw exception self.get_success( @@ -127,7 +131,7 @@ class AuthTestCase(unittest.HomeserverTestCase): ) ) - def test_mau_limits_exceeded_large(self): + def test_mau_limits_exceeded_large(self) -> None: self.auth_blocking._limit_usage_by_mau = True self.hs.get_datastores().main.get_monthly_active_count = Mock( return_value=make_awaitable(self.large_number_of_users) @@ -150,7 +154,7 @@ class AuthTestCase(unittest.HomeserverTestCase): ResourceLimitError, ) - def test_mau_limits_parity(self): + def test_mau_limits_parity(self) -> None: # Ensure we're not at the unix epoch. self.reactor.advance(1) self.auth_blocking._limit_usage_by_mau = True @@ -189,7 +193,7 @@ class AuthTestCase(unittest.HomeserverTestCase): ) ) - def test_mau_limits_not_exceeded(self): + def test_mau_limits_not_exceeded(self) -> None: self.auth_blocking._limit_usage_by_mau = True self.hs.get_datastores().main.get_monthly_active_count = Mock( @@ -211,7 +215,7 @@ class AuthTestCase(unittest.HomeserverTestCase): ) ) - def _get_macaroon(self): + def _get_macaroon(self) -> pymacaroons.Macaroon: token = self.macaroon_generator.generate_short_term_login_token( self.user1, "", duration_in_ms=5000 ) diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py index a267228846..a54aa29cf1 100644 --- a/tests/handlers/test_cas.py +++ b/tests/handlers/test_cas.py @@ -11,9 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Dict from unittest.mock import Mock +from twisted.test.proto_helpers import MemoryReactor + from synapse.handlers.cas import CasResponse +from synapse.server import HomeServer +from synapse.util import Clock from tests.test_utils import simple_async_mock from tests.unittest import HomeserverTestCase, override_config @@ -24,7 +29,7 @@ SERVER_URL = "https://issuer/" class CasHandlerTestCase(HomeserverTestCase): - def default_config(self): + def default_config(self) -> Dict[str, Any]: config = super().default_config() config["public_baseurl"] = BASE_URL cas_config = { @@ -40,7 +45,7 @@ class CasHandlerTestCase(HomeserverTestCase): return config - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs = self.setup_test_homeserver() self.handler = hs.get_cas_handler() @@ -51,7 +56,7 @@ class CasHandlerTestCase(HomeserverTestCase): return hs - def test_map_cas_user_to_user(self): + def test_map_cas_user_to_user(self) -> None: """Ensure that mapping the CAS user returned from a provider to an MXID works properly.""" # stub out the auth handler @@ -75,7 +80,7 @@ class CasHandlerTestCase(HomeserverTestCase): auth_provider_session_id=None, ) - def test_map_cas_user_to_existing_user(self): + def test_map_cas_user_to_existing_user(self) -> None: """Existing users can log in with CAS account.""" store = self.hs.get_datastores().main self.get_success( @@ -119,7 +124,7 @@ class CasHandlerTestCase(HomeserverTestCase): auth_provider_session_id=None, ) - def test_map_cas_user_to_invalid_localpart(self): + def test_map_cas_user_to_invalid_localpart(self) -> None: """CAS automaps invalid characters to base-64 encoding.""" # stub out the auth handler @@ -150,7 +155,7 @@ class CasHandlerTestCase(HomeserverTestCase): } } ) - def test_required_attributes(self): + def test_required_attributes(self) -> None: """The required attributes must be met from the CAS response.""" # stub out the auth handler @@ -166,7 +171,7 @@ class CasHandlerTestCase(HomeserverTestCase): auth_handler.complete_sso_login.assert_not_called() # The response doesn't have any department. - cas_response = CasResponse("test_user", {"userGroup": "staff"}) + cas_response = CasResponse("test_user", {"userGroup": ["staff"]}) request.reset_mock() self.get_success( self.handler._handle_cas_response(request, cas_response, "redirect_uri", "") diff --git a/tests/handlers/test_deactivate_account.py b/tests/handlers/test_deactivate_account.py index ddda36c5a9..3a10791226 100644 --- a/tests/handlers/test_deactivate_account.py +++ b/tests/handlers/test_deactivate_account.py @@ -39,7 +39,7 @@ class DeactivateAccountTestCase(HomeserverTestCase): self.user = self.register_user("user", "pass") self.token = self.login("user", "pass") - def _deactivate_my_account(self): + def _deactivate_my_account(self) -> None: """ Deactivates the account `self.user` using `self.token` and asserts that it returns a 200 success code. diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index 683677fd07..01ea7d2a42 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -14,9 +14,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import synapse.api.errors -import synapse.handlers.device -import synapse.storage +from typing import Optional + +from twisted.test.proto_helpers import MemoryReactor + +from synapse.api.errors import NotFoundError, SynapseError +from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest @@ -25,28 +30,27 @@ user2 = "@theresa:bbb" class DeviceTestCase(unittest.HomeserverTestCase): - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs = self.setup_test_homeserver("server", federation_http_client=None) self.handler = hs.get_device_handler() self.store = hs.get_datastores().main return hs - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # These tests assume that it starts 1000 seconds in. self.reactor.advance(1000) - def test_device_is_created_with_invalid_name(self): + def test_device_is_created_with_invalid_name(self) -> None: self.get_failure( self.handler.check_device_registered( user_id="@boris:foo", device_id="foo", - initial_device_display_name="a" - * (synapse.handlers.device.MAX_DEVICE_DISPLAY_NAME_LEN + 1), + initial_device_display_name="a" * (MAX_DEVICE_DISPLAY_NAME_LEN + 1), ), - synapse.api.errors.SynapseError, + SynapseError, ) - def test_device_is_created_if_doesnt_exist(self): + def test_device_is_created_if_doesnt_exist(self) -> None: res = self.get_success( self.handler.check_device_registered( user_id="@boris:foo", @@ -59,7 +63,7 @@ class DeviceTestCase(unittest.HomeserverTestCase): dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco")) self.assertEqual(dev["display_name"], "display name") - def test_device_is_preserved_if_exists(self): + def test_device_is_preserved_if_exists(self) -> None: res1 = self.get_success( self.handler.check_device_registered( user_id="@boris:foo", @@ -81,7 +85,7 @@ class DeviceTestCase(unittest.HomeserverTestCase): dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco")) self.assertEqual(dev["display_name"], "display name") - def test_device_id_is_made_up_if_unspecified(self): + def test_device_id_is_made_up_if_unspecified(self) -> None: device_id = self.get_success( self.handler.check_device_registered( user_id="@theresa:foo", @@ -93,7 +97,7 @@ class DeviceTestCase(unittest.HomeserverTestCase): dev = self.get_success(self.handler.store.get_device("@theresa:foo", device_id)) self.assertEqual(dev["display_name"], "display") - def test_get_devices_by_user(self): + def test_get_devices_by_user(self) -> None: self._record_users() res = self.get_success(self.handler.get_devices_by_user(user1)) @@ -131,7 +135,7 @@ class DeviceTestCase(unittest.HomeserverTestCase): device_map["abc"], ) - def test_get_device(self): + def test_get_device(self) -> None: self._record_users() res = self.get_success(self.handler.get_device(user1, "abc")) @@ -146,21 +150,19 @@ class DeviceTestCase(unittest.HomeserverTestCase): res, ) - def test_delete_device(self): + def test_delete_device(self) -> None: self._record_users() # delete the device self.get_success(self.handler.delete_device(user1, "abc")) # check the device was deleted - self.get_failure( - self.handler.get_device(user1, "abc"), synapse.api.errors.NotFoundError - ) + self.get_failure(self.handler.get_device(user1, "abc"), NotFoundError) # we'd like to check the access token was invalidated, but that's a # bit of a PITA. - def test_delete_device_and_device_inbox(self): + def test_delete_device_and_device_inbox(self) -> None: self._record_users() # add an device_inbox @@ -191,7 +193,7 @@ class DeviceTestCase(unittest.HomeserverTestCase): ) self.assertIsNone(res) - def test_update_device(self): + def test_update_device(self) -> None: self._record_users() update = {"display_name": "new display"} @@ -200,32 +202,29 @@ class DeviceTestCase(unittest.HomeserverTestCase): res = self.get_success(self.handler.get_device(user1, "abc")) self.assertEqual(res["display_name"], "new display") - def test_update_device_too_long_display_name(self): + def test_update_device_too_long_display_name(self) -> None: """Update a device with a display name that is invalid (too long).""" self._record_users() # Request to update a device display name with a new value that is longer than allowed. - update = { - "display_name": "a" - * (synapse.handlers.device.MAX_DEVICE_DISPLAY_NAME_LEN + 1) - } + update = {"display_name": "a" * (MAX_DEVICE_DISPLAY_NAME_LEN + 1)} self.get_failure( self.handler.update_device(user1, "abc", update), - synapse.api.errors.SynapseError, + SynapseError, ) # Ensure the display name was not updated. res = self.get_success(self.handler.get_device(user1, "abc")) self.assertEqual(res["display_name"], "display 2") - def test_update_unknown_device(self): + def test_update_unknown_device(self) -> None: update = {"display_name": "new_display"} self.get_failure( self.handler.update_device("user_id", "unknown_device_id", update), - synapse.api.errors.NotFoundError, + NotFoundError, ) - def _record_users(self): + def _record_users(self) -> None: # check this works for both devices which have a recorded client_ip, # and those which don't. self._record_user(user1, "xyz", "display 0") @@ -238,8 +237,13 @@ class DeviceTestCase(unittest.HomeserverTestCase): self.reactor.advance(10000) def _record_user( - self, user_id, device_id, display_name, access_token=None, ip=None - ): + self, + user_id: str, + device_id: str, + display_name: str, + access_token: Optional[str] = None, + ip: Optional[str] = None, + ) -> None: device_id = self.get_success( self.handler.check_device_registered( user_id=user_id, @@ -248,7 +252,7 @@ class DeviceTestCase(unittest.HomeserverTestCase): ) ) - if ip is not None: + if access_token is not None and ip is not None: self.get_success( self.store.insert_client_ip( user_id, access_token, ip, "user_agent", device_id @@ -258,7 +262,7 @@ class DeviceTestCase(unittest.HomeserverTestCase): class DehydrationTestCase(unittest.HomeserverTestCase): - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs = self.setup_test_homeserver("server", federation_http_client=None) self.handler = hs.get_device_handler() self.registration = hs.get_registration_handler() @@ -266,7 +270,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase): self.store = hs.get_datastores().main return hs - def test_dehydrate_and_rehydrate_device(self): + def test_dehydrate_and_rehydrate_device(self) -> None: user_id = "@boris:dehydration" self.get_success(self.store.register_user(user_id, "foobar")) @@ -303,7 +307,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase): access_token=access_token, device_id="not the right device ID", ), - synapse.api.errors.NotFoundError, + NotFoundError, ) # dehydrating the right devices should succeed and change our device ID @@ -331,7 +335,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase): # make sure that the device ID that we were initially assigned no longer exists self.get_failure( self.handler.get_device(user_id, device_id), - synapse.api.errors.NotFoundError, + NotFoundError, ) # make sure that there's no device available for dehydrating now diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 6e403a87c5..11ad44223d 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -12,14 +12,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from typing import Any, Awaitable, Callable, Dict from unittest.mock import Mock +from twisted.test.proto_helpers import MemoryReactor + import synapse.api.errors import synapse.rest.admin from synapse.api.constants import EventTypes from synapse.rest.client import directory, login, room -from synapse.types import RoomAlias, create_requester +from synapse.server import HomeServer +from synapse.types import JsonDict, RoomAlias, create_requester +from synapse.util import Clock from tests import unittest from tests.test_utils import make_awaitable @@ -28,13 +32,15 @@ from tests.test_utils import make_awaitable class DirectoryTestCase(unittest.HomeserverTestCase): """Tests the directory service.""" - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.mock_federation = Mock() self.mock_registry = Mock() - self.query_handlers = {} + self.query_handlers: Dict[str, Callable[[dict], Awaitable[JsonDict]]] = {} - def register_query_handler(query_type, handler): + def register_query_handler( + query_type: str, handler: Callable[[dict], Awaitable[JsonDict]] + ) -> None: self.query_handlers[query_type] = handler self.mock_registry.register_query_handler = register_query_handler @@ -54,7 +60,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): return hs - def test_get_local_association(self): + def test_get_local_association(self) -> None: self.get_success( self.store.create_room_alias_association( self.my_room, "!8765qwer:test", ["test"] @@ -65,7 +71,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): self.assertEqual({"room_id": "!8765qwer:test", "servers": ["test"]}, result) - def test_get_remote_association(self): + def test_get_remote_association(self) -> None: self.mock_federation.make_query.return_value = make_awaitable( {"room_id": "!8765qwer:test", "servers": ["test", "remote"]} ) @@ -83,7 +89,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): ignore_backoff=True, ) - def test_incoming_fed_query(self): + def test_incoming_fed_query(self) -> None: self.get_success( self.store.create_room_alias_association( self.your_room, "!8765asdf:test", ["test"] @@ -105,7 +111,7 @@ class TestCreateAlias(unittest.HomeserverTestCase): directory.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.handler = hs.get_directory_handler() # Create user @@ -125,7 +131,7 @@ class TestCreateAlias(unittest.HomeserverTestCase): self.test_user_tok = self.login("user", "pass") self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok) - def test_create_alias_joined_room(self): + def test_create_alias_joined_room(self) -> None: """A user can create an alias for a room they're in.""" self.get_success( self.handler.create_association( @@ -135,7 +141,7 @@ class TestCreateAlias(unittest.HomeserverTestCase): ) ) - def test_create_alias_other_room(self): + def test_create_alias_other_room(self) -> None: """A user cannot create an alias for a room they're NOT in.""" other_room_id = self.helper.create_room_as( self.admin_user, tok=self.admin_user_tok @@ -150,7 +156,7 @@ class TestCreateAlias(unittest.HomeserverTestCase): synapse.api.errors.SynapseError, ) - def test_create_alias_admin(self): + def test_create_alias_admin(self) -> None: """An admin can create an alias for a room they're NOT in.""" other_room_id = self.helper.create_room_as( self.test_user, tok=self.test_user_tok @@ -173,7 +179,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase): directory.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main self.handler = hs.get_directory_handler() self.state_handler = hs.get_state_handler() @@ -195,7 +201,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase): self.test_user_tok = self.login("user", "pass") self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok) - def _create_alias(self, user): + def _create_alias(self, user) -> None: # Create a new alias to this room. self.get_success( self.store.create_room_alias_association( @@ -203,7 +209,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase): ) ) - def test_delete_alias_not_allowed(self): + def test_delete_alias_not_allowed(self) -> None: """A user that doesn't meet the expected guidelines cannot delete an alias.""" self._create_alias(self.admin_user) self.get_failure( @@ -213,7 +219,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase): synapse.api.errors.AuthError, ) - def test_delete_alias_creator(self): + def test_delete_alias_creator(self) -> None: """An alias creator can delete their own alias.""" # Create an alias from a different user. self._create_alias(self.test_user) @@ -232,7 +238,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase): synapse.api.errors.SynapseError, ) - def test_delete_alias_admin(self): + def test_delete_alias_admin(self) -> None: """A server admin can delete an alias created by another user.""" # Create an alias from a different user. self._create_alias(self.test_user) @@ -251,7 +257,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase): synapse.api.errors.SynapseError, ) - def test_delete_alias_sufficient_power(self): + def test_delete_alias_sufficient_power(self) -> None: """A user with a sufficient power level should be able to delete an alias.""" self._create_alias(self.admin_user) @@ -288,7 +294,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): directory.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main self.handler = hs.get_directory_handler() self.state_handler = hs.get_state_handler() @@ -317,7 +323,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): ) return room_alias - def _set_canonical_alias(self, content): + def _set_canonical_alias(self, content) -> None: """Configure the canonical alias state on the room.""" self.helper.send_state( self.room_id, @@ -334,7 +340,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): ) ) - def test_remove_alias(self): + def test_remove_alias(self) -> None: """Removing an alias that is the canonical alias should remove it there too.""" # Set this new alias as the canonical alias for this room self._set_canonical_alias( @@ -356,7 +362,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): self.assertNotIn("alias", data["content"]) self.assertNotIn("alt_aliases", data["content"]) - def test_remove_other_alias(self): + def test_remove_other_alias(self) -> None: """Removing an alias listed as in alt_aliases should remove it there too.""" # Create a second alias. other_test_alias = "#test2:test" @@ -393,7 +399,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase): servlets = [directory.register_servlets, room.register_servlets] - def default_config(self): + def default_config(self) -> Dict[str, Any]: config = super().default_config() # Add custom alias creation rules to the config. @@ -403,7 +409,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase): return config - def test_denied(self): + def test_denied(self) -> None: room_id = self.helper.create_room_as(self.user_id) channel = self.make_request( @@ -413,7 +419,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase): ) self.assertEqual(403, channel.code, channel.result) - def test_allowed(self): + def test_allowed(self) -> None: room_id = self.helper.create_room_as(self.user_id) channel = self.make_request( @@ -423,7 +429,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase): ) self.assertEqual(200, channel.code, channel.result) - def test_denied_during_creation(self): + def test_denied_during_creation(self) -> None: """A room alias that is not allowed should be rejected during creation.""" # Invalid room alias. self.helper.create_room_as( @@ -432,7 +438,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase): extra_content={"room_alias_name": "foo"}, ) - def test_allowed_during_creation(self): + def test_allowed_during_creation(self) -> None: """A valid room alias should be allowed during creation.""" room_id = self.helper.create_room_as( self.user_id, @@ -459,7 +465,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase): data = {"room_alias_name": "unofficial_test"} allowed_localpart = "allowed" - def default_config(self): + def default_config(self) -> Dict[str, Any]: config = super().default_config() # Add custom room list publication rules to the config. @@ -474,7 +480,9 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase): return config - def prepare(self, reactor, clock, hs): + def prepare( + self, reactor: MemoryReactor, clock: Clock, hs: HomeServer + ) -> HomeServer: self.allowed_user_id = self.register_user(self.allowed_localpart, "pass") self.allowed_access_token = self.login(self.allowed_localpart, "pass") @@ -483,7 +491,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase): return hs - def test_denied_without_publication_permission(self): + def test_denied_without_publication_permission(self) -> None: """ Try to create a room, register an alias for it, and publish it, as a user without permission to publish rooms. @@ -497,7 +505,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase): expect_code=403, ) - def test_allowed_when_creating_private_room(self): + def test_allowed_when_creating_private_room(self) -> None: """ Try to create a room, register an alias for it, and NOT publish it, as a user without permission to publish rooms. @@ -511,7 +519,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase): expect_code=200, ) - def test_allowed_with_publication_permission(self): + def test_allowed_with_publication_permission(self) -> None: """ Try to create a room, register an alias for it, and publish it, as a user WITH permission to publish rooms. @@ -525,7 +533,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase): expect_code=200, ) - def test_denied_publication_with_invalid_alias(self): + def test_denied_publication_with_invalid_alias(self) -> None: """ Try to create a room, register an alias for it, and publish it, as a user WITH permission to publish rooms. @@ -538,7 +546,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase): expect_code=403, ) - def test_can_create_as_private_room_after_rejection(self): + def test_can_create_as_private_room_after_rejection(self) -> None: """ After failing to publish a room with an alias as a user without publish permission, retry as the same user, but without publishing the room. @@ -549,7 +557,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase): self.test_denied_without_publication_permission() self.test_allowed_when_creating_private_room() - def test_can_create_with_permission_after_rejection(self): + def test_can_create_with_permission_after_rejection(self) -> None: """ After failing to publish a room with an alias as a user without publish permission, retry as someone with permission, using the same alias. @@ -566,7 +574,9 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase): servlets = [directory.register_servlets, room.register_servlets] - def prepare(self, reactor, clock, hs): + def prepare( + self, reactor: MemoryReactor, clock: Clock, hs: HomeServer + ) -> HomeServer: room_id = self.helper.create_room_as(self.user_id) channel = self.make_request( @@ -579,7 +589,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase): return hs - def test_disabling_room_list(self): + def test_disabling_room_list(self) -> None: self.room_list_handler.enable_room_list_search = True self.directory_handler.enable_room_list_search = True diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 9338ab92e9..ac21a28c43 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -20,33 +20,37 @@ from parameterized import parameterized from signedjson import key as key, sign as sign from twisted.internet import defer +from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import RoomEncryptionAlgorithms from synapse.api.errors import Codes, SynapseError +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util import Clock from tests import unittest from tests.test_utils import make_awaitable class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: return self.setup_test_homeserver(federation_client=mock.Mock()) - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.handler = hs.get_e2e_keys_handler() self.store = self.hs.get_datastores().main - def test_query_local_devices_no_devices(self): + def test_query_local_devices_no_devices(self) -> None: """If the user has no devices, we expect an empty list.""" local_user = "@boris:" + self.hs.hostname res = self.get_success(self.handler.query_local_devices({local_user: None})) self.assertDictEqual(res, {local_user: {}}) - def test_reupload_one_time_keys(self): + def test_reupload_one_time_keys(self) -> None: """we should be able to re-upload the same keys""" local_user = "@boris:" + self.hs.hostname device_id = "xyz" - keys = { + keys: JsonDict = { "alg1:k1": "key1", "alg2:k2": {"key": "key2", "signatures": {"k1": "sig1"}}, "alg2:k3": {"key": "key3"}, @@ -74,7 +78,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): res, {"one_time_key_counts": {"alg1": 1, "alg2": 2, "signed_curve25519": 0}} ) - def test_change_one_time_keys(self): + def test_change_one_time_keys(self) -> None: """attempts to change one-time-keys should be rejected""" local_user = "@boris:" + self.hs.hostname @@ -134,7 +138,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): SynapseError, ) - def test_claim_one_time_key(self): + def test_claim_one_time_key(self) -> None: local_user = "@boris:" + self.hs.hostname device_id = "xyz" keys = {"alg1:k1": "key1"} @@ -161,7 +165,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): }, ) - def test_fallback_key(self): + def test_fallback_key(self) -> None: local_user = "@boris:" + self.hs.hostname device_id = "xyz" fallback_key = {"alg1:k1": "fallback_key1"} @@ -294,7 +298,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}}, ) - def test_replace_master_key(self): + def test_replace_master_key(self) -> None: """uploading a new signing key should make the old signing key unavailable""" local_user = "@boris:" + self.hs.hostname keys1 = { @@ -328,7 +332,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): ) self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]}) - def test_reupload_signatures(self): + def test_reupload_signatures(self) -> None: """re-uploading a signature should not fail""" local_user = "@boris:" + self.hs.hostname keys1 = { @@ -433,7 +437,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): self.assertDictEqual(devices["device_keys"][local_user]["abc"], device_key_1) self.assertDictEqual(devices["device_keys"][local_user]["def"], device_key_2) - def test_self_signing_key_doesnt_show_up_as_device(self): + def test_self_signing_key_doesnt_show_up_as_device(self) -> None: """signing keys should be hidden when fetching a user's devices""" local_user = "@boris:" + self.hs.hostname keys1 = { @@ -462,7 +466,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): res = self.get_success(self.handler.query_local_devices({local_user: None})) self.assertDictEqual(res, {local_user: {}}) - def test_upload_signatures(self): + def test_upload_signatures(self) -> None: """should check signatures that are uploaded""" # set up a user with cross-signing keys and a device. This user will # try uploading signatures @@ -686,7 +690,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): other_master_key["signatures"][local_user]["ed25519:" + usersigning_pubkey], ) - def test_query_devices_remote_no_sync(self): + def test_query_devices_remote_no_sync(self) -> None: """Tests that querying keys for a remote user that we don't share a room with returns the cross signing keys correctly. """ @@ -759,7 +763,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): }, ) - def test_query_devices_remote_sync(self): + def test_query_devices_remote_sync(self) -> None: """Tests that querying keys for a remote user that we share a room with, but haven't yet fetched the keys for, returns the cross signing keys correctly. @@ -845,7 +849,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): (["device_1", "device_2"],), ] ) - def test_query_all_devices_caches_result(self, device_ids: Iterable[str]): + def test_query_all_devices_caches_result(self, device_ids: Iterable[str]) -> None: """Test that requests for all of a remote user's devices are cached. We do this by asserting that only one call over federation was made, and that @@ -853,7 +857,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): """ local_user_id = "@test:test" remote_user_id = "@test:other" - request_body = {"device_keys": {remote_user_id: []}} + request_body: JsonDict = {"device_keys": {remote_user_id: []}} response_devices = [ { diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index e8b4e39d1a..89078fc637 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import List +from typing import List, cast from unittest import TestCase +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.constants import EventTypes from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseError from synapse.api.room_versions import RoomVersions @@ -23,7 +25,9 @@ from synapse.federation.federation_base import event_from_pdu_json from synapse.logging.context import LoggingContext, run_in_background from synapse.rest import admin from synapse.rest.client import login, room +from synapse.server import HomeServer from synapse.types import create_requester +from synapse.util import Clock from synapse.util.stringutils import random_string from tests import unittest @@ -42,7 +46,7 @@ class FederationTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs = self.setup_test_homeserver(federation_http_client=None) self.handler = hs.get_federation_handler() self.store = hs.get_datastores().main @@ -50,7 +54,7 @@ class FederationTestCase(unittest.HomeserverTestCase): self._event_auth_handler = hs.get_event_auth_handler() return hs - def test_exchange_revoked_invite(self): + def test_exchange_revoked_invite(self) -> None: user_id = self.register_user("kermit", "test") tok = self.login("kermit", "test") @@ -96,7 +100,7 @@ class FederationTestCase(unittest.HomeserverTestCase): self.assertEqual(failure.errcode, Codes.FORBIDDEN, failure) self.assertEqual(failure.msg, "You are not invited to this room.") - def test_rejected_message_event_state(self): + def test_rejected_message_event_state(self) -> None: """ Check that we store the state group correctly for rejected non-state events. @@ -126,7 +130,7 @@ class FederationTestCase(unittest.HomeserverTestCase): "content": {}, "room_id": room_id, "sender": "@yetanotheruser:" + OTHER_SERVER, - "depth": join_event["depth"] + 1, + "depth": cast(int, join_event["depth"]) + 1, "prev_events": [join_event.event_id], "auth_events": [], "origin_server_ts": self.clock.time_msec(), @@ -149,7 +153,7 @@ class FederationTestCase(unittest.HomeserverTestCase): self.assertEqual(sg, sg2) - def test_rejected_state_event_state(self): + def test_rejected_state_event_state(self) -> None: """ Check that we store the state group correctly for rejected state events. @@ -180,7 +184,7 @@ class FederationTestCase(unittest.HomeserverTestCase): "content": {}, "room_id": room_id, "sender": "@yetanotheruser:" + OTHER_SERVER, - "depth": join_event["depth"] + 1, + "depth": cast(int, join_event["depth"]) + 1, "prev_events": [join_event.event_id], "auth_events": [], "origin_server_ts": self.clock.time_msec(), @@ -203,7 +207,7 @@ class FederationTestCase(unittest.HomeserverTestCase): self.assertEqual(sg, sg2) - def test_backfill_with_many_backward_extremities(self): + def test_backfill_with_many_backward_extremities(self) -> None: """ Check that we can backfill with many backward extremities. The goal is to make sure that when we only use a portion @@ -262,7 +266,7 @@ class FederationTestCase(unittest.HomeserverTestCase): ) self.get_success(d) - def test_backfill_floating_outlier_membership_auth(self): + def test_backfill_floating_outlier_membership_auth(self) -> None: """ As the local homeserver, check that we can properly process a federated event from the OTHER_SERVER with auth_events that include a floating @@ -377,7 +381,7 @@ class FederationTestCase(unittest.HomeserverTestCase): for ae in auth_events ] - self.handler.federation_client.get_event_auth = get_event_auth + self.handler.federation_client.get_event_auth = get_event_auth # type: ignore[assignment] with LoggingContext("receive_pdu"): # Fake the OTHER_SERVER federating the message event over to our local homeserver @@ -397,7 +401,7 @@ class FederationTestCase(unittest.HomeserverTestCase): @unittest.override_config( {"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}} ) - def test_invite_by_user_ratelimit(self): + def test_invite_by_user_ratelimit(self) -> None: """Tests that invites from federation to a particular user are actually rate-limited. """ @@ -446,7 +450,9 @@ class FederationTestCase(unittest.HomeserverTestCase): exc=LimitExceededError, ) - def _build_and_send_join_event(self, other_server, other_user, room_id): + def _build_and_send_join_event( + self, other_server: str, other_user: str, room_id: str + ) -> EventBase: join_event = self.get_success( self.handler.on_make_join_request(other_server, room_id, other_user) ) @@ -469,7 +475,7 @@ class FederationTestCase(unittest.HomeserverTestCase): class EventFromPduTestCase(TestCase): - def test_valid_json(self): + def test_valid_json(self) -> None: """Valid JSON should be turned into an event.""" ev = event_from_pdu_json( { @@ -487,7 +493,7 @@ class EventFromPduTestCase(TestCase): self.assertIsInstance(ev, EventBase) - def test_invalid_numbers(self): + def test_invalid_numbers(self) -> None: """Invalid values for an integer should be rejected, all floats should be rejected.""" for value in [ -(2 ** 53), @@ -512,7 +518,7 @@ class EventFromPduTestCase(TestCase): RoomVersions.V6, ) - def test_invalid_nested(self): + def test_invalid_nested(self) -> None: """List and dictionaries are recursively searched.""" with self.assertRaises(SynapseError): event_from_pdu_json( diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index e8418b6638..014815db6e 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -13,14 +13,18 @@ # limitations under the License. import json import os +from typing import Any, Dict from unittest.mock import ANY, Mock, patch from urllib.parse import parse_qs, urlparse import pymacaroons +from twisted.test.proto_helpers import MemoryReactor + from synapse.handlers.sso import MappingException from synapse.server import HomeServer -from synapse.types import UserID +from synapse.types import JsonDict, UserID +from synapse.util import Clock from synapse.util.macaroons import get_value_from_macaroon from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock @@ -98,7 +102,7 @@ class TestMappingProviderFailures(TestMappingProvider): } -async def get_json(url): +async def get_json(url: str) -> JsonDict: # Mock get_json calls to handle jwks & oidc discovery endpoints if url == WELL_KNOWN: # Minimal discovery document, as defined in OpenID.Discovery @@ -116,6 +120,8 @@ async def get_json(url): elif url == JWKS_URI: return {"keys": []} + return {} + def _key_file_path() -> str: """path to a file containing the private half of a test key""" @@ -147,12 +153,12 @@ class OidcHandlerTestCase(HomeserverTestCase): if not HAS_OIDC: skip = "requires OIDC" - def default_config(self): + def default_config(self) -> Dict[str, Any]: config = super().default_config() config["public_baseurl"] = BASE_URL return config - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.http_client = Mock(spec=["get_json"]) self.http_client.get_json.side_effect = get_json self.http_client.user_agent = b"Synapse Test" @@ -164,7 +170,7 @@ class OidcHandlerTestCase(HomeserverTestCase): sso_handler = hs.get_sso_handler() # Mock the render error method. self.render_error = Mock(return_value=None) - sso_handler.render_error = self.render_error + sso_handler.render_error = self.render_error # type: ignore[assignment] # Reduce the number of attempts when generating MXIDs. sso_handler._MAP_USERNAME_RETRIES = 3 @@ -193,14 +199,14 @@ class OidcHandlerTestCase(HomeserverTestCase): return args @override_config({"oidc_config": DEFAULT_CONFIG}) - def test_config(self): + def test_config(self) -> None: """Basic config correctly sets up the callback URL and client auth correctly.""" self.assertEqual(self.provider._callback_url, CALLBACK_URL) self.assertEqual(self.provider._client_auth.client_id, CLIENT_ID) self.assertEqual(self.provider._client_auth.client_secret, CLIENT_SECRET) @override_config({"oidc_config": {**DEFAULT_CONFIG, "discover": True}}) - def test_discovery(self): + def test_discovery(self) -> None: """The handler should discover the endpoints from OIDC discovery document.""" # This would throw if some metadata were invalid metadata = self.get_success(self.provider.load_metadata()) @@ -219,13 +225,13 @@ class OidcHandlerTestCase(HomeserverTestCase): self.http_client.get_json.assert_not_called() @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG}) - def test_no_discovery(self): + def test_no_discovery(self) -> None: """When discovery is disabled, it should not try to load from discovery document.""" self.get_success(self.provider.load_metadata()) self.http_client.get_json.assert_not_called() @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG}) - def test_load_jwks(self): + def test_load_jwks(self) -> None: """JWKS loading is done once (then cached) if used.""" jwks = self.get_success(self.provider.load_jwks()) self.http_client.get_json.assert_called_once_with(JWKS_URI) @@ -253,7 +259,7 @@ class OidcHandlerTestCase(HomeserverTestCase): self.get_failure(self.provider.load_jwks(force=True), RuntimeError) @override_config({"oidc_config": DEFAULT_CONFIG}) - def test_validate_config(self): + def test_validate_config(self) -> None: """Provider metadatas are extensively validated.""" h = self.provider @@ -336,14 +342,14 @@ class OidcHandlerTestCase(HomeserverTestCase): force_load_metadata() @override_config({"oidc_config": {**DEFAULT_CONFIG, "skip_verification": True}}) - def test_skip_verification(self): + def test_skip_verification(self) -> None: """Provider metadata validation can be disabled by config.""" with self.metadata_edit({"issuer": "http://insecure"}): # This should not throw get_awaitable_result(self.provider.load_metadata()) @override_config({"oidc_config": DEFAULT_CONFIG}) - def test_redirect_request(self): + def test_redirect_request(self) -> None: """The redirect request has the right arguments & generates a valid session cookie.""" req = Mock(spec=["cookies"]) req.cookies = [] @@ -387,7 +393,7 @@ class OidcHandlerTestCase(HomeserverTestCase): self.assertEqual(redirect, "http://client/redirect") @override_config({"oidc_config": DEFAULT_CONFIG}) - def test_callback_error(self): + def test_callback_error(self) -> None: """Errors from the provider returned in the callback are displayed.""" request = Mock(args={}) request.args[b"error"] = [b"invalid_client"] @@ -399,7 +405,7 @@ class OidcHandlerTestCase(HomeserverTestCase): self.assertRenderedError("invalid_client", "some description") @override_config({"oidc_config": DEFAULT_CONFIG}) - def test_callback(self): + def test_callback(self) -> None: """Code callback works and display errors if something went wrong. A lot of scenarios are tested here: @@ -428,9 +434,9 @@ class OidcHandlerTestCase(HomeserverTestCase): "username": username, } expected_user_id = "@%s:%s" % (username, self.hs.hostname) - self.provider._exchange_code = simple_async_mock(return_value=token) - self.provider._parse_id_token = simple_async_mock(return_value=userinfo) - self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo) + self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment] + self.provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment] + self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo) # type: ignore[assignment] auth_handler = self.hs.get_auth_handler() auth_handler.complete_sso_login = simple_async_mock() @@ -468,7 +474,7 @@ class OidcHandlerTestCase(HomeserverTestCase): self.assertRenderedError("mapping_error") # Handle ID token errors - self.provider._parse_id_token = simple_async_mock(raises=Exception()) + self.provider._parse_id_token = simple_async_mock(raises=Exception()) # type: ignore[assignment] self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("invalid_token") @@ -483,7 +489,7 @@ class OidcHandlerTestCase(HomeserverTestCase): "type": "bearer", "access_token": "access_token", } - self.provider._exchange_code = simple_async_mock(return_value=token) + self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment] self.get_success(self.handler.handle_oidc_callback(request)) auth_handler.complete_sso_login.assert_called_once_with( @@ -510,8 +516,8 @@ class OidcHandlerTestCase(HomeserverTestCase): id_token = { "sid": "abcdefgh", } - self.provider._parse_id_token = simple_async_mock(return_value=id_token) - self.provider._exchange_code = simple_async_mock(return_value=token) + self.provider._parse_id_token = simple_async_mock(return_value=id_token) # type: ignore[assignment] + self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment] auth_handler.complete_sso_login.reset_mock() self.provider._fetch_userinfo.reset_mock() self.get_success(self.handler.handle_oidc_callback(request)) @@ -531,21 +537,21 @@ class OidcHandlerTestCase(HomeserverTestCase): self.render_error.assert_not_called() # Handle userinfo fetching error - self.provider._fetch_userinfo = simple_async_mock(raises=Exception()) + self.provider._fetch_userinfo = simple_async_mock(raises=Exception()) # type: ignore[assignment] self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("fetch_error") # Handle code exchange failure from synapse.handlers.oidc import OidcError - self.provider._exchange_code = simple_async_mock( + self.provider._exchange_code = simple_async_mock( # type: ignore[assignment] raises=OidcError("invalid_request") ) self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("invalid_request") @override_config({"oidc_config": DEFAULT_CONFIG}) - def test_callback_session(self): + def test_callback_session(self) -> None: """The callback verifies the session presence and validity""" request = Mock(spec=["args", "getCookie", "cookies"]) @@ -590,7 +596,7 @@ class OidcHandlerTestCase(HomeserverTestCase): @override_config( {"oidc_config": {**DEFAULT_CONFIG, "client_auth_method": "client_secret_post"}} ) - def test_exchange_code(self): + def test_exchange_code(self) -> None: """Code exchange behaves correctly and handles various error scenarios.""" token = {"type": "bearer"} token_json = json.dumps(token).encode("utf-8") @@ -686,7 +692,7 @@ class OidcHandlerTestCase(HomeserverTestCase): } } ) - def test_exchange_code_jwt_key(self): + def test_exchange_code_jwt_key(self) -> None: """Test that code exchange works with a JWK client secret.""" from authlib.jose import jwt @@ -741,7 +747,7 @@ class OidcHandlerTestCase(HomeserverTestCase): } } ) - def test_exchange_code_no_auth(self): + def test_exchange_code_no_auth(self) -> None: """Test that code exchange works with no client secret.""" token = {"type": "bearer"} self.http_client.request = simple_async_mock( @@ -776,7 +782,7 @@ class OidcHandlerTestCase(HomeserverTestCase): } } ) - def test_extra_attributes(self): + def test_extra_attributes(self) -> None: """ Login while using a mapping provider that implements get_extra_attributes. """ @@ -790,8 +796,8 @@ class OidcHandlerTestCase(HomeserverTestCase): "username": "foo", "phone": "1234567", } - self.provider._exchange_code = simple_async_mock(return_value=token) - self.provider._parse_id_token = simple_async_mock(return_value=userinfo) + self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment] + self.provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment] auth_handler = self.hs.get_auth_handler() auth_handler.complete_sso_login = simple_async_mock() @@ -817,12 +823,12 @@ class OidcHandlerTestCase(HomeserverTestCase): ) @override_config({"oidc_config": DEFAULT_CONFIG}) - def test_map_userinfo_to_user(self): + def test_map_userinfo_to_user(self) -> None: """Ensure that mapping the userinfo returned from a provider to an MXID works properly.""" auth_handler = self.hs.get_auth_handler() auth_handler.complete_sso_login = simple_async_mock() - userinfo = { + userinfo: dict = { "sub": "test_user", "username": "test_user", } @@ -870,7 +876,7 @@ class OidcHandlerTestCase(HomeserverTestCase): ) @override_config({"oidc_config": {**DEFAULT_CONFIG, "allow_existing_users": True}}) - def test_map_userinfo_to_existing_user(self): + def test_map_userinfo_to_existing_user(self) -> None: """Existing users can log in with OpenID Connect when allow_existing_users is True.""" store = self.hs.get_datastores().main user = UserID.from_string("@test_user:test") @@ -974,7 +980,7 @@ class OidcHandlerTestCase(HomeserverTestCase): ) @override_config({"oidc_config": DEFAULT_CONFIG}) - def test_map_userinfo_to_invalid_localpart(self): + def test_map_userinfo_to_invalid_localpart(self) -> None: """If the mapping provider generates an invalid localpart it should be rejected.""" self.get_success( _make_callback_with_userinfo(self.hs, {"sub": "test2", "username": "föö"}) @@ -991,7 +997,7 @@ class OidcHandlerTestCase(HomeserverTestCase): } } ) - def test_map_userinfo_to_user_retries(self): + def test_map_userinfo_to_user_retries(self) -> None: """The mapping provider can retry generating an MXID if the MXID is already in use.""" auth_handler = self.hs.get_auth_handler() auth_handler.complete_sso_login = simple_async_mock() @@ -1039,7 +1045,7 @@ class OidcHandlerTestCase(HomeserverTestCase): ) @override_config({"oidc_config": DEFAULT_CONFIG}) - def test_empty_localpart(self): + def test_empty_localpart(self) -> None: """Attempts to map onto an empty localpart should be rejected.""" userinfo = { "sub": "tester", @@ -1058,7 +1064,7 @@ class OidcHandlerTestCase(HomeserverTestCase): } } ) - def test_null_localpart(self): + def test_null_localpart(self) -> None: """Mapping onto a null localpart via an empty OIDC attribute should be rejected""" userinfo = { "sub": "tester", @@ -1075,7 +1081,7 @@ class OidcHandlerTestCase(HomeserverTestCase): } } ) - def test_attribute_requirements(self): + def test_attribute_requirements(self) -> None: """The required attributes must be met from the OIDC userinfo response.""" auth_handler = self.hs.get_auth_handler() auth_handler.complete_sso_login = simple_async_mock() @@ -1115,7 +1121,7 @@ class OidcHandlerTestCase(HomeserverTestCase): } } ) - def test_attribute_requirements_contains(self): + def test_attribute_requirements_contains(self) -> None: """Test that auth succeeds if userinfo attribute CONTAINS required value""" auth_handler = self.hs.get_auth_handler() auth_handler.complete_sso_login = simple_async_mock() @@ -1146,7 +1152,7 @@ class OidcHandlerTestCase(HomeserverTestCase): } } ) - def test_attribute_requirements_mismatch(self): + def test_attribute_requirements_mismatch(self) -> None: """ Test that auth fails if attributes exist but don't match, or are non-string values. @@ -1154,7 +1160,7 @@ class OidcHandlerTestCase(HomeserverTestCase): auth_handler = self.hs.get_auth_handler() auth_handler.complete_sso_login = simple_async_mock() # userinfo with "test": "not_foobar" attribute should fail - userinfo = { + userinfo: dict = { "sub": "tester", "username": "tester", "test": "not_foobar", @@ -1248,9 +1254,9 @@ async def _make_callback_with_userinfo( handler = hs.get_oidc_handler() provider = handler._providers["oidc"] - provider._exchange_code = simple_async_mock(return_value={"id_token": ""}) - provider._parse_id_token = simple_async_mock(return_value=userinfo) - provider._fetch_userinfo = simple_async_mock(return_value=userinfo) + provider._exchange_code = simple_async_mock(return_value={"id_token": ""}) # type: ignore[assignment] + provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment] + provider._fetch_userinfo = simple_async_mock(return_value=userinfo) # type: ignore[assignment] state = "state" session = handler._token_generator.generate_oidc_session_token( diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py index 49d832de81..d401fda938 100644 --- a/tests/handlers/test_password_providers.py +++ b/tests/handlers/test_password_providers.py @@ -124,7 +124,6 @@ class PasswordCustomAuthProvider: ("m.login.password", ("password",)): self.check_auth, } ) - pass def check_auth(self, *args): return mock_password_provider.check_auth(*args) diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index 6ddec9ecf1..b2ed9cbe37 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -331,11 +331,11 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): # Extract presence update user ID and state information into lists of tuples db_presence_states = [(ps[0], ps[1]) for _, ps in db_presence_states[0]] - presence_states = [(ps.user_id, ps.state) for ps in presence_states] + presence_states_compare = [(ps.user_id, ps.state) for ps in presence_states] # Compare what we put into the storage with what we got out. # They should be identical. - self.assertEqual(presence_states, db_presence_states) + self.assertEqual(presence_states_compare, db_presence_states) class PresenceTimeoutTestCase(unittest.TestCase): @@ -357,6 +357,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) self.assertIsNotNone(new_state) + assert new_state is not None self.assertEqual(new_state.state, PresenceState.UNAVAILABLE) self.assertEqual(new_state.status_msg, status_msg) @@ -380,6 +381,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) self.assertIsNotNone(new_state) + assert new_state is not None self.assertEqual(new_state.state, PresenceState.BUSY) self.assertEqual(new_state.status_msg, status_msg) @@ -399,6 +401,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) self.assertIsNotNone(new_state) + assert new_state is not None self.assertEqual(new_state.state, PresenceState.OFFLINE) self.assertEqual(new_state.status_msg, status_msg) @@ -420,6 +423,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): ) self.assertIsNotNone(new_state) + assert new_state is not None self.assertEqual(new_state.state, PresenceState.ONLINE) self.assertEqual(new_state.status_msg, status_msg) @@ -477,6 +481,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): ) self.assertIsNotNone(new_state) + assert new_state is not None self.assertEqual(new_state.state, PresenceState.OFFLINE) self.assertEqual(new_state.status_msg, status_msg) @@ -653,13 +658,13 @@ class PresenceHandlerTestCase(unittest.HomeserverTestCase): self._set_presencestate_with_status_msg(user_id, PresenceState.ONLINE, None) def _set_presencestate_with_status_msg( - self, user_id: str, state: PresenceState, status_msg: Optional[str] + self, user_id: str, state: str, status_msg: Optional[str] ): """Set a PresenceState and status_msg and check the result. Args: user_id: User for that the status is to be set. - PresenceState: The new PresenceState. + state: The new PresenceState. status_msg: Status message that is to be set. """ self.get_success( diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 972cbac6e4..1ec105c373 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -11,14 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict +from typing import Any, Awaitable, Callable, Dict from unittest.mock import Mock +from twisted.test.proto_helpers import MemoryReactor + import synapse.types from synapse.api.errors import AuthError, SynapseError from synapse.rest import admin from synapse.server import HomeServer -from synapse.types import UserID +from synapse.types import JsonDict, UserID +from synapse.util import Clock from tests import unittest from tests.test_utils import make_awaitable @@ -29,13 +32,15 @@ class ProfileTestCase(unittest.HomeserverTestCase): servlets = [admin.register_servlets] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.mock_federation = Mock() self.mock_registry = Mock() - self.query_handlers = {} + self.query_handlers: Dict[str, Callable[[dict], Awaitable[JsonDict]]] = {} - def register_query_handler(query_type, handler): + def register_query_handler( + query_type: str, handler: Callable[[dict], Awaitable[JsonDict]] + ) -> None: self.query_handlers[query_type] = handler self.mock_registry.register_query_handler = register_query_handler @@ -47,7 +52,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) return hs - def prepare(self, reactor, clock, hs: HomeServer): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main self.frank = UserID.from_string("@1234abcd:test") @@ -58,7 +63,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): self.handler = hs.get_profile_handler() - def test_get_my_name(self): + def test_get_my_name(self) -> None: self.get_success( self.store.set_profile_displayname(self.frank.localpart, "Frank") ) @@ -67,7 +72,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): self.assertEqual("Frank", displayname) - def test_set_my_name(self): + def test_set_my_name(self) -> None: self.get_success( self.handler.set_displayname( self.frank, synapse.types.create_requester(self.frank), "Frank Jr." @@ -110,7 +115,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): self.get_success(self.store.get_profile_displayname(self.frank.localpart)) ) - def test_set_my_name_if_disabled(self): + def test_set_my_name_if_disabled(self) -> None: self.hs.config.registration.enable_set_displayname = False # Setting displayname for the first time is allowed @@ -135,7 +140,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): SynapseError, ) - def test_set_my_name_noauth(self): + def test_set_my_name_noauth(self) -> None: self.get_failure( self.handler.set_displayname( self.frank, synapse.types.create_requester(self.bob), "Frank Jr." @@ -143,7 +148,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): AuthError, ) - def test_get_other_name(self): + def test_get_other_name(self) -> None: self.mock_federation.make_query.return_value = make_awaitable( {"displayname": "Alice"} ) @@ -158,7 +163,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): ignore_backoff=True, ) - def test_incoming_fed_query(self): + def test_incoming_fed_query(self) -> None: self.get_success(self.store.create_profile("caroline")) self.get_success(self.store.set_profile_displayname("caroline", "Caroline")) @@ -174,7 +179,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): self.assertEqual({"displayname": "Caroline"}, response) - def test_get_my_avatar(self): + def test_get_my_avatar(self) -> None: self.get_success( self.store.set_profile_avatar_url( self.frank.localpart, "http://my.server/me.png" @@ -184,7 +189,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): self.assertEqual("http://my.server/me.png", avatar_url) - def test_set_my_avatar(self): + def test_set_my_avatar(self) -> None: self.get_success( self.handler.set_avatar_url( self.frank, @@ -225,7 +230,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))), ) - def test_set_my_avatar_if_disabled(self): + def test_set_my_avatar_if_disabled(self) -> None: self.hs.config.registration.enable_set_avatar_url = False # Setting displayname for the first time is allowed @@ -250,7 +255,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): SynapseError, ) - def test_avatar_constraints_no_config(self): + def test_avatar_constraints_no_config(self) -> None: """Tests that the method to check an avatar against configured constraints skips all of its check if no constraint is configured. """ @@ -263,7 +268,13 @@ class ProfileTestCase(unittest.HomeserverTestCase): self.assertTrue(res) @unittest.override_config({"max_avatar_size": 50}) - def test_avatar_constraints_missing(self): + def test_avatar_constraints_allow_empty_avatar_url(self) -> None: + """An empty avatar is always permitted.""" + res = self.get_success(self.handler.check_avatar_size_and_mime_type("")) + self.assertTrue(res) + + @unittest.override_config({"max_avatar_size": 50}) + def test_avatar_constraints_missing(self) -> None: """Tests that an avatar isn't allowed if the file at the given MXC URI couldn't be found. """ @@ -273,7 +284,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): self.assertFalse(res) @unittest.override_config({"max_avatar_size": 50}) - def test_avatar_constraints_file_size(self): + def test_avatar_constraints_file_size(self) -> None: """Tests that a file that's above the allowed file size is forbidden but one that's below it is allowed. """ @@ -295,7 +306,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): self.assertFalse(res) @unittest.override_config({"allowed_avatar_mimetypes": ["image/png"]}) - def test_avatar_constraint_mime_type(self): + def test_avatar_constraint_mime_type(self) -> None: """Tests that a file with an unauthorised MIME type is forbidden but one with an authorised content type is allowed. """ diff --git a/tests/handlers/test_room_summary.py b/tests/handlers/test_room_summary.py index cff07a8973..d37292ce13 100644 --- a/tests/handlers/test_room_summary.py +++ b/tests/handlers/test_room_summary.py @@ -172,6 +172,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): result_room_ids = [] result_children_ids = [] for result_room in result["rooms"]: + # Ensure federation results are not leaking over the client-server API. + self.assertNotIn("allowed_room_ids", result_room) + result_room_ids.append(result_room["room_id"]) result_children_ids.append( [ diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py index 23941abed8..8d4404eda1 100644 --- a/tests/handlers/test_saml.py +++ b/tests/handlers/test_saml.py @@ -12,12 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Any, Dict, Optional from unittest.mock import Mock import attr +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.errors import RedirectException +from synapse.server import HomeServer +from synapse.util import Clock from tests.test_utils import simple_async_mock from tests.unittest import HomeserverTestCase, override_config @@ -81,10 +85,10 @@ class TestRedirectMappingProvider(TestMappingProvider): class SamlHandlerTestCase(HomeserverTestCase): - def default_config(self): + def default_config(self) -> Dict[str, Any]: config = super().default_config() config["public_baseurl"] = BASE_URL - saml_config = { + saml_config: Dict[str, Any] = { "sp_config": {"metadata": {}}, # Disable grandfathering. "grandfathered_mxid_source_attribute": None, @@ -98,7 +102,7 @@ class SamlHandlerTestCase(HomeserverTestCase): return config - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs = self.setup_test_homeserver() self.handler = hs.get_saml_handler() @@ -114,7 +118,7 @@ class SamlHandlerTestCase(HomeserverTestCase): elif not has_xmlsec1: skip = "Requires xmlsec1" - def test_map_saml_response_to_user(self): + def test_map_saml_response_to_user(self) -> None: """Ensure that mapping the SAML response returned from a provider to an MXID works properly.""" # stub out the auth handler @@ -140,7 +144,7 @@ class SamlHandlerTestCase(HomeserverTestCase): ) @override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}}) - def test_map_saml_response_to_existing_user(self): + def test_map_saml_response_to_existing_user(self) -> None: """Existing users can log in with SAML account.""" store = self.hs.get_datastores().main self.get_success( @@ -186,7 +190,7 @@ class SamlHandlerTestCase(HomeserverTestCase): auth_provider_session_id=None, ) - def test_map_saml_response_to_invalid_localpart(self): + def test_map_saml_response_to_invalid_localpart(self) -> None: """If the mapping provider generates an invalid localpart it should be rejected.""" # stub out the auth handler @@ -207,7 +211,7 @@ class SamlHandlerTestCase(HomeserverTestCase): ) auth_handler.complete_sso_login.assert_not_called() - def test_map_saml_response_to_user_retries(self): + def test_map_saml_response_to_user_retries(self) -> None: """The mapping provider can retry generating an MXID if the MXID is already in use.""" # stub out the auth handler and error renderer @@ -271,7 +275,7 @@ class SamlHandlerTestCase(HomeserverTestCase): } } ) - def test_map_saml_response_redirect(self): + def test_map_saml_response_redirect(self) -> None: """Test a mapping provider that raises a RedirectException""" saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"}) @@ -292,7 +296,7 @@ class SamlHandlerTestCase(HomeserverTestCase): }, } ) - def test_attribute_requirements(self): + def test_attribute_requirements(self) -> None: """The required attributes must be met from the SAML response.""" # stub out the auth handler diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index f91a80b9fa..ffd5c4cb93 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -18,11 +18,14 @@ from typing import Dict from unittest.mock import ANY, Mock, call from twisted.internet import defer +from twisted.test.proto_helpers import MemoryReactor from twisted.web.resource import Resource from synapse.api.errors import AuthError from synapse.federation.transport.server import TransportLayerServer -from synapse.types import UserID, create_requester +from synapse.server import HomeServer +from synapse.types import JsonDict, UserID, create_requester +from synapse.util import Clock from tests import unittest from tests.test_utils import make_awaitable @@ -42,7 +45,9 @@ ROOM_ID = "a-room" OTHER_ROOM_ID = "another-room" -def _expect_edu_transaction(edu_type, content, origin="test"): +def _expect_edu_transaction( + edu_type: str, content: JsonDict, origin: str = "test" +) -> JsonDict: return { "origin": origin, "origin_server_ts": 1000000, @@ -51,12 +56,12 @@ def _expect_edu_transaction(edu_type, content, origin="test"): } -def _make_edu_transaction_json(edu_type, content): +def _make_edu_transaction_json(edu_type: str, content: JsonDict) -> bytes: return json.dumps(_expect_edu_transaction(edu_type, content)).encode("utf8") class TypingNotificationsTestCase(unittest.HomeserverTestCase): - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: # we mock out the keyring so as to skip the authentication check on the # federation API call. mock_keyring = Mock(spec=["verify_json_for_server"]) @@ -83,7 +88,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): d["/_matrix/federation"] = TransportLayerServer(self.hs) return d - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: mock_notifier = hs.get_notifier() self.on_new_event = mock_notifier.on_new_event @@ -111,24 +116,24 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.room_members = [] - async def check_user_in_room(room_id, user_id): + async def check_user_in_room(room_id: str, user_id: str) -> None: if user_id not in [u.to_string() for u in self.room_members]: raise AuthError(401, "User is not in the room") return None hs.get_auth().check_user_in_room = check_user_in_room - async def check_host_in_room(room_id, server_name): + async def check_host_in_room(room_id: str, server_name: str) -> bool: return room_id == ROOM_ID hs.get_event_auth_handler().check_host_in_room = check_host_in_room - def get_joined_hosts_for_room(room_id): + def get_joined_hosts_for_room(room_id: str): return {member.domain for member in self.room_members} self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room - async def get_users_in_room(room_id): + async def get_users_in_room(room_id: str): return {str(u) for u in self.room_members} self.datastore.get_users_in_room = get_users_in_room @@ -153,7 +158,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): lambda *args, **kwargs: make_awaitable(None) ) - def test_started_typing_local(self): + def test_started_typing_local(self) -> None: self.room_members = [U_APPLE, U_BANANA] self.assertEqual(self.event_source.get_current_key(), 0) @@ -187,7 +192,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ) @override_config({"send_federation": True}) - def test_started_typing_remote_send(self): + def test_started_typing_remote_send(self) -> None: self.room_members = [U_APPLE, U_ONION] self.get_success( @@ -217,7 +222,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): try_trailing_slash_on_400=True, ) - def test_started_typing_remote_recv(self): + def test_started_typing_remote_recv(self) -> None: self.room_members = [U_APPLE, U_ONION] self.assertEqual(self.event_source.get_current_key(), 0) @@ -256,7 +261,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ], ) - def test_started_typing_remote_recv_not_in_room(self): + def test_started_typing_remote_recv_not_in_room(self) -> None: self.room_members = [U_APPLE, U_ONION] self.assertEqual(self.event_source.get_current_key(), 0) @@ -292,7 +297,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.assertEqual(events[1], 0) @override_config({"send_federation": True}) - def test_stopped_typing(self): + def test_stopped_typing(self) -> None: self.room_members = [U_APPLE, U_BANANA, U_ONION] # Gut-wrenching @@ -343,7 +348,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): [{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}], ) - def test_typing_timeout(self): + def test_typing_timeout(self) -> None: self.room_members = [U_APPLE, U_BANANA] self.assertEqual(self.event_source.get_current_key(), 0) diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py index c3f20f9692..10dd94b549 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py @@ -86,6 +86,16 @@ class ModuleApiTestCase(HomeserverTestCase): displayname = self.get_success(self.store.get_profile_displayname("bob")) self.assertEqual(displayname, "Bobberino") + def test_can_register_admin_user(self): + user_id = self.get_success( + self.register_user( + "bob_module_admin", "1234", displayname="Bobberino Admin", admin=True + ) + ) + found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id)) + self.assertEqual(found_user.user_id.to_string(), user_id) + self.assertIdentical(found_user.is_admin, True) + def test_get_userinfo_by_id(self): user_id = self.register_user("alice", "1234") found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id)) diff --git a/tests/push/test_http.py b/tests/push/test_http.py index c284beb37c..ba158f5d93 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py @@ -11,14 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Dict, List, Optional, Tuple from unittest.mock import Mock from twisted.internet.defer import Deferred +from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin from synapse.logging.context import make_deferred_yieldable from synapse.push import PusherConfigException -from synapse.rest.client import login, receipts, room +from synapse.rest.client import login, push_rule, receipts, room +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util import Clock from tests.unittest import HomeserverTestCase, override_config @@ -29,22 +34,23 @@ class HTTPPusherTests(HomeserverTestCase): room.register_servlets, login.register_servlets, receipts.register_servlets, + push_rule.register_servlets, ] user_id = True hijack_auth = False - def default_config(self): + def default_config(self) -> Dict[str, Any]: config = super().default_config() config["start_pushers"] = True return config - def make_homeserver(self, reactor, clock): - self.push_attempts = [] + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + self.push_attempts: List[Tuple[Deferred, str, dict]] = [] m = Mock() def post_json_get_json(url, body): - d = Deferred() + d: Deferred = Deferred() self.push_attempts.append((d, url, body)) return make_deferred_yieldable(d) @@ -54,7 +60,7 @@ class HTTPPusherTests(HomeserverTestCase): return hs - def test_invalid_configuration(self): + def test_invalid_configuration(self) -> None: """Invalid push configurations should be rejected.""" # Register the user who gets notified user_id = self.register_user("user", "pass") @@ -66,7 +72,7 @@ class HTTPPusherTests(HomeserverTestCase): ) token_id = user_tuple.token_id - def test_data(data): + def test_data(data: Optional[JsonDict]) -> None: self.get_failure( self.hs.get_pusherpool().add_pusher( user_id=user_id, @@ -93,7 +99,7 @@ class HTTPPusherTests(HomeserverTestCase): # A url with an incorrect path isn't accepted. test_data({"url": "http://example.com/foo"}) - def test_sends_http(self): + def test_sends_http(self) -> None: """ The HTTP pusher will send pushes for each message to a HTTP endpoint when configured to do so. @@ -198,7 +204,7 @@ class HTTPPusherTests(HomeserverTestCase): self.assertEqual(len(pushers), 1) self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering) - def test_sends_high_priority_for_encrypted(self): + def test_sends_high_priority_for_encrypted(self) -> None: """ The HTTP pusher will send pushes at high priority if they correspond to an encrypted message. @@ -319,7 +325,7 @@ class HTTPPusherTests(HomeserverTestCase): ) self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "high") - def test_sends_high_priority_for_one_to_one_only(self): + def test_sends_high_priority_for_one_to_one_only(self) -> None: """ The HTTP pusher will send pushes at high priority if they correspond to a message in a one-to-one room. @@ -402,7 +408,7 @@ class HTTPPusherTests(HomeserverTestCase): # check that this is low-priority self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low") - def test_sends_high_priority_for_mention(self): + def test_sends_high_priority_for_mention(self) -> None: """ The HTTP pusher will send pushes at high priority if they correspond to a message containing the user's display name. @@ -478,7 +484,7 @@ class HTTPPusherTests(HomeserverTestCase): # check that this is low-priority self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low") - def test_sends_high_priority_for_atroom(self): + def test_sends_high_priority_for_atroom(self) -> None: """ The HTTP pusher will send pushes at high priority if they correspond to a message that contains @room. @@ -561,7 +567,7 @@ class HTTPPusherTests(HomeserverTestCase): # check that this is low-priority self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low") - def test_push_unread_count_group_by_room(self): + def test_push_unread_count_group_by_room(self) -> None: """ The HTTP pusher will group unread count by number of unread rooms. """ @@ -574,7 +580,7 @@ class HTTPPusherTests(HomeserverTestCase): self._check_push_attempt(6, 1) @override_config({"push": {"group_unread_count_by_room": False}}) - def test_push_unread_count_message_count(self): + def test_push_unread_count_message_count(self) -> None: """ The HTTP pusher will send the total unread message count. """ @@ -587,7 +593,7 @@ class HTTPPusherTests(HomeserverTestCase): # last read receipt self._check_push_attempt(6, 3) - def _test_push_unread_count(self): + def _test_push_unread_count(self) -> None: """ Tests that the correct unread count appears in sent push notifications @@ -679,7 +685,7 @@ class HTTPPusherTests(HomeserverTestCase): self.helper.send(room_id, body="HELLO???", tok=other_access_token) - def _advance_time_and_make_push_succeed(self, expected_push_attempts): + def _advance_time_and_make_push_succeed(self, expected_push_attempts: int) -> None: self.pump() self.push_attempts[expected_push_attempts - 1][0].callback({}) @@ -706,7 +712,9 @@ class HTTPPusherTests(HomeserverTestCase): expected_unread_count_last_push, ) - def _send_read_request(self, access_token, message_event_id, room_id): + def _send_read_request( + self, access_token: str, message_event_id: str, room_id: str + ) -> None: # Now set the user's read receipt position to the first event # # This will actually trigger a new notification to be sent out so that @@ -719,3 +727,67 @@ class HTTPPusherTests(HomeserverTestCase): access_token=access_token, ) self.assertEqual(channel.code, 200, channel.json_body) + + def _make_user_with_pusher(self, username: str) -> Tuple[str, str]: + user_id = self.register_user(username, "pass") + access_token = self.login(username, "pass") + + # Register the pusher + user_tuple = self.get_success( + self.hs.get_datastores().main.get_user_by_access_token(access_token) + ) + token_id = user_tuple.token_id + + self.get_success( + self.hs.get_pusherpool().add_pusher( + user_id=user_id, + access_token=token_id, + kind="http", + app_id="m.http", + app_display_name="HTTP Push Notifications", + device_display_name="pushy push", + pushkey="a@example.com", + lang=None, + data={"url": "http://example.com/_matrix/push/v1/notify"}, + ) + ) + + return user_id, access_token + + def test_dont_notify_rule_overrides_message(self) -> None: + """ + The override push rule will suppress notification + """ + + user_id, access_token = self._make_user_with_pusher("user") + other_user_id, other_access_token = self._make_user_with_pusher("otheruser") + + # Create a room + room = self.helper.create_room_as(user_id, tok=access_token) + + # Disable user notifications for this room -> user + body = { + "conditions": [{"kind": "event_match", "key": "room_id", "pattern": room}], + "actions": ["dont_notify"], + } + channel = self.make_request( + "PUT", + "/pushrules/global/override/best.friend", + body, + access_token=access_token, + ) + self.assertEqual(channel.code, 200) + + # Check we start with no pushes + self.assertEqual(len(self.push_attempts), 0) + + # The other user joins + self.helper.join(room=room, user=other_user_id, tok=other_access_token) + + # The other user sends a message (ignored by dont_notify push rule set above) + self.helper.send(room, body="Hi!", tok=other_access_token) + self.assertEqual(len(self.push_attempts), 0) + + # The user sends a message back (sends a notification) + self.helper.send(room, body="Hello", tok=access_token) + self.assertEqual(len(self.push_attempts), 1) diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py index 3849beb9d6..5dba187076 100644 --- a/tests/push/test_push_rule_evaluator.py +++ b/tests/push/test_push_rule_evaluator.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict +from typing import Dict, Optional, Union import frozendict @@ -20,12 +20,13 @@ from synapse.api.room_versions import RoomVersions from synapse.events import FrozenEvent from synapse.push import push_rule_evaluator from synapse.push.push_rule_evaluator import PushRuleEvaluatorForEvent +from synapse.types import JsonDict from tests import unittest class PushRuleEvaluatorTestCase(unittest.TestCase): - def _get_evaluator(self, content): + def _get_evaluator(self, content: JsonDict) -> PushRuleEvaluatorForEvent: event = FrozenEvent( { "event_id": "$event_id", @@ -39,12 +40,12 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): ) room_member_count = 0 sender_power_level = 0 - power_levels = {} + power_levels: Dict[str, Union[int, Dict[str, int]]] = {} return PushRuleEvaluatorForEvent( event, room_member_count, sender_power_level, power_levels ) - def test_display_name(self): + def test_display_name(self) -> None: """Check for a matching display name in the body of the event.""" evaluator = self._get_evaluator({"body": "foo bar baz"}) @@ -71,20 +72,20 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): self.assertTrue(evaluator.matches(condition, "@user:test", "foo bar")) def _assert_matches( - self, condition: Dict[str, Any], content: Dict[str, Any], msg=None + self, condition: JsonDict, content: JsonDict, msg: Optional[str] = None ) -> None: evaluator = self._get_evaluator(content) self.assertTrue(evaluator.matches(condition, "@user:test", "display_name"), msg) def _assert_not_matches( - self, condition: Dict[str, Any], content: Dict[str, Any], msg=None + self, condition: JsonDict, content: JsonDict, msg: Optional[str] = None ) -> None: evaluator = self._get_evaluator(content) self.assertFalse( evaluator.matches(condition, "@user:test", "display_name"), msg ) - def test_event_match_body(self): + def test_event_match_body(self) -> None: """Check that event_match conditions on content.body work as expected""" # if the key is `content.body`, the pattern matches substrings. @@ -165,7 +166,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): r"? after \ should match any character", ) - def test_event_match_non_body(self): + def test_event_match_non_body(self) -> None: """Check that event_match conditions on other keys work as expected""" # if the key is anything other than 'content.body', the pattern must match the @@ -241,7 +242,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): "pattern should not match before a newline", ) - def test_no_body(self): + def test_no_body(self) -> None: """Not having a body shouldn't break the evaluator.""" evaluator = self._get_evaluator({}) @@ -250,7 +251,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): } self.assertFalse(evaluator.matches(condition, "@user:test", "foo")) - def test_invalid_body(self): + def test_invalid_body(self) -> None: """A non-string body should not break the evaluator.""" condition = { "kind": "contains_display_name", @@ -260,7 +261,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): evaluator = self._get_evaluator({"body": body}) self.assertFalse(evaluator.matches(condition, "@user:test", "foo")) - def test_tweaks_for_actions(self): + def test_tweaks_for_actions(self) -> None: """ This tests the behaviour of tweaks_for_actions. """ diff --git a/tests/replication/_base.py b/tests/replication/_base.py index a7a05a564f..9c5df266bd 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -251,7 +251,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): self.connect_any_redis_attempts, ) - self.hs.get_tcp_replication().start_replication(self.hs) + self.hs.get_replication_command_handler().start_replication(self.hs) # When we see a connection attempt to the master replication listener we # automatically set up the connection. This is so that tests don't @@ -375,7 +375,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): ) if worker_hs.config.redis.redis_enabled: - worker_hs.get_tcp_replication().start_replication(worker_hs) + worker_hs.get_replication_command_handler().start_replication(worker_hs) return worker_hs diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py index f9d5da723c..641a94133b 100644 --- a/tests/replication/tcp/streams/test_events.py +++ b/tests/replication/tcp/streams/test_events.py @@ -420,7 +420,7 @@ class EventsStreamTestCase(BaseStreamTestCase): # Manually send an old RDATA command, which should get dropped. This # re-uses the row from above, but with an earlier stream token. - self.hs.get_tcp_replication().send_command( + self.hs.get_replication_command_handler().send_command( RdataCommand("events", "master", 1, row) ) diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py index 3ff5afc6e5..9a229dd23f 100644 --- a/tests/replication/tcp/streams/test_typing.py +++ b/tests/replication/tcp/streams/test_typing.py @@ -118,7 +118,7 @@ class TypingStreamTestCase(BaseStreamTestCase): # Reset the typing handler self.hs.get_replication_streams()["typing"].last_token = 0 - self.hs.get_tcp_replication()._streams["typing"].last_token = 0 + self.hs.get_replication_command_handler()._streams["typing"].last_token = 0 typing._latest_room_serial = 0 typing._typing_stream_change_cache = StreamChangeCache( "TypingStreamChangeCache", typing._latest_room_serial diff --git a/tests/replication/test_federation_ack.py b/tests/replication/test_federation_ack.py index 1b6a4bf4b0..26b8bd512a 100644 --- a/tests/replication/test_federation_ack.py +++ b/tests/replication/test_federation_ack.py @@ -48,7 +48,7 @@ class FederationAckTestCase(HomeserverTestCase): transport, rather than assuming that the implementation has a ReplicationCommandHandler. """ - rch = self.hs.get_tcp_replication() + rch = self.hs.get_replication_command_handler() # wire up the ReplicationCommandHandler to a mock connection, which needs # to implement IReplicationConnection. (Note that Mock doesn't understand diff --git a/tests/rest/admin/test_background_updates.py b/tests/rest/admin/test_background_updates.py index fb36aa9940..6cf56b1e35 100644 --- a/tests/rest/admin/test_background_updates.py +++ b/tests/rest/admin/test_background_updates.py @@ -39,6 +39,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") + self.updater = BackgroundUpdater(hs, self.store.db_pool) @parameterized.expand( [ @@ -135,10 +136,10 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): """Test the status API works with a background update.""" # Create a new background update - self._register_bg_update() self.store.db_pool.updates.start_doing_background_updates() + self.reactor.pump([1.0, 1.0, 1.0]) channel = self.make_request( @@ -155,10 +156,10 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): "current_updates": { "master": { "name": "test_update", - "average_items_per_ms": 0.001, + "average_items_per_ms": 0.1, "total_duration_ms": 1000.0, "total_item_count": ( - BackgroundUpdater.MINIMUM_BACKGROUND_BATCH_SIZE + self.updater.default_background_batch_size ), } }, @@ -210,10 +211,10 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): "current_updates": { "master": { "name": "test_update", - "average_items_per_ms": 0.001, + "average_items_per_ms": 0.1, "total_duration_ms": 1000.0, "total_item_count": ( - BackgroundUpdater.MINIMUM_BACKGROUND_BATCH_SIZE + self.updater.default_background_batch_size ), } }, @@ -239,10 +240,10 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): "current_updates": { "master": { "name": "test_update", - "average_items_per_ms": 0.001, + "average_items_per_ms": 0.1, "total_duration_ms": 1000.0, "total_item_count": ( - BackgroundUpdater.MINIMUM_BACKGROUND_BATCH_SIZE + self.updater.default_background_batch_size ), } }, @@ -278,11 +279,9 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): "current_updates": { "master": { "name": "test_update", - "average_items_per_ms": 0.001, + "average_items_per_ms": 0.05263157894736842, "total_duration_ms": 2000.0, - "total_item_count": ( - 2 * BackgroundUpdater.MINIMUM_BACKGROUND_BATCH_SIZE - ), + "total_item_count": (110), } }, "enabled": True, diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index a60ea0a563..bef911d5df 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -1050,6 +1050,25 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): self._is_erased("@user:test", True) + @override_config({"max_avatar_size": 1234}) + def test_deactivate_user_erase_true_avatar_nonnull_but_empty(self) -> None: + """Check we can erase a user whose avatar is the empty string. + + Reproduces #12257. + """ + # Patch `self.other_user` to have an empty string as their avatar. + self.get_success(self.store.set_profile_avatar_url("user", "")) + + # Check we can still erase them. + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={"erase": True}, + ) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self._is_erased("@user:test", True) + def test_deactivate_user_erase_false(self) -> None: """ Test deactivating a user and set `erase` to `false` diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py index def836054d..27946febff 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py @@ -31,7 +31,7 @@ from synapse.rest import admin from synapse.rest.client import account, login, register, room from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource from synapse.server import HomeServer -from synapse.types import JsonDict +from synapse.types import JsonDict, UserID from synapse.util import Clock from tests import unittest @@ -1222,6 +1222,62 @@ class AccountStatusTestCase(unittest.HomeserverTestCase): expected_failures=[users[2]], ) + @unittest.override_config( + { + "use_account_validity_in_account_status": True, + } + ) + def test_no_account_validity(self) -> None: + """Tests that if we decide to include account validity in the response but no + account validity 'is_user_expired' callback is provided, we default to marking all + users as not expired. + """ + user = self.register_user("someuser", "password") + + self._test_status( + users=[user], + expected_statuses={ + user: { + "exists": True, + "deactivated": False, + "org.matrix.expired": False, + }, + }, + expected_failures=[], + ) + + @unittest.override_config( + { + "use_account_validity_in_account_status": True, + } + ) + def test_account_validity_expired(self) -> None: + """Test that if we decide to include account validity in the response and the user + is expired, we return the correct info. + """ + user = self.register_user("someuser", "password") + + async def is_expired(user_id: str) -> bool: + # We can't blindly say everyone is expired, otherwise the request to get the + # account status will fail. + return UserID.from_string(user_id).localpart == "someuser" + + self.hs.get_account_validity_handler()._is_user_expired_callbacks.append( + is_expired + ) + + self._test_status( + users=[user], + expected_statuses={ + user: { + "exists": True, + "deactivated": False, + "org.matrix.expired": True, + }, + }, + expected_failures=[], + ) + def _test_status( self, users: Optional[List[str]], diff --git a/tests/rest/client/test_shared_rooms.py b/tests/rest/client/test_mutual_rooms.py index 3818b7b14b..7b7d283bb6 100644 --- a/tests/rest/client/test_shared_rooms.py +++ b/tests/rest/client/test_mutual_rooms.py @@ -14,7 +14,7 @@ from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin -from synapse.rest.client import login, room, shared_rooms +from synapse.rest.client import login, mutual_rooms, room from synapse.server import HomeServer from synapse.util import Clock @@ -22,16 +22,16 @@ from tests import unittest from tests.server import FakeChannel -class UserSharedRoomsTest(unittest.HomeserverTestCase): +class UserMutualRoomsTest(unittest.HomeserverTestCase): """ - Tests the UserSharedRoomsServlet. + Tests the UserMutualRoomsServlet. """ servlets = [ login.register_servlets, synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, - shared_rooms.register_servlets, + mutual_rooms.register_servlets, ] def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: @@ -43,10 +43,10 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase): self.store = hs.get_datastores().main self.handler = hs.get_user_directory_handler() - def _get_shared_rooms(self, token: str, other_user: str) -> FakeChannel: + def _get_mutual_rooms(self, token: str, other_user: str) -> FakeChannel: return self.make_request( "GET", - "/_matrix/client/unstable/uk.half-shot.msc2666/user/shared_rooms/%s" + "/_matrix/client/unstable/uk.half-shot.msc2666/user/mutual_rooms/%s" % other_user, access_token=token, ) @@ -56,14 +56,14 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase): A room should show up in the shared list of rooms between two users if it is public. """ - self._check_shared_rooms_with(room_one_is_public=True, room_two_is_public=True) + self._check_mutual_rooms_with(room_one_is_public=True, room_two_is_public=True) def test_shared_room_list_private(self) -> None: """ A room should show up in the shared list of rooms between two users if it is private. """ - self._check_shared_rooms_with( + self._check_mutual_rooms_with( room_one_is_public=False, room_two_is_public=False ) @@ -72,9 +72,9 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase): The shared room list between two users should contain both public and private rooms. """ - self._check_shared_rooms_with(room_one_is_public=True, room_two_is_public=False) + self._check_mutual_rooms_with(room_one_is_public=True, room_two_is_public=False) - def _check_shared_rooms_with( + def _check_mutual_rooms_with( self, room_one_is_public: bool, room_two_is_public: bool ) -> None: """Checks that shared public or private rooms between two users appear in @@ -94,7 +94,7 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase): # Check shared rooms from user1's perspective. # We should see the one room in common - channel = self._get_shared_rooms(u1_token, u2) + channel = self._get_mutual_rooms(u1_token, u2) self.assertEqual(200, channel.code, channel.result) self.assertEqual(len(channel.json_body["joined"]), 1) self.assertEqual(channel.json_body["joined"][0], room_id_one) @@ -107,7 +107,7 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase): self.helper.join(room_id_two, user=u2, tok=u2_token) # Check shared rooms again. We should now see both rooms. - channel = self._get_shared_rooms(u1_token, u2) + channel = self._get_mutual_rooms(u1_token, u2) self.assertEqual(200, channel.code, channel.result) self.assertEqual(len(channel.json_body["joined"]), 2) for room_id_id in channel.json_body["joined"]: @@ -128,7 +128,7 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase): self.helper.join(room, user=u2, tok=u2_token) # Assert user directory is not empty - channel = self._get_shared_rooms(u1_token, u2) + channel = self._get_mutual_rooms(u1_token, u2) self.assertEqual(200, channel.code, channel.result) self.assertEqual(len(channel.json_body["joined"]), 1) self.assertEqual(channel.json_body["joined"][0], room) @@ -136,11 +136,11 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase): self.helper.leave(room, user=u1, tok=u1_token) # Check user1's view of shared rooms with user2 - channel = self._get_shared_rooms(u1_token, u2) + channel = self._get_mutual_rooms(u1_token, u2) self.assertEqual(200, channel.code, channel.result) self.assertEqual(len(channel.json_body["joined"]), 0) # Check user2's view of shared rooms with user1 - channel = self._get_shared_rooms(u2_token, u1) + channel = self._get_mutual_rooms(u2_token, u1) self.assertEqual(200, channel.code, channel.result) self.assertEqual(len(channel.json_body["joined"]), 0) diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 709f851a38..fe97a0b3dd 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -15,17 +15,16 @@ import itertools import urllib.parse -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple from unittest.mock import patch from twisted.test.proto_helpers import MemoryReactor -from synapse.api.constants import EventTypes, RelationTypes +from synapse.api.constants import AccountDataTypes, EventTypes, RelationTypes from synapse.rest import admin from synapse.rest.client import login, register, relations, room, sync from synapse.server import HomeServer -from synapse.storage.relations import RelationPaginationToken -from synapse.types import JsonDict, StreamToken +from synapse.types import JsonDict from synapse.util import Clock from tests import unittest @@ -80,6 +79,7 @@ class BaseRelationsTestCase(unittest.HomeserverTestCase): content: Optional[dict] = None, access_token: Optional[str] = None, parent_id: Optional[str] = None, + expected_response_code: int = 200, ) -> FakeChannel: """Helper function to send a relation pointing at `self.parent_id` @@ -116,16 +116,60 @@ class BaseRelationsTestCase(unittest.HomeserverTestCase): content, access_token=access_token, ) + self.assertEqual(expected_response_code, channel.code, channel.json_body) return channel + def _get_related_events(self) -> List[str]: + """ + Requests /relations on the parent ID and returns a list of event IDs. + """ + # Request the relations of the event. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + return [ev["event_id"] for ev in channel.json_body["chunk"]] + + def _get_bundled_aggregations(self) -> JsonDict: + """ + Requests /event on the parent ID and returns the m.relations field (from unsigned), if it exists. + """ + # Fetch the bundled aggregations of the event. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/event/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + return channel.json_body["unsigned"].get("m.relations", {}) + + def _get_aggregations(self) -> List[JsonDict]: + """Request /aggregations on the parent ID and includes the returned chunk.""" + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + return channel.json_body["chunk"] + + def _find_event_in_chunk(self, events: List[JsonDict]) -> JsonDict: + """ + Find the parent event in a chunk of events and assert that it has the proper bundled aggregations. + """ + for event in events: + if event["event_id"] == self.parent_id: + return event + + raise AssertionError(f"Event {self.parent_id} not found in chunk") + class RelationsTestCase(BaseRelationsTestCase): def test_send_relation(self) -> None: """Tests that sending a relation works.""" - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍") - self.assertEqual(200, channel.code, channel.json_body) - event_id = channel.json_body["event_id"] channel = self.make_request( @@ -152,13 +196,13 @@ class RelationsTestCase(BaseRelationsTestCase): def test_deny_invalid_event(self) -> None: """Test that we deny relations on non-existant events""" - channel = self._send_relation( + self._send_relation( RelationTypes.ANNOTATION, EventTypes.Message, parent_id="foo", content={"body": "foo", "msgtype": "m.text"}, + expected_response_code=400, ) - self.assertEqual(400, channel.code, channel.json_body) # Unless that event is referenced from another event! self.get_success( @@ -172,13 +216,12 @@ class RelationsTestCase(BaseRelationsTestCase): desc="test_deny_invalid_event", ) ) - channel = self._send_relation( + self._send_relation( RelationTypes.THREAD, EventTypes.Message, parent_id="foo", content={"body": "foo", "msgtype": "m.text"}, ) - self.assertEqual(200, channel.code, channel.json_body) def test_deny_invalid_room(self) -> None: """Test that we deny relations on non-existant events""" @@ -188,18 +231,20 @@ class RelationsTestCase(BaseRelationsTestCase): parent_id = res["event_id"] # Attempt to send an annotation to that event. - channel = self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", parent_id=parent_id, key="A" + self._send_relation( + RelationTypes.ANNOTATION, + "m.reaction", + parent_id=parent_id, + key="A", + expected_response_code=400, ) - self.assertEqual(400, channel.code, channel.json_body) def test_deny_double_react(self) -> None: """Test that we deny relations on membership events""" - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") - self.assertEqual(200, channel.code, channel.json_body) - - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEqual(400, channel.code, channel.json_body) + self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") + self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", "a", expected_response_code=400 + ) def test_deny_forked_thread(self) -> None: """It is invalid to start a thread off a thread.""" @@ -209,386 +254,24 @@ class RelationsTestCase(BaseRelationsTestCase): content={"msgtype": "m.text", "body": "foo"}, parent_id=self.parent_id, ) - self.assertEqual(200, channel.code, channel.json_body) parent_id = channel.json_body["event_id"] - channel = self._send_relation( + self._send_relation( RelationTypes.THREAD, "m.room.message", content={"msgtype": "m.text", "body": "foo"}, parent_id=parent_id, - ) - self.assertEqual(400, channel.code, channel.json_body) - - def test_basic_paginate_relations(self) -> None: - """Tests that calling pagination API correctly the latest relations.""" - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEqual(200, channel.code, channel.json_body) - first_annotation_id = channel.json_body["event_id"] - - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") - self.assertEqual(200, channel.code, channel.json_body) - second_annotation_id = channel.json_body["event_id"] - - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - - # We expect to get back a single pagination result, which is the latest - # full relation event we sent above. - self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) - self.assert_dict( - { - "event_id": second_annotation_id, - "sender": self.user_id, - "type": "m.reaction", - }, - channel.json_body["chunk"][0], - ) - - # We also expect to get the original event (the id of which is self.parent_id) - self.assertEqual( - channel.json_body["original_event"]["event_id"], self.parent_id - ) - - # Make sure next_batch has something in it that looks like it could be a - # valid token. - self.assertIsInstance( - channel.json_body.get("next_batch"), str, channel.json_body - ) - - # Request the relations again, but with a different direction. - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/relations" - f"/{self.parent_id}?limit=1&org.matrix.msc3715.dir=f", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - - # We expect to get back a single pagination result, which is the earliest - # full relation event we sent above. - self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) - self.assert_dict( - { - "event_id": first_annotation_id, - "sender": self.user_id, - "type": "m.reaction", - }, - channel.json_body["chunk"][0], - ) - - def _stream_token_to_relation_token(self, token: str) -> str: - """Convert a StreamToken into a legacy token (RelationPaginationToken).""" - room_key = self.get_success(StreamToken.from_string(self.store, token)).room_key - return self.get_success( - RelationPaginationToken( - topological=room_key.topological, stream=room_key.stream - ).to_string(self.store) - ) - - def test_repeated_paginate_relations(self) -> None: - """Test that if we paginate using a limit and tokens then we get the - expected events. - """ - - expected_event_ids = [] - for idx in range(10): - channel = self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", chr(ord("a") + idx) - ) - self.assertEqual(200, channel.code, channel.json_body) - expected_event_ids.append(channel.json_body["event_id"]) - - prev_token = "" - found_event_ids: List[str] = [] - for _ in range(20): - from_token = "" - if prev_token: - from_token = "&from=" + prev_token - - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - - found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) - next_batch = channel.json_body.get("next_batch") - - self.assertNotEqual(prev_token, next_batch) - prev_token = next_batch - - if not prev_token: - break - - # We paginated backwards, so reverse - found_event_ids.reverse() - self.assertEqual(found_event_ids, expected_event_ids) - - # Reset and try again, but convert the tokens to the legacy format. - prev_token = "" - found_event_ids = [] - for _ in range(20): - from_token = "" - if prev_token: - from_token = "&from=" + self._stream_token_to_relation_token(prev_token) - - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - - found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) - next_batch = channel.json_body.get("next_batch") - - self.assertNotEqual(prev_token, next_batch) - prev_token = next_batch - - if not prev_token: - break - - # We paginated backwards, so reverse - found_event_ids.reverse() - self.assertEqual(found_event_ids, expected_event_ids) - - def test_pagination_from_sync_and_messages(self) -> None: - """Pagination tokens from /sync and /messages can be used to paginate /relations.""" - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "A") - self.assertEqual(200, channel.code, channel.json_body) - annotation_id = channel.json_body["event_id"] - # Send an event after the relation events. - self.helper.send(self.room, body="Latest event", tok=self.user_token) - - # Request /sync, limiting it such that only the latest event is returned - # (and not the relation). - filter = urllib.parse.quote_plus(b'{"room": {"timeline": {"limit": 1}}}') - channel = self.make_request( - "GET", f"/sync?filter={filter}", access_token=self.user_token - ) - self.assertEqual(200, channel.code, channel.json_body) - room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] - sync_prev_batch = room_timeline["prev_batch"] - self.assertIsNotNone(sync_prev_batch) - # Ensure the relation event is not in the batch returned from /sync. - self.assertNotIn( - annotation_id, [ev["event_id"] for ev in room_timeline["events"]] - ) - - # Request /messages, limiting it such that only the latest event is - # returned (and not the relation). - channel = self.make_request( - "GET", - f"/rooms/{self.room}/messages?dir=b&limit=1", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - messages_end = channel.json_body["end"] - self.assertIsNotNone(messages_end) - # Ensure the relation event is not in the chunk returned from /messages. - self.assertNotIn( - annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]] + expected_response_code=400, ) - # Request /relations with the pagination tokens received from both the - # /sync and /messages responses above, in turn. - # - # This is a tiny bit silly since the client wouldn't know the parent ID - # from the requests above; consider the parent ID to be known from a - # previous /sync. - for from_token in (sync_prev_batch, messages_end): - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?from={from_token}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - - # The relation should be in the returned chunk. - self.assertIn( - annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]] - ) - - def test_aggregation_pagination_groups(self) -> None: - """Test that we can paginate annotation groups correctly.""" - - # We need to create ten separate users to send each reaction. - access_tokens = [self.user_token, self.user2_token] - idx = 0 - while len(access_tokens) < 10: - user_id, token = self._create_user("test" + str(idx)) - idx += 1 - - self.helper.join(self.room, user=user_id, tok=token) - access_tokens.append(token) - - idx = 0 - sent_groups = {"👍": 10, "a": 7, "b": 5, "c": 3, "d": 2, "e": 1} - for key in itertools.chain.from_iterable( - itertools.repeat(key, num) for key, num in sent_groups.items() - ): - channel = self._send_relation( - RelationTypes.ANNOTATION, - "m.reaction", - key=key, - access_token=access_tokens[idx], - ) - self.assertEqual(200, channel.code, channel.json_body) - - idx += 1 - idx %= len(access_tokens) - - prev_token: Optional[str] = None - found_groups: Dict[str, int] = {} - for _ in range(20): - from_token = "" - if prev_token: - from_token = "&from=" + prev_token - - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}?limit=1{from_token}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - - self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) - - for groups in channel.json_body["chunk"]: - # We only expect reactions - self.assertEqual(groups["type"], "m.reaction", channel.json_body) - - # We should only see each key once - self.assertNotIn(groups["key"], found_groups, channel.json_body) - - found_groups[groups["key"]] = groups["count"] - - next_batch = channel.json_body.get("next_batch") - - self.assertNotEqual(prev_token, next_batch) - prev_token = next_batch - - if not prev_token: - break - - self.assertEqual(sent_groups, found_groups) - - def test_aggregation_pagination_within_group(self) -> None: - """Test that we can paginate within an annotation group.""" - - # We need to create ten separate users to send each reaction. - access_tokens = [self.user_token, self.user2_token] - idx = 0 - while len(access_tokens) < 10: - user_id, token = self._create_user("test" + str(idx)) - idx += 1 - - self.helper.join(self.room, user=user_id, tok=token) - access_tokens.append(token) - - idx = 0 - expected_event_ids = [] - for _ in range(10): - channel = self._send_relation( - RelationTypes.ANNOTATION, - "m.reaction", - key="👍", - access_token=access_tokens[idx], - ) - self.assertEqual(200, channel.code, channel.json_body) - expected_event_ids.append(channel.json_body["event_id"]) - - idx += 1 - - # Also send a different type of reaction so that we test we don't see it - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") - self.assertEqual(200, channel.code, channel.json_body) - - prev_token = "" - found_event_ids: List[str] = [] - encoded_key = urllib.parse.quote_plus("👍".encode()) - for _ in range(20): - from_token = "" - if prev_token: - from_token = "&from=" + prev_token - - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}" - f"/aggregations/{self.parent_id}/{RelationTypes.ANNOTATION}" - f"/m.reaction/{encoded_key}?limit=1{from_token}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - - self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) - - found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) - - next_batch = channel.json_body.get("next_batch") - - self.assertNotEqual(prev_token, next_batch) - prev_token = next_batch - - if not prev_token: - break - - # We paginated backwards, so reverse - found_event_ids.reverse() - self.assertEqual(found_event_ids, expected_event_ids) - - # Reset and try again, but convert the tokens to the legacy format. - prev_token = "" - found_event_ids = [] - for _ in range(20): - from_token = "" - if prev_token: - from_token = "&from=" + self._stream_token_to_relation_token(prev_token) - - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}" - f"/aggregations/{self.parent_id}/{RelationTypes.ANNOTATION}" - f"/m.reaction/{encoded_key}?limit=1{from_token}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - - self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) - - found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) - - next_batch = channel.json_body.get("next_batch") - - self.assertNotEqual(prev_token, next_batch) - prev_token = next_batch - - if not prev_token: - break - - # We paginated backwards, so reverse - found_event_ids.reverse() - self.assertEqual(found_event_ids, expected_event_ids) - def test_aggregation(self) -> None: """Test that annotations get correctly aggregated.""" - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEqual(200, channel.code, channel.json_body) - - channel = self._send_relation( + self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + self._send_relation( RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token ) - self.assertEqual(200, channel.code, channel.json_body) - - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") - self.assertEqual(200, channel.code, channel.json_body) + self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") channel = self.make_request( "GET", @@ -618,220 +301,6 @@ class RelationsTestCase(BaseRelationsTestCase): ) self.assertEqual(400, channel.code, channel.json_body) - @unittest.override_config( - {"experimental_features": {"msc3440_enabled": True, "msc3666_enabled": True}} - ) - def test_bundled_aggregations(self) -> None: - """ - Test that annotations, references, and threads get correctly bundled. - - Note that this doesn't test against /relations since only thread relations - get bundled via that API. See test_aggregation_get_event_for_thread. - - See test_edit for a similar test for edits. - """ - # Setup by sending a variety of relations. - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEqual(200, channel.code, channel.json_body) - - channel = self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token - ) - self.assertEqual(200, channel.code, channel.json_body) - - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") - self.assertEqual(200, channel.code, channel.json_body) - - channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") - self.assertEqual(200, channel.code, channel.json_body) - reply_1 = channel.json_body["event_id"] - - channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") - self.assertEqual(200, channel.code, channel.json_body) - reply_2 = channel.json_body["event_id"] - - channel = self._send_relation(RelationTypes.THREAD, "m.room.test") - self.assertEqual(200, channel.code, channel.json_body) - - channel = self._send_relation(RelationTypes.THREAD, "m.room.test") - self.assertEqual(200, channel.code, channel.json_body) - thread_2 = channel.json_body["event_id"] - - def assert_bundle(event_json: JsonDict) -> None: - """Assert the expected values of the bundled aggregations.""" - relations_dict = event_json["unsigned"].get("m.relations") - - # Ensure the fields are as expected. - self.assertCountEqual( - relations_dict.keys(), - ( - RelationTypes.ANNOTATION, - RelationTypes.REFERENCE, - RelationTypes.THREAD, - ), - ) - - # Check the values of each field. - self.assertEqual( - { - "chunk": [ - {"type": "m.reaction", "key": "a", "count": 2}, - {"type": "m.reaction", "key": "b", "count": 1}, - ] - }, - relations_dict[RelationTypes.ANNOTATION], - ) - - self.assertEqual( - {"chunk": [{"event_id": reply_1}, {"event_id": reply_2}]}, - relations_dict[RelationTypes.REFERENCE], - ) - - self.assertEqual( - 2, - relations_dict[RelationTypes.THREAD].get("count"), - ) - self.assertTrue( - relations_dict[RelationTypes.THREAD].get("current_user_participated") - ) - # The latest thread event has some fields that don't matter. - self.assert_dict( - { - "content": { - "m.relates_to": { - "event_id": self.parent_id, - "rel_type": RelationTypes.THREAD, - } - }, - "event_id": thread_2, - "room_id": self.room, - "sender": self.user_id, - "type": "m.room.test", - "user_id": self.user_id, - }, - relations_dict[RelationTypes.THREAD].get("latest_event"), - ) - - # Request the event directly. - channel = self.make_request( - "GET", - f"/rooms/{self.room}/event/{self.parent_id}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - assert_bundle(channel.json_body) - - # Request the room messages. - channel = self.make_request( - "GET", - f"/rooms/{self.room}/messages?dir=b", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"])) - - # Request the room context. - channel = self.make_request( - "GET", - f"/rooms/{self.room}/context/{self.parent_id}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - assert_bundle(channel.json_body["event"]) - - # Request sync. - channel = self.make_request("GET", "/sync", access_token=self.user_token) - self.assertEqual(200, channel.code, channel.json_body) - room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] - self.assertTrue(room_timeline["limited"]) - assert_bundle(self._find_event_in_chunk(room_timeline["events"])) - - # Request search. - channel = self.make_request( - "POST", - "/search", - # Search term matches the parent message. - content={"search_categories": {"room_events": {"search_term": "Hi"}}}, - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - chunk = [ - result["result"] - for result in channel.json_body["search_categories"]["room_events"][ - "results" - ] - ] - assert_bundle(self._find_event_in_chunk(chunk)) - - def test_aggregation_get_event_for_annotation(self) -> None: - """Test that annotations do not get bundled aggregations included - when directly requested. - """ - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEqual(200, channel.code, channel.json_body) - annotation_id = channel.json_body["event_id"] - - # Annotate the annotation. - channel = self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=annotation_id - ) - self.assertEqual(200, channel.code, channel.json_body) - - channel = self.make_request( - "GET", - f"/rooms/{self.room}/event/{annotation_id}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - self.assertIsNone(channel.json_body["unsigned"].get("m.relations")) - - def test_aggregation_get_event_for_thread(self) -> None: - """Test that threads get bundled aggregations included when directly requested.""" - channel = self._send_relation(RelationTypes.THREAD, "m.room.test") - self.assertEqual(200, channel.code, channel.json_body) - thread_id = channel.json_body["event_id"] - - # Annotate the annotation. - channel = self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_id - ) - self.assertEqual(200, channel.code, channel.json_body) - - channel = self.make_request( - "GET", - f"/rooms/{self.room}/event/{thread_id}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - self.assertEqual( - channel.json_body["unsigned"].get("m.relations"), - { - RelationTypes.ANNOTATION: { - "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}] - }, - }, - ) - - # It should also be included when the entire thread is requested. - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - self.assertEqual(len(channel.json_body["chunk"]), 1) - - thread_message = channel.json_body["chunk"][0] - self.assertEqual( - thread_message["unsigned"].get("m.relations"), - { - RelationTypes.ANNOTATION: { - "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}] - }, - }, - ) - - @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) def test_ignore_invalid_room(self) -> None: """Test that we ignore invalid relations over federation.""" # Create another room and send a message in it. @@ -953,8 +422,6 @@ class RelationsTestCase(BaseRelationsTestCase): "m.room.message", content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, ) - self.assertEqual(200, channel.code, channel.json_body) - edit_event_id = channel.json_body["event_id"] def assert_bundle(event_json: JsonDict) -> None: @@ -1030,7 +497,7 @@ class RelationsTestCase(BaseRelationsTestCase): shouldn't be allowed, are correctly handled. """ - channel = self._send_relation( + self._send_relation( RelationTypes.REPLACE, "m.room.message", content={ @@ -1039,7 +506,6 @@ class RelationsTestCase(BaseRelationsTestCase): "m.new_content": {"msgtype": "m.text", "body": "First edit"}, }, ) - self.assertEqual(200, channel.code, channel.json_body) new_body = {"msgtype": "m.text", "body": "I've been edited!"} channel = self._send_relation( @@ -1047,11 +513,9 @@ class RelationsTestCase(BaseRelationsTestCase): "m.room.message", content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, ) - self.assertEqual(200, channel.code, channel.json_body) - edit_event_id = channel.json_body["event_id"] - channel = self._send_relation( + self._send_relation( RelationTypes.REPLACE, "m.room.message.WRONG_TYPE", content={ @@ -1060,7 +524,6 @@ class RelationsTestCase(BaseRelationsTestCase): "m.new_content": {"msgtype": "m.text", "body": "Edit, but wrong type"}, }, ) - self.assertEqual(200, channel.code, channel.json_body) channel = self.make_request( "GET", @@ -1091,7 +554,6 @@ class RelationsTestCase(BaseRelationsTestCase): "m.room.message", content={"msgtype": "m.text", "body": "A reply!"}, ) - self.assertEqual(200, channel.code, channel.json_body) reply = channel.json_body["event_id"] new_body = {"msgtype": "m.text", "body": "I've been edited!"} @@ -1101,8 +563,6 @@ class RelationsTestCase(BaseRelationsTestCase): content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, parent_id=reply, ) - self.assertEqual(200, channel.code, channel.json_body) - edit_event_id = channel.json_body["event_id"] channel = self.make_request( @@ -1138,7 +598,6 @@ class RelationsTestCase(BaseRelationsTestCase): {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict ) - @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) def test_edit_thread(self) -> None: """Test that editing a thread works.""" @@ -1148,17 +607,15 @@ class RelationsTestCase(BaseRelationsTestCase): "m.room.message", content={"msgtype": "m.text", "body": "A threaded reply!"}, ) - self.assertEqual(200, channel.code, channel.json_body) threaded_event_id = channel.json_body["event_id"] new_body = {"msgtype": "m.text", "body": "I've been edited!"} - channel = self._send_relation( + self._send_relation( RelationTypes.REPLACE, "m.room.message", content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, parent_id=threaded_event_id, ) - self.assertEqual(200, channel.code, channel.json_body) # Fetch the thread root, to get the bundled aggregation for the thread. channel = self.make_request( @@ -1190,11 +647,10 @@ class RelationsTestCase(BaseRelationsTestCase): "m.new_content": new_body, }, ) - self.assertEqual(200, channel.code, channel.json_body) edit_event_id = channel.json_body["event_id"] # Edit the edit event. - channel = self._send_relation( + self._send_relation( RelationTypes.REPLACE, "m.room.message", content={ @@ -1204,7 +660,6 @@ class RelationsTestCase(BaseRelationsTestCase): }, parent_id=edit_event_id, ) - self.assertEqual(200, channel.code, channel.json_body) # Request the original event. channel = self.make_request( @@ -1231,7 +686,6 @@ class RelationsTestCase(BaseRelationsTestCase): def test_unknown_relations(self) -> None: """Unknown relations should be accepted.""" channel = self._send_relation("m.relation.test", "m.room.test") - self.assertEqual(200, channel.code, channel.json_body) event_id = channel.json_body["event_id"] channel = self.make_request( @@ -1272,28 +726,15 @@ class RelationsTestCase(BaseRelationsTestCase): self.assertEqual(200, channel.code, channel.json_body) self.assertEqual(channel.json_body["chunk"], []) - def _find_event_in_chunk(self, events: List[JsonDict]) -> JsonDict: - """ - Find the parent event in a chunk of events and assert that it has the proper bundled aggregations. - """ - for event in events: - if event["event_id"] == self.parent_id: - return event - - raise AssertionError(f"Event {self.parent_id} not found in chunk") - def test_background_update(self) -> None: """Test the event_arbitrary_relations background update.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍") - self.assertEqual(200, channel.code, channel.json_body) annotation_event_id_good = channel.json_body["event_id"] channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="A") - self.assertEqual(200, channel.code, channel.json_body) annotation_event_id_bad = channel.json_body["event_id"] channel = self._send_relation(RelationTypes.THREAD, "m.room.test") - self.assertEqual(200, channel.code, channel.json_body) thread_event_id = channel.json_body["event_id"] # Clean-up the table as if the inserts did not happen during event creation. @@ -1345,8 +786,638 @@ class RelationsTestCase(BaseRelationsTestCase): ) +class RelationPaginationTestCase(BaseRelationsTestCase): + def test_basic_paginate_relations(self) -> None: + """Tests that calling pagination API correctly the latest relations.""" + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + first_annotation_id = channel.json_body["event_id"] + + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") + second_annotation_id = channel.json_body["event_id"] + + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + + # We expect to get back a single pagination result, which is the latest + # full relation event we sent above. + self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) + self.assert_dict( + { + "event_id": second_annotation_id, + "sender": self.user_id, + "type": "m.reaction", + }, + channel.json_body["chunk"][0], + ) + + # We also expect to get the original event (the id of which is self.parent_id) + self.assertEqual( + channel.json_body["original_event"]["event_id"], self.parent_id + ) + + # Make sure next_batch has something in it that looks like it could be a + # valid token. + self.assertIsInstance( + channel.json_body.get("next_batch"), str, channel.json_body + ) + + # Request the relations again, but with a different direction. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/relations" + f"/{self.parent_id}?limit=1&org.matrix.msc3715.dir=f", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + + # We expect to get back a single pagination result, which is the earliest + # full relation event we sent above. + self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) + self.assert_dict( + { + "event_id": first_annotation_id, + "sender": self.user_id, + "type": "m.reaction", + }, + channel.json_body["chunk"][0], + ) + + def test_repeated_paginate_relations(self) -> None: + """Test that if we paginate using a limit and tokens then we get the + expected events. + """ + + expected_event_ids = [] + for idx in range(10): + channel = self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", chr(ord("a") + idx) + ) + expected_event_ids.append(channel.json_body["event_id"]) + + prev_token = "" + found_event_ids: List[str] = [] + for _ in range(20): + from_token = "" + if prev_token: + from_token = "&from=" + prev_token + + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + + found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) + next_batch = channel.json_body.get("next_batch") + + self.assertNotEqual(prev_token, next_batch) + prev_token = next_batch + + if not prev_token: + break + + # We paginated backwards, so reverse + found_event_ids.reverse() + self.assertEqual(found_event_ids, expected_event_ids) + + def test_pagination_from_sync_and_messages(self) -> None: + """Pagination tokens from /sync and /messages can be used to paginate /relations.""" + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "A") + annotation_id = channel.json_body["event_id"] + # Send an event after the relation events. + self.helper.send(self.room, body="Latest event", tok=self.user_token) + + # Request /sync, limiting it such that only the latest event is returned + # (and not the relation). + filter = urllib.parse.quote_plus(b'{"room": {"timeline": {"limit": 1}}}') + channel = self.make_request( + "GET", f"/sync?filter={filter}", access_token=self.user_token + ) + self.assertEqual(200, channel.code, channel.json_body) + room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] + sync_prev_batch = room_timeline["prev_batch"] + self.assertIsNotNone(sync_prev_batch) + # Ensure the relation event is not in the batch returned from /sync. + self.assertNotIn( + annotation_id, [ev["event_id"] for ev in room_timeline["events"]] + ) + + # Request /messages, limiting it such that only the latest event is + # returned (and not the relation). + channel = self.make_request( + "GET", + f"/rooms/{self.room}/messages?dir=b&limit=1", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + messages_end = channel.json_body["end"] + self.assertIsNotNone(messages_end) + # Ensure the relation event is not in the chunk returned from /messages. + self.assertNotIn( + annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]] + ) + + # Request /relations with the pagination tokens received from both the + # /sync and /messages responses above, in turn. + # + # This is a tiny bit silly since the client wouldn't know the parent ID + # from the requests above; consider the parent ID to be known from a + # previous /sync. + for from_token in (sync_prev_batch, messages_end): + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?from={from_token}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + + # The relation should be in the returned chunk. + self.assertIn( + annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]] + ) + + def test_aggregation_pagination_groups(self) -> None: + """Test that we can paginate annotation groups correctly.""" + + # We need to create ten separate users to send each reaction. + access_tokens = [self.user_token, self.user2_token] + idx = 0 + while len(access_tokens) < 10: + user_id, token = self._create_user("test" + str(idx)) + idx += 1 + + self.helper.join(self.room, user=user_id, tok=token) + access_tokens.append(token) + + idx = 0 + sent_groups = {"👍": 10, "a": 7, "b": 5, "c": 3, "d": 2, "e": 1} + for key in itertools.chain.from_iterable( + itertools.repeat(key, num) for key, num in sent_groups.items() + ): + self._send_relation( + RelationTypes.ANNOTATION, + "m.reaction", + key=key, + access_token=access_tokens[idx], + ) + + idx += 1 + idx %= len(access_tokens) + + prev_token: Optional[str] = None + found_groups: Dict[str, int] = {} + for _ in range(20): + from_token = "" + if prev_token: + from_token = "&from=" + prev_token + + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}?limit=1{from_token}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + + self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) + + for groups in channel.json_body["chunk"]: + # We only expect reactions + self.assertEqual(groups["type"], "m.reaction", channel.json_body) + + # We should only see each key once + self.assertNotIn(groups["key"], found_groups, channel.json_body) + + found_groups[groups["key"]] = groups["count"] + + next_batch = channel.json_body.get("next_batch") + + self.assertNotEqual(prev_token, next_batch) + prev_token = next_batch + + if not prev_token: + break + + self.assertEqual(sent_groups, found_groups) + + def test_aggregation_pagination_within_group(self) -> None: + """Test that we can paginate within an annotation group.""" + + # We need to create ten separate users to send each reaction. + access_tokens = [self.user_token, self.user2_token] + idx = 0 + while len(access_tokens) < 10: + user_id, token = self._create_user("test" + str(idx)) + idx += 1 + + self.helper.join(self.room, user=user_id, tok=token) + access_tokens.append(token) + + idx = 0 + expected_event_ids = [] + for _ in range(10): + channel = self._send_relation( + RelationTypes.ANNOTATION, + "m.reaction", + key="👍", + access_token=access_tokens[idx], + ) + expected_event_ids.append(channel.json_body["event_id"]) + + idx += 1 + + # Also send a different type of reaction so that we test we don't see it + self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") + + prev_token = "" + found_event_ids: List[str] = [] + encoded_key = urllib.parse.quote_plus("👍".encode()) + for _ in range(20): + from_token = "" + if prev_token: + from_token = "&from=" + prev_token + + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}" + f"/aggregations/{self.parent_id}/{RelationTypes.ANNOTATION}" + f"/m.reaction/{encoded_key}?limit=1{from_token}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + + self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) + + found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) + + next_batch = channel.json_body.get("next_batch") + + self.assertNotEqual(prev_token, next_batch) + prev_token = next_batch + + if not prev_token: + break + + # We paginated backwards, so reverse + found_event_ids.reverse() + self.assertEqual(found_event_ids, expected_event_ids) + + +class BundledAggregationsTestCase(BaseRelationsTestCase): + """ + See RelationsTestCase.test_edit for a similar test for edits. + + Note that this doesn't test against /relations since only thread relations + get bundled via that API. See test_aggregation_get_event_for_thread. + """ + + def _test_bundled_aggregations( + self, + relation_type: str, + assertion_callable: Callable[[JsonDict], None], + expected_db_txn_for_event: int, + ) -> None: + """ + Makes requests to various endpoints which should include bundled aggregations + and then calls an assertion function on the bundled aggregations. + + Args: + relation_type: The field to search for in the `m.relations` field in unsigned. + assertion_callable: Called with the contents of unsigned["m.relations"][relation_type] + for relation-specific assertions. + expected_db_txn_for_event: The number of database transactions which + are expected for a call to /event/. + """ + + def assert_bundle(event_json: JsonDict) -> None: + """Assert the expected values of the bundled aggregations.""" + relations_dict = event_json["unsigned"].get("m.relations") + + # Ensure the fields are as expected. + self.assertCountEqual(relations_dict.keys(), (relation_type,)) + assertion_callable(relations_dict[relation_type]) + + # Request the event directly. + channel = self.make_request( + "GET", + f"/rooms/{self.room}/event/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + assert_bundle(channel.json_body) + assert channel.resource_usage is not None + self.assertEqual(channel.resource_usage.db_txn_count, expected_db_txn_for_event) + + # Request the room messages. + channel = self.make_request( + "GET", + f"/rooms/{self.room}/messages?dir=b", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"])) + + # Request the room context. + channel = self.make_request( + "GET", + f"/rooms/{self.room}/context/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + assert_bundle(channel.json_body["event"]) + + # Request sync. + filter = urllib.parse.quote_plus(b'{"room": {"timeline": {"limit": 4}}}') + channel = self.make_request( + "GET", f"/sync?filter={filter}", access_token=self.user_token + ) + self.assertEqual(200, channel.code, channel.json_body) + room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] + self.assertTrue(room_timeline["limited"]) + assert_bundle(self._find_event_in_chunk(room_timeline["events"])) + + # Request search. + channel = self.make_request( + "POST", + "/search", + # Search term matches the parent message. + content={"search_categories": {"room_events": {"search_term": "Hi"}}}, + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + chunk = [ + result["result"] + for result in channel.json_body["search_categories"]["room_events"][ + "results" + ] + ] + assert_bundle(self._find_event_in_chunk(chunk)) + + @unittest.override_config({"experimental_features": {"msc3666_enabled": True}}) + def test_annotation(self) -> None: + """ + Test that annotations get correctly bundled. + """ + # Setup by sending a variety of relations. + self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token + ) + self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") + + def assert_annotations(bundled_aggregations: JsonDict) -> None: + self.assertEqual( + { + "chunk": [ + {"type": "m.reaction", "key": "a", "count": 2}, + {"type": "m.reaction", "key": "b", "count": 1}, + ] + }, + bundled_aggregations, + ) + + self._test_bundled_aggregations(RelationTypes.ANNOTATION, assert_annotations, 7) + + @unittest.override_config({"experimental_features": {"msc3666_enabled": True}}) + def test_reference(self) -> None: + """ + Test that references get correctly bundled. + """ + channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") + reply_1 = channel.json_body["event_id"] + + channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") + reply_2 = channel.json_body["event_id"] + + def assert_annotations(bundled_aggregations: JsonDict) -> None: + self.assertEqual( + {"chunk": [{"event_id": reply_1}, {"event_id": reply_2}]}, + bundled_aggregations, + ) + + self._test_bundled_aggregations(RelationTypes.REFERENCE, assert_annotations, 7) + + @unittest.override_config({"experimental_features": {"msc3666_enabled": True}}) + def test_thread(self) -> None: + """ + Test that threads get correctly bundled. + """ + self._send_relation(RelationTypes.THREAD, "m.room.test") + channel = self._send_relation(RelationTypes.THREAD, "m.room.test") + thread_2 = channel.json_body["event_id"] + + def assert_annotations(bundled_aggregations: JsonDict) -> None: + self.assertEqual(2, bundled_aggregations.get("count")) + self.assertTrue(bundled_aggregations.get("current_user_participated")) + # The latest thread event has some fields that don't matter. + self.assert_dict( + { + "content": { + "m.relates_to": { + "event_id": self.parent_id, + "rel_type": RelationTypes.THREAD, + } + }, + "event_id": thread_2, + "sender": self.user_id, + "type": "m.room.test", + }, + bundled_aggregations.get("latest_event"), + ) + + self._test_bundled_aggregations(RelationTypes.THREAD, assert_annotations, 9) + + def test_aggregation_get_event_for_annotation(self) -> None: + """Test that annotations do not get bundled aggregations included + when directly requested. + """ + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + annotation_id = channel.json_body["event_id"] + + # Annotate the annotation. + self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=annotation_id + ) + + channel = self.make_request( + "GET", + f"/rooms/{self.room}/event/{annotation_id}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + self.assertIsNone(channel.json_body["unsigned"].get("m.relations")) + + def test_aggregation_get_event_for_thread(self) -> None: + """Test that threads get bundled aggregations included when directly requested.""" + channel = self._send_relation(RelationTypes.THREAD, "m.room.test") + thread_id = channel.json_body["event_id"] + + # Annotate the annotation. + self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_id + ) + + channel = self.make_request( + "GET", + f"/rooms/{self.room}/event/{thread_id}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + self.assertEqual( + channel.json_body["unsigned"].get("m.relations"), + { + RelationTypes.ANNOTATION: { + "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}] + }, + }, + ) + + # It should also be included when the entire thread is requested. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + self.assertEqual(len(channel.json_body["chunk"]), 1) + + thread_message = channel.json_body["chunk"][0] + self.assertEqual( + thread_message["unsigned"].get("m.relations"), + { + RelationTypes.ANNOTATION: { + "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}] + }, + }, + ) + + def test_bundled_aggregations_with_filter(self) -> None: + """ + If "unsigned" is an omitted field (due to filtering), adding the bundled + aggregations should not break. + + Note that the spec allows for a server to return additional fields beyond + what is specified. + """ + self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + + # Note that the sync filter does not include "unsigned" as a field. + filter = urllib.parse.quote_plus( + b'{"event_fields": ["content", "event_id"], "room": {"timeline": {"limit": 3}}}' + ) + channel = self.make_request( + "GET", f"/sync?filter={filter}", access_token=self.user_token + ) + self.assertEqual(200, channel.code, channel.json_body) + + # Ensure the timeline is limited, find the parent event. + room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] + self.assertTrue(room_timeline["limited"]) + parent_event = self._find_event_in_chunk(room_timeline["events"]) + + # Ensure there's bundled aggregations on it. + self.assertIn("unsigned", parent_event) + self.assertIn("m.relations", parent_event["unsigned"]) + + +class RelationIgnoredUserTestCase(BaseRelationsTestCase): + """Relations sent from an ignored user should be ignored.""" + + def _test_ignored_user( + self, allowed_event_ids: List[str], ignored_event_ids: List[str] + ) -> None: + """ + Fetch the relations and ensure they're all there, then ignore user2, and + repeat. + """ + # Get the relations. + event_ids = self._get_related_events() + self.assertCountEqual(event_ids, allowed_event_ids + ignored_event_ids) + + # Ignore user2 and re-do the requests. + self.get_success( + self.store.add_account_data_for_user( + self.user_id, + AccountDataTypes.IGNORED_USER_LIST, + {"ignored_users": {self.user2_id: {}}}, + ) + ) + + # Get the relations. + event_ids = self._get_related_events() + self.assertCountEqual(event_ids, allowed_event_ids) + + def test_annotation(self) -> None: + """Annotations should ignore""" + # Send 2 from us, 2 from the to be ignored user. + allowed_event_ids = [] + ignored_event_ids = [] + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") + allowed_event_ids.append(channel.json_body["event_id"]) + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="b") + allowed_event_ids.append(channel.json_body["event_id"]) + channel = self._send_relation( + RelationTypes.ANNOTATION, + "m.reaction", + key="a", + access_token=self.user2_token, + ) + ignored_event_ids.append(channel.json_body["event_id"]) + channel = self._send_relation( + RelationTypes.ANNOTATION, + "m.reaction", + key="c", + access_token=self.user2_token, + ) + ignored_event_ids.append(channel.json_body["event_id"]) + + self._test_ignored_user(allowed_event_ids, ignored_event_ids) + + def test_reference(self) -> None: + """Annotations should ignore""" + channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") + allowed_event_ids = [channel.json_body["event_id"]] + + channel = self._send_relation( + RelationTypes.REFERENCE, "m.room.test", access_token=self.user2_token + ) + ignored_event_ids = [channel.json_body["event_id"]] + + self._test_ignored_user(allowed_event_ids, ignored_event_ids) + + def test_thread(self) -> None: + """Annotations should ignore""" + channel = self._send_relation(RelationTypes.THREAD, "m.room.test") + allowed_event_ids = [channel.json_body["event_id"]] + + channel = self._send_relation( + RelationTypes.THREAD, "m.room.test", access_token=self.user2_token + ) + ignored_event_ids = [channel.json_body["event_id"]] + + self._test_ignored_user(allowed_event_ids, ignored_event_ids) + + class RelationRedactionTestCase(BaseRelationsTestCase): - """Test the behaviour of relations when the parent or child event is redacted.""" + """ + Test the behaviour of relations when the parent or child event is redacted. + + The behaviour of each relation type is subtly different which causes the tests + to be a bit repetitive, they follow a naming scheme of: + + test_redact_(relation|parent)_{relation_type} + + The first bit of "relation" means that the event with the relation defined + on it (the child event) is to be redacted. A "parent" means that the target + of the relation (the parent event) is to be redacted. + + The relation_type describes which type of relation is under test (i.e. it is + related to the value of rel_type in the event content). + """ def _redact(self, event_id: str) -> None: channel = self.make_request( @@ -1358,40 +1429,116 @@ class RelationRedactionTestCase(BaseRelationsTestCase): self.assertEqual(200, channel.code, channel.json_body) def test_redact_relation_annotation(self) -> None: - """Test that annotations of an event are properly handled after the + """ + Test that annotations of an event are properly handled after the annotation is redacted. + + The redacted relation should not be included in bundled aggregations or + the response to relations. """ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEqual(200, channel.code, channel.json_body) to_redact_event_id = channel.json_body["event_id"] channel = self._send_relation( RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token ) - self.assertEqual(200, channel.code, channel.json_body) + unredacted_event_id = channel.json_body["event_id"] + + # Both relations should exist. + event_ids = self._get_related_events() + relations = self._get_bundled_aggregations() + self.assertCountEqual(event_ids, [to_redact_event_id, unredacted_event_id]) + self.assertEquals( + relations["m.annotation"], + {"chunk": [{"type": "m.reaction", "key": "a", "count": 2}]}, + ) + + # Both relations appear in the aggregation. + chunk = self._get_aggregations() + self.assertEqual(chunk, [{"type": "m.reaction", "key": "a", "count": 2}]) # Redact one of the reactions. self._redact(to_redact_event_id) - # Ensure that the aggregations are correct. - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}", - access_token=self.user_token, + # The unredacted relation should still exist. + event_ids = self._get_related_events() + relations = self._get_bundled_aggregations() + self.assertEquals(event_ids, [unredacted_event_id]) + self.assertEquals( + relations["m.annotation"], + {"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]}, ) - self.assertEqual(200, channel.code, channel.json_body) + # The unredacted aggregation should still exist. + chunk = self._get_aggregations() + self.assertEqual(chunk, [{"type": "m.reaction", "key": "a", "count": 1}]) + + def test_redact_relation_thread(self) -> None: + """ + Test that thread replies are properly handled after the thread reply redacted. + + The redacted event should not be included in bundled aggregations or + the response to relations. + """ + channel = self._send_relation( + RelationTypes.THREAD, + EventTypes.Message, + content={"body": "reply 1", "msgtype": "m.text"}, + ) + unredacted_event_id = channel.json_body["event_id"] + + # Note that the *last* event in the thread is redacted, as that gets + # included in the bundled aggregation. + channel = self._send_relation( + RelationTypes.THREAD, + EventTypes.Message, + content={"body": "reply 2", "msgtype": "m.text"}, + ) + to_redact_event_id = channel.json_body["event_id"] + + # Both relations exist. + event_ids = self._get_related_events() + relations = self._get_bundled_aggregations() + self.assertEquals(event_ids, [to_redact_event_id, unredacted_event_id]) + self.assertDictContainsSubset( + { + "count": 2, + "current_user_participated": True, + }, + relations[RelationTypes.THREAD], + ) + # And the latest event returned is the event that will be redacted. + self.assertEqual( + relations[RelationTypes.THREAD]["latest_event"]["event_id"], + to_redact_event_id, + ) + + # Redact one of the reactions. + self._redact(to_redact_event_id) + + # The unredacted relation should still exist. + event_ids = self._get_related_events() + relations = self._get_bundled_aggregations() + self.assertEquals(event_ids, [unredacted_event_id]) + self.assertDictContainsSubset( + { + "count": 1, + "current_user_participated": True, + }, + relations[RelationTypes.THREAD], + ) + # And the latest event is now the unredacted event. self.assertEqual( - channel.json_body, - {"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]}, + relations[RelationTypes.THREAD]["latest_event"]["event_id"], + unredacted_event_id, ) - def test_redact_relation_edit(self) -> None: + def test_redact_parent_edit(self) -> None: """Test that edits of an event are redacted when the original event is redacted. """ # Add a relation - channel = self._send_relation( + self._send_relation( RelationTypes.REPLACE, "m.room.message", parent_id=self.parent_id, @@ -1401,54 +1548,83 @@ class RelationRedactionTestCase(BaseRelationsTestCase): "m.new_content": {"msgtype": "m.text", "body": "First edit"}, }, ) - self.assertEqual(200, channel.code, channel.json_body) # Check the relation is returned - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/relations" - f"/{self.parent_id}/m.replace/m.room.message", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - - self.assertIn("chunk", channel.json_body) - self.assertEqual(len(channel.json_body["chunk"]), 1) + event_ids = self._get_related_events() + relations = self._get_bundled_aggregations() + self.assertEqual(len(event_ids), 1) + self.assertIn(RelationTypes.REPLACE, relations) # Redact the original event self._redact(self.parent_id) - # Try to check for remaining m.replace relations - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/relations" - f"/{self.parent_id}/m.replace/m.room.message", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) + # The relations are not returned. + event_ids = self._get_related_events() + relations = self._get_bundled_aggregations() + self.assertEqual(len(event_ids), 0) + self.assertEqual(relations, {}) - # Check that no relations are returned - self.assertIn("chunk", channel.json_body) - self.assertEqual(channel.json_body["chunk"], []) - - def test_redact_parent(self) -> None: - """Test that annotations of an event are redacted when the original event + def test_redact_parent_annotation(self) -> None: + """Test that annotations of an event are viewable when the original event is redacted. """ # Add a relation channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍") - self.assertEqual(200, channel.code, channel.json_body) + related_event_id = channel.json_body["event_id"] + + # The relations should exist. + event_ids = self._get_related_events() + relations = self._get_bundled_aggregations() + self.assertEqual(len(event_ids), 1) + self.assertIn(RelationTypes.ANNOTATION, relations) + + # The aggregation should exist. + chunk = self._get_aggregations() + self.assertEqual(chunk, [{"type": "m.reaction", "key": "👍", "count": 1}]) # Redact the original event. self._redact(self.parent_id) - # Check that aggregations returns zero - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}/m.annotation/m.reaction", - access_token=self.user_token, + # The relations are returned. + event_ids = self._get_related_events() + relations = self._get_bundled_aggregations() + self.assertEquals(event_ids, [related_event_id]) + self.assertEquals( + relations["m.annotation"], + {"chunk": [{"type": "m.reaction", "key": "👍", "count": 1}]}, ) - self.assertEqual(200, channel.code, channel.json_body) - self.assertIn("chunk", channel.json_body) - self.assertEqual(channel.json_body["chunk"], []) + # There's nothing to aggregate. + chunk = self._get_aggregations() + self.assertEqual(chunk, [{"count": 1, "key": "👍", "type": "m.reaction"}]) + + @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) + def test_redact_parent_thread(self) -> None: + """ + Test that thread replies are still available when the root event is redacted. + """ + channel = self._send_relation( + RelationTypes.THREAD, + EventTypes.Message, + content={"body": "reply 1", "msgtype": "m.text"}, + ) + related_event_id = channel.json_body["event_id"] + + # Redact one of the reactions. + self._redact(self.parent_id) + + # The unredacted relation should still exist. + event_ids = self._get_related_events() + relations = self._get_bundled_aggregations() + self.assertEquals(len(event_ids), 1) + self.assertDictContainsSubset( + { + "count": 1, + "current_user_participated": True, + }, + relations[RelationTypes.THREAD], + ) + self.assertEqual( + relations[RelationTypes.THREAD]["latest_event"]["event_id"], + related_event_id, + ) diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py index f3bf8d0934..7b8fe6d025 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py @@ -24,6 +24,7 @@ from synapse.util import Clock from synapse.visibility import filter_events_for_client from tests import unittest +from tests.unittest import override_config one_hour_ms = 3600000 one_day_ms = one_hour_ms * 24 @@ -38,7 +39,10 @@ class RetentionTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() - config["retention"] = { + + # merge this default retention config with anything that was specified in + # @override_config + retention_config = { "enabled": True, "default_policy": { "min_lifetime": one_day_ms, @@ -47,6 +51,8 @@ class RetentionTestCase(unittest.HomeserverTestCase): "allowed_lifetime_min": one_day_ms, "allowed_lifetime_max": one_day_ms * 3, } + retention_config.update(config.get("retention", {})) + config["retention"] = retention_config self.hs = self.setup_test_homeserver(config=config) @@ -115,22 +121,20 @@ class RetentionTestCase(unittest.HomeserverTestCase): self._test_retention_event_purged(room_id, one_day_ms * 2) + @override_config({"retention": {"purge_jobs": [{"interval": "5d"}]}}) def test_visibility(self) -> None: """Tests that synapse.visibility.filter_events_for_client correctly filters out - outdated events + outdated events, even if the purge job hasn't got to them yet. + + We do this by setting a very long time between purge jobs. """ store = self.hs.get_datastores().main storage = self.hs.get_storage() room_id = self.helper.create_room_as(self.user_id, tok=self.token) - events = [] # Send a first event, which should be filtered out at the end of the test. resp = self.helper.send(room_id=room_id, body="1", tok=self.token) - - # Get the event from the store so that we end up with a FrozenEvent that we can - # give to filter_events_for_client. We need to do this now because the event won't - # be in the database anymore after it has expired. - events.append(self.get_success(store.get_event(resp.get("event_id")))) + first_event_id = resp.get("event_id") # Advance the time by 2 days. We're using the default retention policy, therefore # after this the first event will still be valid. @@ -138,16 +142,17 @@ class RetentionTestCase(unittest.HomeserverTestCase): # Send another event, which shouldn't get filtered out. resp = self.helper.send(room_id=room_id, body="2", tok=self.token) - valid_event_id = resp.get("event_id") - events.append(self.get_success(store.get_event(valid_event_id))) - # Advance the time by another 2 days. After this, the first event should be # outdated but not the second one. self.reactor.advance(one_day_ms * 2 / 1000) - # Run filter_events_for_client with our list of FrozenEvents. + # Fetch the events, and run filter_events_for_client on them + events = self.get_success( + store.get_events_as_list([first_event_id, valid_event_id]) + ) + self.assertEqual(2, len(events), "events retrieved from database") filtered_events = self.get_success( filter_events_for_client(storage, self.user_id, events) ) diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 37866ee330..3a9617d6da 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -2141,21 +2141,19 @@ class RelationsTestCase(unittest.HomeserverTestCase): def test_filter_relation_senders(self) -> None: # Messages which second user reacted to. - filter = {"io.element.relation_senders": [self.second_user_id]} + filter = {"related_by_senders": [self.second_user_id]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) self.assertEqual(chunk[0]["event_id"], self.event_id_1) # Messages which third user reacted to. - filter = {"io.element.relation_senders": [self.third_user_id]} + filter = {"related_by_senders": [self.third_user_id]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) self.assertEqual(chunk[0]["event_id"], self.event_id_2) # Messages which either user reacted to. - filter = { - "io.element.relation_senders": [self.second_user_id, self.third_user_id] - } + filter = {"related_by_senders": [self.second_user_id, self.third_user_id]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 2, chunk) self.assertCountEqual( @@ -2164,20 +2162,20 @@ class RelationsTestCase(unittest.HomeserverTestCase): def test_filter_relation_type(self) -> None: # Messages which have annotations. - filter = {"io.element.relation_types": [RelationTypes.ANNOTATION]} + filter = {"related_by_rel_types": [RelationTypes.ANNOTATION]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) self.assertEqual(chunk[0]["event_id"], self.event_id_1) # Messages which have references. - filter = {"io.element.relation_types": [RelationTypes.REFERENCE]} + filter = {"related_by_rel_types": [RelationTypes.REFERENCE]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) self.assertEqual(chunk[0]["event_id"], self.event_id_2) # Messages which have either annotations or references. filter = { - "io.element.relation_types": [ + "related_by_rel_types": [ RelationTypes.ANNOTATION, RelationTypes.REFERENCE, ] @@ -2191,8 +2189,8 @@ class RelationsTestCase(unittest.HomeserverTestCase): def test_filter_relation_senders_and_type(self) -> None: # Messages which second user reacted to. filter = { - "io.element.relation_senders": [self.second_user_id], - "io.element.relation_types": [RelationTypes.ANNOTATION], + "related_by_senders": [self.second_user_id], + "related_by_rel_types": [RelationTypes.ANNOTATION], } chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py index 58f1ea11b7..e7de67e3a3 100644 --- a/tests/rest/client/test_third_party_rules.py +++ b/tests/rest/client/test_third_party_rules.py @@ -775,3 +775,124 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): self.assertEqual(args[0], user_id) self.assertFalse(args[1]) self.assertTrue(args[2]) + + def test_check_can_deactivate_user(self) -> None: + """Tests that the on_user_deactivation_status_changed module callback is called + correctly when processing a user's deactivation. + """ + # Register a mocked callback. + deactivation_mock = Mock(return_value=make_awaitable(False)) + third_party_rules = self.hs.get_third_party_event_rules() + third_party_rules._check_can_deactivate_user_callbacks.append( + deactivation_mock, + ) + + # Register a user that we'll deactivate. + user_id = self.register_user("altan", "password") + tok = self.login("altan", "password") + + # Deactivate that user. + channel = self.make_request( + "POST", + "/_matrix/client/v3/account/deactivate", + { + "auth": { + "type": LoginType.PASSWORD, + "password": "password", + "identifier": { + "type": "m.id.user", + "user": user_id, + }, + }, + "erase": True, + }, + access_token=tok, + ) + + # Check that the deactivation was blocked + self.assertEqual(channel.code, 403, channel.json_body) + + # Check that the mock was called once. + deactivation_mock.assert_called_once() + args = deactivation_mock.call_args[0] + + # Check that the mock was called with the right user ID + self.assertEqual(args[0], user_id) + + # Check that the request was not made by an admin + self.assertEqual(args[1], False) + + def test_check_can_deactivate_user_admin(self) -> None: + """Tests that the on_user_deactivation_status_changed module callback is called + correctly when processing a user's deactivation triggered by a server admin. + """ + # Register a mocked callback. + deactivation_mock = Mock(return_value=make_awaitable(False)) + third_party_rules = self.hs.get_third_party_event_rules() + third_party_rules._check_can_deactivate_user_callbacks.append( + deactivation_mock, + ) + + # Register an admin user. + self.register_user("admin", "password", admin=True) + admin_tok = self.login("admin", "password") + + # Register a user that we'll deactivate. + user_id = self.register_user("altan", "password") + + # Deactivate the user. + channel = self.make_request( + "PUT", + "/_synapse/admin/v2/users/%s" % user_id, + {"deactivated": True}, + access_token=admin_tok, + ) + + # Check that the deactivation was blocked + self.assertEqual(channel.code, 403, channel.json_body) + + # Check that the mock was called once. + deactivation_mock.assert_called_once() + args = deactivation_mock.call_args[0] + + # Check that the mock was called with the right user ID + self.assertEqual(args[0], user_id) + + # Check that the mock was made by an admin + self.assertEqual(args[1], True) + + def test_check_can_shutdown_room(self) -> None: + """Tests that the check_can_shutdown_room module callback is called + correctly when processing an admin's shutdown room request. + """ + # Register a mocked callback. + shutdown_mock = Mock(return_value=make_awaitable(False)) + third_party_rules = self.hs.get_third_party_event_rules() + third_party_rules._check_can_shutdown_room_callbacks.append( + shutdown_mock, + ) + + # Register an admin user. + admin_user_id = self.register_user("admin", "password", admin=True) + admin_tok = self.login("admin", "password") + + # Shutdown the room. + channel = self.make_request( + "DELETE", + "/_synapse/admin/v2/rooms/%s" % self.room_id, + {}, + access_token=admin_tok, + ) + + # Check that the shutdown was blocked + self.assertEqual(channel.code, 403, channel.json_body) + + # Check that the mock was called once. + shutdown_mock.assert_called_once() + args = shutdown_mock.call_args[0] + + # Check that the mock was called with the right user ID + self.assertEqual(args[0], admin_user_id) + + # Check that the mock was called with the right room ID + self.assertEqual(args[1], self.room_id) diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py index 3b5747cb12..8d8251b2ac 100644 --- a/tests/rest/client/test_transactions.py +++ b/tests/rest/client/test_transactions.py @@ -1,3 +1,18 @@ +# Copyright 2018-2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from http import HTTPStatus from unittest.mock import Mock, call from twisted.internet import defer, reactor @@ -11,14 +26,14 @@ from tests.utils import MockClock class HttpTransactionCacheTestCase(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.clock = MockClock() self.hs = Mock() self.hs.get_clock = Mock(return_value=self.clock) self.hs.get_auth = Mock() self.cache = HttpTransactionCache(self.hs) - self.mock_http_response = (200, "GOOD JOB!") + self.mock_http_response = (HTTPStatus.OK, "GOOD JOB!") self.mock_key = "foo" @defer.inlineCallbacks diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py index 4672a68596..978c252f84 100644 --- a/tests/rest/key/v2/test_remote_key_resource.py +++ b/tests/rest/key/v2/test_remote_key_resource.py @@ -13,19 +13,24 @@ # limitations under the License. import urllib.parse from io import BytesIO, StringIO +from typing import Any, Dict, Optional, Union from unittest.mock import Mock import signedjson.key from canonicaljson import encode_canonical_json -from nacl.signing import SigningKey from signedjson.sign import sign_json +from signedjson.types import SigningKey -from twisted.web.resource import NoResource +from twisted.test.proto_helpers import MemoryReactor +from twisted.web.resource import NoResource, Resource from synapse.crypto.keyring import PerspectivesKeyFetcher from synapse.http.site import SynapseRequest from synapse.rest.key.v2 import KeyApiV2Resource +from synapse.server import HomeServer from synapse.storage.keys import FetchKeyResult +from synapse.types import JsonDict +from synapse.util import Clock from synapse.util.httpresourcetree import create_resource_tree from synapse.util.stringutils import random_string @@ -35,11 +40,11 @@ from tests.utils import default_config class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase): - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.http_client = Mock() return self.setup_test_homeserver(federation_http_client=self.http_client) - def create_test_resource(self): + def create_test_resource(self) -> Resource: return create_resource_tree( {"/_matrix/key/v2": KeyApiV2Resource(self.hs)}, root_resource=NoResource() ) @@ -51,7 +56,12 @@ class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase): Tell the mock http client to expect an outgoing GET request for the given key """ - async def get_json(destination, path, ignore_backoff=False, **kwargs): + async def get_json( + destination: str, + path: str, + ignore_backoff: bool = False, + **kwargs: Any, + ) -> Union[JsonDict, list]: self.assertTrue(ignore_backoff) self.assertEqual(destination, server_name) key_id = "%s:%s" % (signing_key.alg, signing_key.version) @@ -84,7 +94,8 @@ class RemoteKeyResourceTestCase(BaseRemoteKeyResourceTestCase): Checks that the response is a 200 and returns the decoded json body. """ channel = FakeChannel(self.site, self.reactor) - req = SynapseRequest(channel, self.site) + # channel is a `FakeChannel` but `HTTPChannel` is expected + req = SynapseRequest(channel, self.site) # type: ignore[arg-type] req.content = BytesIO(b"") req.requestReceived( b"GET", @@ -97,7 +108,7 @@ class RemoteKeyResourceTestCase(BaseRemoteKeyResourceTestCase): resp = channel.json_body return resp - def test_get_key(self): + def test_get_key(self) -> None: """Fetch a remote key""" SERVER_NAME = "remote.server" testkey = signedjson.key.generate_signing_key("ver1") @@ -114,7 +125,7 @@ class RemoteKeyResourceTestCase(BaseRemoteKeyResourceTestCase): self.assertIn(SERVER_NAME, keys[0]["signatures"]) self.assertIn(self.hs.hostname, keys[0]["signatures"]) - def test_get_own_key(self): + def test_get_own_key(self) -> None: """Fetch our own key""" testkey = signedjson.key.generate_signing_key("ver1") self.expect_outgoing_key_request(self.hs.hostname, testkey) @@ -141,7 +152,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase): endpoint, to check that the two implementations are compatible. """ - def default_config(self): + def default_config(self) -> Dict[str, Any]: config = super().default_config() # replace the signing key with our own @@ -152,7 +163,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase): return config - def prepare(self, reactor, clock, homeserver): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # make a second homeserver, configured to use the first one as a key notary self.http_client2 = Mock() config = default_config(name="keyclient") @@ -175,7 +186,9 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase): # wire up outbound POST /key/v2/query requests from hs2 so that they # will be forwarded to hs1 - async def post_json(destination, path, data): + async def post_json( + destination: str, path: str, data: Optional[JsonDict] = None + ) -> Union[JsonDict, list]: self.assertEqual(destination, self.hs.hostname) self.assertEqual( path, @@ -183,7 +196,8 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase): ) channel = FakeChannel(self.site, self.reactor) - req = SynapseRequest(channel, self.site) + # channel is a `FakeChannel` but `HTTPChannel` is expected + req = SynapseRequest(channel, self.site) # type: ignore[arg-type] req.content = BytesIO(encode_canonical_json(data)) req.requestReceived( @@ -198,7 +212,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase): self.http_client2.post_json.side_effect = post_json - def test_get_key(self): + def test_get_key(self) -> None: """Fetch a key belonging to a random server""" # make up a key to be fetched. testkey = signedjson.key.generate_signing_key("abc") @@ -218,7 +232,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase): signedjson.key.encode_verify_key_base64(testkey.verify_key), ) - def test_get_notary_key(self): + def test_get_notary_key(self) -> None: """Fetch a key belonging to the notary server""" # make up a key to be fetched. We randomise the keyid to try to get it to # appear before the key server signing key sometimes (otherwise we bail out @@ -240,7 +254,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase): signedjson.key.encode_verify_key_base64(testkey.verify_key), ) - def test_get_notary_keyserver_key(self): + def test_get_notary_keyserver_key(self) -> None: """Fetch the notary's keyserver key""" # we expect hs1 to make a regular key request to itself self.expect_outgoing_key_request(self.hs.hostname, self.hs_signing_key) diff --git a/tests/rest/media/v1/test_base.py b/tests/rest/media/v1/test_base.py index f761e23f1b..c73179151a 100644 --- a/tests/rest/media/v1/test_base.py +++ b/tests/rest/media/v1/test_base.py @@ -28,11 +28,11 @@ class GetFileNameFromHeadersTests(unittest.TestCase): b"inline; filename*=utf-8''foo%C2%A3bar": "foo£bar", } - def tests(self): + def tests(self) -> None: for hdr, expected in self.TEST_CASES.items(): res = get_filename_from_headers({b"Content-Disposition": [hdr]}) self.assertEqual( res, expected, - "expected output for %s to be %s but was %s" % (hdr, expected, res), + f"expected output for {hdr!r} to be {expected} but was {res}", ) diff --git a/tests/rest/media/v1/test_filepath.py b/tests/rest/media/v1/test_filepath.py index 913bc530aa..43e6f0f70a 100644 --- a/tests/rest/media/v1/test_filepath.py +++ b/tests/rest/media/v1/test_filepath.py @@ -21,12 +21,12 @@ from tests import unittest class MediaFilePathsTestCase(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: super().setUp() self.filepaths = MediaFilePaths("/media_store") - def test_local_media_filepath(self): + def test_local_media_filepath(self) -> None: """Test local media paths""" self.assertEqual( self.filepaths.local_media_filepath_rel("GerZNDnDZVjsOtardLuwfIBg"), @@ -37,7 +37,7 @@ class MediaFilePathsTestCase(unittest.TestCase): "/media_store/local_content/Ge/rZ/NDnDZVjsOtardLuwfIBg", ) - def test_local_media_thumbnail(self): + def test_local_media_thumbnail(self) -> None: """Test local media thumbnail paths""" self.assertEqual( self.filepaths.local_media_thumbnail_rel( @@ -52,14 +52,14 @@ class MediaFilePathsTestCase(unittest.TestCase): "/media_store/local_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale", ) - def test_local_media_thumbnail_dir(self): + def test_local_media_thumbnail_dir(self) -> None: """Test local media thumbnail directory paths""" self.assertEqual( self.filepaths.local_media_thumbnail_dir("GerZNDnDZVjsOtardLuwfIBg"), "/media_store/local_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg", ) - def test_remote_media_filepath(self): + def test_remote_media_filepath(self) -> None: """Test remote media paths""" self.assertEqual( self.filepaths.remote_media_filepath_rel( @@ -74,7 +74,7 @@ class MediaFilePathsTestCase(unittest.TestCase): "/media_store/remote_content/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg", ) - def test_remote_media_thumbnail(self): + def test_remote_media_thumbnail(self) -> None: """Test remote media thumbnail paths""" self.assertEqual( self.filepaths.remote_media_thumbnail_rel( @@ -99,7 +99,7 @@ class MediaFilePathsTestCase(unittest.TestCase): "/media_store/remote_thumbnail/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale", ) - def test_remote_media_thumbnail_legacy(self): + def test_remote_media_thumbnail_legacy(self) -> None: """Test old-style remote media thumbnail paths""" self.assertEqual( self.filepaths.remote_media_thumbnail_rel_legacy( @@ -108,7 +108,7 @@ class MediaFilePathsTestCase(unittest.TestCase): "remote_thumbnail/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg", ) - def test_remote_media_thumbnail_dir(self): + def test_remote_media_thumbnail_dir(self) -> None: """Test remote media thumbnail directory paths""" self.assertEqual( self.filepaths.remote_media_thumbnail_dir( @@ -117,7 +117,7 @@ class MediaFilePathsTestCase(unittest.TestCase): "/media_store/remote_thumbnail/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg", ) - def test_url_cache_filepath(self): + def test_url_cache_filepath(self) -> None: """Test URL cache paths""" self.assertEqual( self.filepaths.url_cache_filepath_rel("2020-01-02_GerZNDnDZVjsOtar"), @@ -128,7 +128,7 @@ class MediaFilePathsTestCase(unittest.TestCase): "/media_store/url_cache/2020-01-02/GerZNDnDZVjsOtar", ) - def test_url_cache_filepath_legacy(self): + def test_url_cache_filepath_legacy(self) -> None: """Test old-style URL cache paths""" self.assertEqual( self.filepaths.url_cache_filepath_rel("GerZNDnDZVjsOtardLuwfIBg"), @@ -139,7 +139,7 @@ class MediaFilePathsTestCase(unittest.TestCase): "/media_store/url_cache/Ge/rZ/NDnDZVjsOtardLuwfIBg", ) - def test_url_cache_filepath_dirs_to_delete(self): + def test_url_cache_filepath_dirs_to_delete(self) -> None: """Test URL cache cleanup paths""" self.assertEqual( self.filepaths.url_cache_filepath_dirs_to_delete( @@ -148,7 +148,7 @@ class MediaFilePathsTestCase(unittest.TestCase): ["/media_store/url_cache/2020-01-02"], ) - def test_url_cache_filepath_dirs_to_delete_legacy(self): + def test_url_cache_filepath_dirs_to_delete_legacy(self) -> None: """Test old-style URL cache cleanup paths""" self.assertEqual( self.filepaths.url_cache_filepath_dirs_to_delete( @@ -160,7 +160,7 @@ class MediaFilePathsTestCase(unittest.TestCase): ], ) - def test_url_cache_thumbnail(self): + def test_url_cache_thumbnail(self) -> None: """Test URL cache thumbnail paths""" self.assertEqual( self.filepaths.url_cache_thumbnail_rel( @@ -175,7 +175,7 @@ class MediaFilePathsTestCase(unittest.TestCase): "/media_store/url_cache_thumbnails/2020-01-02/GerZNDnDZVjsOtar/800-600-image-jpeg-scale", ) - def test_url_cache_thumbnail_legacy(self): + def test_url_cache_thumbnail_legacy(self) -> None: """Test old-style URL cache thumbnail paths""" self.assertEqual( self.filepaths.url_cache_thumbnail_rel( @@ -190,7 +190,7 @@ class MediaFilePathsTestCase(unittest.TestCase): "/media_store/url_cache_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale", ) - def test_url_cache_thumbnail_directory(self): + def test_url_cache_thumbnail_directory(self) -> None: """Test URL cache thumbnail directory paths""" self.assertEqual( self.filepaths.url_cache_thumbnail_directory_rel( @@ -203,7 +203,7 @@ class MediaFilePathsTestCase(unittest.TestCase): "/media_store/url_cache_thumbnails/2020-01-02/GerZNDnDZVjsOtar", ) - def test_url_cache_thumbnail_directory_legacy(self): + def test_url_cache_thumbnail_directory_legacy(self) -> None: """Test old-style URL cache thumbnail directory paths""" self.assertEqual( self.filepaths.url_cache_thumbnail_directory_rel( @@ -216,7 +216,7 @@ class MediaFilePathsTestCase(unittest.TestCase): "/media_store/url_cache_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg", ) - def test_url_cache_thumbnail_dirs_to_delete(self): + def test_url_cache_thumbnail_dirs_to_delete(self) -> None: """Test URL cache thumbnail cleanup paths""" self.assertEqual( self.filepaths.url_cache_thumbnail_dirs_to_delete( @@ -228,7 +228,7 @@ class MediaFilePathsTestCase(unittest.TestCase): ], ) - def test_url_cache_thumbnail_dirs_to_delete_legacy(self): + def test_url_cache_thumbnail_dirs_to_delete_legacy(self) -> None: """Test old-style URL cache thumbnail cleanup paths""" self.assertEqual( self.filepaths.url_cache_thumbnail_dirs_to_delete( @@ -241,7 +241,7 @@ class MediaFilePathsTestCase(unittest.TestCase): ], ) - def test_server_name_validation(self): + def test_server_name_validation(self) -> None: """Test validation of server names""" self._test_path_validation( [ @@ -274,7 +274,7 @@ class MediaFilePathsTestCase(unittest.TestCase): ], ) - def test_file_id_validation(self): + def test_file_id_validation(self) -> None: """Test validation of local, remote and legacy URL cache file / media IDs""" # File / media IDs get split into three parts to form paths, consisting of the # first two characters, next two characters and rest of the ID. @@ -357,7 +357,7 @@ class MediaFilePathsTestCase(unittest.TestCase): invalid_values=invalid_file_ids, ) - def test_url_cache_media_id_validation(self): + def test_url_cache_media_id_validation(self) -> None: """Test validation of URL cache media IDs""" self._test_path_validation( [ @@ -387,7 +387,7 @@ class MediaFilePathsTestCase(unittest.TestCase): ], ) - def test_content_type_validation(self): + def test_content_type_validation(self) -> None: """Test validation of thumbnail content types""" self._test_path_validation( [ @@ -410,7 +410,7 @@ class MediaFilePathsTestCase(unittest.TestCase): ], ) - def test_thumbnail_method_validation(self): + def test_thumbnail_method_validation(self) -> None: """Test validation of thumbnail methods""" self._test_path_validation( [ @@ -440,7 +440,7 @@ class MediaFilePathsTestCase(unittest.TestCase): parameter: str, valid_values: Iterable[str], invalid_values: Iterable[str], - ): + ) -> None: """Test that the specified methods validate the named parameter as expected Args: diff --git a/tests/rest/media/v1/test_html_preview.py b/tests/rest/media/v1/test_html_preview.py index a4b57e3d1f..62e308814d 100644 --- a/tests/rest/media/v1/test_html_preview.py +++ b/tests/rest/media/v1/test_html_preview.py @@ -16,7 +16,6 @@ from synapse.rest.media.v1.preview_html import ( _get_html_media_encodings, decode_body, parse_html_to_open_graph, - rebase_url, summarize_paragraphs, ) @@ -32,7 +31,7 @@ class SummarizeTestCase(unittest.TestCase): if not lxml: skip = "url preview feature requires lxml" - def test_long_summarize(self): + def test_long_summarize(self) -> None: example_paras = [ """Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami: Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in @@ -90,7 +89,7 @@ class SummarizeTestCase(unittest.TestCase): " Tromsøya had a population of 36,088. Substantial parts of the urban…", ) - def test_short_summarize(self): + def test_short_summarize(self) -> None: example_paras = [ "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" @@ -117,7 +116,7 @@ class SummarizeTestCase(unittest.TestCase): " most of the year.", ) - def test_small_then_large_summarize(self): + def test_small_then_large_summarize(self) -> None: example_paras = [ "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" @@ -150,7 +149,7 @@ class CalcOgTestCase(unittest.TestCase): if not lxml: skip = "url preview feature requires lxml" - def test_simple(self): + def test_simple(self) -> None: html = b""" <html> <head><title>Foo</title></head> @@ -161,11 +160,11 @@ class CalcOgTestCase(unittest.TestCase): """ tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) - def test_comment(self): + def test_comment(self) -> None: html = b""" <html> <head><title>Foo</title></head> @@ -177,11 +176,11 @@ class CalcOgTestCase(unittest.TestCase): """ tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) - def test_comment2(self): + def test_comment2(self) -> None: html = b""" <html> <head><title>Foo</title></head> @@ -196,7 +195,7 @@ class CalcOgTestCase(unittest.TestCase): """ tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) self.assertEqual( og, @@ -206,7 +205,7 @@ class CalcOgTestCase(unittest.TestCase): }, ) - def test_script(self): + def test_script(self) -> None: html = b""" <html> <head><title>Foo</title></head> @@ -218,11 +217,11 @@ class CalcOgTestCase(unittest.TestCase): """ tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) - def test_missing_title(self): + def test_missing_title(self) -> None: html = b""" <html> <body> @@ -232,11 +231,11 @@ class CalcOgTestCase(unittest.TestCase): """ tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) self.assertEqual(og, {"og:title": None, "og:description": "Some text."}) - def test_h1_as_title(self): + def test_h1_as_title(self) -> None: html = b""" <html> <meta property="og:description" content="Some text."/> @@ -247,11 +246,11 @@ class CalcOgTestCase(unittest.TestCase): """ tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."}) - def test_missing_title_and_broken_h1(self): + def test_missing_title_and_broken_h1(self) -> None: html = b""" <html> <body> @@ -262,23 +261,23 @@ class CalcOgTestCase(unittest.TestCase): """ tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) self.assertEqual(og, {"og:title": None, "og:description": "Some text."}) - def test_empty(self): + def test_empty(self) -> None: """Test a body with no data in it.""" html = b"" tree = decode_body(html, "http://example.com/test.html") self.assertIsNone(tree) - def test_no_tree(self): + def test_no_tree(self) -> None: """A valid body with no tree in it.""" html = b"\x00" tree = decode_body(html, "http://example.com/test.html") self.assertIsNone(tree) - def test_xml(self): + def test_xml(self) -> None: """Test decoding XML and ensure it works properly.""" # Note that the strip() call is important to ensure the xml tag starts # at the initial byte. @@ -290,10 +289,10 @@ class CalcOgTestCase(unittest.TestCase): <head><title>Foo</title></head><body>Some text.</body></html> """.strip() tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) - def test_invalid_encoding(self): + def test_invalid_encoding(self) -> None: """An invalid character encoding should be ignored and treated as UTF-8, if possible.""" html = b""" <html> @@ -304,10 +303,10 @@ class CalcOgTestCase(unittest.TestCase): </html> """ tree = decode_body(html, "http://example.com/test.html", "invalid-encoding") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) - def test_invalid_encoding2(self): + def test_invalid_encoding2(self) -> None: """A body which doesn't match the sent character encoding.""" # Note that this contains an invalid UTF-8 sequence in the title. html = b""" @@ -319,10 +318,10 @@ class CalcOgTestCase(unittest.TestCase): </html> """ tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) self.assertEqual(og, {"og:title": "ÿÿ Foo", "og:description": "Some text."}) - def test_windows_1252(self): + def test_windows_1252(self) -> None: """A body which uses cp1252, but doesn't declare that.""" html = b""" <html> @@ -333,12 +332,12 @@ class CalcOgTestCase(unittest.TestCase): </html> """ tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) self.assertEqual(og, {"og:title": "ó", "og:description": "Some text."}) class MediaEncodingTestCase(unittest.TestCase): - def test_meta_charset(self): + def test_meta_charset(self) -> None: """A character encoding is found via the meta tag.""" encodings = _get_html_media_encodings( b""" @@ -363,7 +362,7 @@ class MediaEncodingTestCase(unittest.TestCase): ) self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"]) - def test_meta_charset_underscores(self): + def test_meta_charset_underscores(self) -> None: """A character encoding contains underscore.""" encodings = _get_html_media_encodings( b""" @@ -376,7 +375,7 @@ class MediaEncodingTestCase(unittest.TestCase): ) self.assertEqual(list(encodings), ["shift_jis", "utf-8", "cp1252"]) - def test_xml_encoding(self): + def test_xml_encoding(self) -> None: """A character encoding is found via the meta tag.""" encodings = _get_html_media_encodings( b""" @@ -388,7 +387,7 @@ class MediaEncodingTestCase(unittest.TestCase): ) self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"]) - def test_meta_xml_encoding(self): + def test_meta_xml_encoding(self) -> None: """Meta tags take precedence over XML encoding.""" encodings = _get_html_media_encodings( b""" @@ -402,7 +401,7 @@ class MediaEncodingTestCase(unittest.TestCase): ) self.assertEqual(list(encodings), ["utf-16", "ascii", "utf-8", "cp1252"]) - def test_content_type(self): + def test_content_type(self) -> None: """A character encoding is found via the Content-Type header.""" # Test a few variations of the header. headers = ( @@ -417,12 +416,12 @@ class MediaEncodingTestCase(unittest.TestCase): encodings = _get_html_media_encodings(b"", header) self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"]) - def test_fallback(self): + def test_fallback(self) -> None: """A character encoding cannot be found in the body or header.""" encodings = _get_html_media_encodings(b"", "text/html") self.assertEqual(list(encodings), ["utf-8", "cp1252"]) - def test_duplicates(self): + def test_duplicates(self) -> None: """Ensure each encoding is only attempted once.""" encodings = _get_html_media_encodings( b""" @@ -436,7 +435,7 @@ class MediaEncodingTestCase(unittest.TestCase): ) self.assertEqual(list(encodings), ["utf-8", "cp1252"]) - def test_unknown_invalid(self): + def test_unknown_invalid(self) -> None: """A character encoding should be ignored if it is unknown or invalid.""" encodings = _get_html_media_encodings( b""" @@ -448,34 +447,3 @@ class MediaEncodingTestCase(unittest.TestCase): 'text/html; charset="invalid"', ) self.assertEqual(list(encodings), ["utf-8", "cp1252"]) - - -class RebaseUrlTestCase(unittest.TestCase): - def test_relative(self): - """Relative URLs should be resolved based on the context of the base URL.""" - self.assertEqual( - rebase_url("subpage", "https://example.com/foo/"), - "https://example.com/foo/subpage", - ) - self.assertEqual( - rebase_url("sibling", "https://example.com/foo"), - "https://example.com/sibling", - ) - self.assertEqual( - rebase_url("/bar", "https://example.com/foo/"), - "https://example.com/bar", - ) - - def test_absolute(self): - """Absolute URLs should not be modified.""" - self.assertEqual( - rebase_url("https://alice.com/a/", "https://example.com/foo/"), - "https://alice.com/a/", - ) - - def test_data(self): - """Data URLs should not be modified.""" - self.assertEqual( - rebase_url("data:,Hello%2C%20World%21", "https://example.com/foo/"), - "data:,Hello%2C%20World%21", - ) diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index cba9be17c4..7204b2dfe0 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -16,7 +16,7 @@ import shutil import tempfile from binascii import unhexlify from io import BytesIO -from typing import Optional +from typing import Any, BinaryIO, Dict, List, Optional, Union from unittest.mock import Mock from urllib import parse @@ -26,18 +26,24 @@ from PIL import Image as Image from twisted.internet import defer from twisted.internet.defer import Deferred +from twisted.test.proto_helpers import MemoryReactor +from synapse.events import EventBase from synapse.events.spamcheck import load_legacy_spam_checkers from synapse.logging.context import make_deferred_yieldable +from synapse.module_api import ModuleApi from synapse.rest import admin from synapse.rest.client import login from synapse.rest.media.v1._base import FileInfo from synapse.rest.media.v1.filepath import MediaFilePaths -from synapse.rest.media.v1.media_storage import MediaStorage +from synapse.rest.media.v1.media_storage import MediaStorage, ReadableFileWrapper from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend +from synapse.server import HomeServer +from synapse.types import RoomAlias +from synapse.util import Clock from tests import unittest -from tests.server import FakeSite, make_request +from tests.server import FakeChannel, FakeSite, make_request from tests.test_utils import SMALL_PNG from tests.utils import default_config @@ -46,7 +52,7 @@ class MediaStorageTests(unittest.HomeserverTestCase): needs_threadpool = True - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.test_dir = tempfile.mkdtemp(prefix="synapse-tests-") self.addCleanup(shutil.rmtree, self.test_dir) @@ -62,7 +68,7 @@ class MediaStorageTests(unittest.HomeserverTestCase): hs, self.primary_base_path, self.filepaths, storage_providers ) - def test_ensure_media_is_in_local_cache(self): + def test_ensure_media_is_in_local_cache(self) -> None: media_id = "some_media_id" test_body = "Test\n" @@ -105,7 +111,7 @@ class MediaStorageTests(unittest.HomeserverTestCase): self.assertEqual(test_body, body) -@attr.s(slots=True, frozen=True) +@attr.s(auto_attribs=True, slots=True, frozen=True) class _TestImage: """An image for testing thumbnailing with the expected results @@ -121,18 +127,18 @@ class _TestImage: a 404 is expected. """ - data = attr.ib(type=bytes) - content_type = attr.ib(type=bytes) - extension = attr.ib(type=bytes) - expected_cropped = attr.ib(type=Optional[bytes], default=None) - expected_scaled = attr.ib(type=Optional[bytes], default=None) - expected_found = attr.ib(default=True, type=bool) + data: bytes + content_type: bytes + extension: bytes + expected_cropped: Optional[bytes] = None + expected_scaled: Optional[bytes] = None + expected_found: bool = True @parameterized_class( ("test_image",), [ - # smoll png + # small png ( _TestImage( SMALL_PNG, @@ -193,11 +199,17 @@ class MediaRepoTests(unittest.HomeserverTestCase): hijack_auth = True user_id = "@test:user" - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.fetches = [] - def get_file(destination, path, output_stream, args=None, max_size=None): + def get_file( + destination: str, + path: str, + output_stream: BinaryIO, + args: Optional[Dict[str, Union[str, List[str]]]] = None, + max_size: Optional[int] = None, + ) -> Deferred: """ Returns tuple[int,dict,str,int] of file length, response headers, absolute URI, and response code. @@ -238,7 +250,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): return hs - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: media_resource = hs.get_media_repository_resource() self.download_resource = media_resource.children[b"download"] @@ -248,8 +260,9 @@ class MediaRepoTests(unittest.HomeserverTestCase): self.media_id = "example.com/12345" - def _req(self, content_disposition, include_content_type=True): - + def _req( + self, content_disposition: Optional[bytes], include_content_type: bool = True + ) -> FakeChannel: channel = make_request( self.reactor, FakeSite(self.download_resource, self.reactor), @@ -288,7 +301,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): return channel - def test_handle_missing_content_type(self): + def test_handle_missing_content_type(self) -> None: channel = self._req( b"inline; filename=out" + self.test_image.extension, include_content_type=False, @@ -299,7 +312,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): headers.getRawHeaders(b"Content-Type"), [b"application/octet-stream"] ) - def test_disposition_filename_ascii(self): + def test_disposition_filename_ascii(self) -> None: """ If the filename is filename=<ascii> then Synapse will decode it as an ASCII string, and use filename= in the response. @@ -315,7 +328,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): [b"inline; filename=out" + self.test_image.extension], ) - def test_disposition_filenamestar_utf8escaped(self): + def test_disposition_filenamestar_utf8escaped(self) -> None: """ If the filename is filename=*utf8''<utf8 escaped> then Synapse will correctly decode it as the UTF-8 string, and use filename* in the @@ -335,7 +348,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): [b"inline; filename*=utf-8''" + filename + self.test_image.extension], ) - def test_disposition_none(self): + def test_disposition_none(self) -> None: """ If there is no filename, one isn't passed on in the Content-Disposition of the request. @@ -348,26 +361,26 @@ class MediaRepoTests(unittest.HomeserverTestCase): ) self.assertEqual(headers.getRawHeaders(b"Content-Disposition"), None) - def test_thumbnail_crop(self): + def test_thumbnail_crop(self) -> None: """Test that a cropped remote thumbnail is available.""" self._test_thumbnail( "crop", self.test_image.expected_cropped, self.test_image.expected_found ) - def test_thumbnail_scale(self): + def test_thumbnail_scale(self) -> None: """Test that a scaled remote thumbnail is available.""" self._test_thumbnail( "scale", self.test_image.expected_scaled, self.test_image.expected_found ) - def test_invalid_type(self): + def test_invalid_type(self) -> None: """An invalid thumbnail type is never available.""" self._test_thumbnail("invalid", None, False) @unittest.override_config( {"thumbnail_sizes": [{"width": 32, "height": 32, "method": "scale"}]} ) - def test_no_thumbnail_crop(self): + def test_no_thumbnail_crop(self) -> None: """ Override the config to generate only scaled thumbnails, but request a cropped one. """ @@ -376,13 +389,13 @@ class MediaRepoTests(unittest.HomeserverTestCase): @unittest.override_config( {"thumbnail_sizes": [{"width": 32, "height": 32, "method": "crop"}]} ) - def test_no_thumbnail_scale(self): + def test_no_thumbnail_scale(self) -> None: """ Override the config to generate only cropped thumbnails, but request a scaled one. """ self._test_thumbnail("scale", None, False) - def test_thumbnail_repeated_thumbnail(self): + def test_thumbnail_repeated_thumbnail(self) -> None: """Test that fetching the same thumbnail works, and deleting the on disk thumbnail regenerates it. """ @@ -443,7 +456,9 @@ class MediaRepoTests(unittest.HomeserverTestCase): channel.result["body"], ) - def _test_thumbnail(self, method, expected_body, expected_found): + def _test_thumbnail( + self, method: str, expected_body: Optional[bytes], expected_found: bool + ) -> None: params = "?width=32&height=32&method=" + method channel = make_request( self.reactor, @@ -485,7 +500,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): ) @parameterized.expand([("crop", 16), ("crop", 64), ("scale", 16), ("scale", 64)]) - def test_same_quality(self, method, desired_size): + def test_same_quality(self, method: str, desired_size: int) -> None: """Test that choosing between thumbnails with the same quality rating succeeds. We are not particular about which thumbnail is chosen.""" @@ -521,7 +536,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): ) ) - def test_x_robots_tag_header(self): + def test_x_robots_tag_header(self) -> None: """ Tests that the `X-Robots-Tag` header is present, which informs web crawlers to not index, archive, or follow links in media. @@ -540,29 +555,38 @@ class TestSpamChecker: `evil`. """ - def __init__(self, config, api): + def __init__(self, config: Dict[str, Any], api: ModuleApi) -> None: self.config = config self.api = api - def parse_config(config): + def parse_config(config: Dict[str, Any]) -> Dict[str, Any]: return config - async def check_event_for_spam(self, foo): + async def check_event_for_spam(self, event: EventBase) -> Union[bool, str]: return False # allow all events - async def user_may_invite(self, inviter_userid, invitee_userid, room_id): + async def user_may_invite( + self, + inviter_userid: str, + invitee_userid: str, + room_id: str, + ) -> bool: return True # allow all invites - async def user_may_create_room(self, userid): + async def user_may_create_room(self, userid: str) -> bool: return True # allow all room creations - async def user_may_create_room_alias(self, userid, room_alias): + async def user_may_create_room_alias( + self, userid: str, room_alias: RoomAlias + ) -> bool: return True # allow all room aliases - async def user_may_publish_room(self, userid, room_id): + async def user_may_publish_room(self, userid: str, room_id: str) -> bool: return True # allow publishing of all rooms - async def check_media_file_for_spam(self, file_wrapper, file_info) -> bool: + async def check_media_file_for_spam( + self, file_wrapper: ReadableFileWrapper, file_info: FileInfo + ) -> bool: buf = BytesIO() await file_wrapper.write_chunks_to(buf.write) @@ -575,7 +599,7 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase): admin.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.user = self.register_user("user", "pass") self.tok = self.login("user", "pass") @@ -586,7 +610,7 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase): load_legacy_spam_checkers(hs) - def default_config(self): + def default_config(self) -> Dict[str, Any]: config = default_config("test") config.update( @@ -602,13 +626,13 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase): return config - def test_upload_innocent(self): + def test_upload_innocent(self) -> None: """Attempt to upload some innocent data that should be allowed.""" self.helper.upload_media( self.upload_resource, SMALL_PNG, tok=self.tok, expect_code=200 ) - def test_upload_ban(self): + def test_upload_ban(self) -> None: """Attempt to upload some data that includes bytes "evil", which should get rejected by the spam checker. """ diff --git a/tests/rest/media/v1/test_oembed.py b/tests/rest/media/v1/test_oembed.py index 048d0ca44a..f38d7225f8 100644 --- a/tests/rest/media/v1/test_oembed.py +++ b/tests/rest/media/v1/test_oembed.py @@ -16,7 +16,7 @@ import json from twisted.test.proto_helpers import MemoryReactor -from synapse.rest.media.v1.oembed import OEmbedProvider +from synapse.rest.media.v1.oembed import OEmbedProvider, OEmbedResult from synapse.server import HomeServer from synapse.types import JsonDict from synapse.util import Clock @@ -25,15 +25,15 @@ from tests.unittest import HomeserverTestCase class OEmbedTests(HomeserverTestCase): - def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer): - self.oembed = OEmbedProvider(homeserver) + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.oembed = OEmbedProvider(hs) - def parse_response(self, response: JsonDict): + def parse_response(self, response: JsonDict) -> OEmbedResult: return self.oembed.parse_oembed_response( "https://test", json.dumps(response).encode("utf-8") ) - def test_version(self): + def test_version(self) -> None: """Accept versions that are similar to 1.0 as a string or int (or missing).""" for version in ("1.0", 1.0, 1): result = self.parse_response({"version": version, "type": "link"}) diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py index da2c533260..5148c39874 100644 --- a/tests/rest/media/v1/test_url_preview.py +++ b/tests/rest/media/v1/test_url_preview.py @@ -16,16 +16,21 @@ import base64 import json import os import re +from typing import Any, Dict, Optional, Sequence, Tuple, Type from urllib.parse import urlencode from twisted.internet._resolver import HostResolution from twisted.internet.address import IPv4Address, IPv6Address from twisted.internet.error import DNSLookupError -from twisted.test.proto_helpers import AccumulatingProtocol +from twisted.internet.interfaces import IAddress, IResolutionReceiver +from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactor from synapse.config.oembed import OEmbedEndpointConfig +from synapse.rest.media.v1.media_repository import MediaRepositoryResource from synapse.rest.media.v1.preview_url_resource import IMAGE_CACHE_EXPIRY_MS +from synapse.server import HomeServer from synapse.types import JsonDict +from synapse.util import Clock from synapse.util.stringutils import parse_and_validate_mxc_uri from tests import unittest @@ -52,7 +57,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): b"</head></html>" ) - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() config["url_preview_enabled"] = True @@ -113,22 +118,22 @@ class URLPreviewTests(unittest.HomeserverTestCase): return hs - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.media_repo = hs.get_media_repository_resource() self.preview_url = self.media_repo.children[b"preview_url"] - self.lookups = {} + self.lookups: Dict[str, Any] = {} class Resolver: def resolveHostName( _self, - resolutionReceiver, - hostName, - portNumber=0, - addressTypes=None, - transportSemantics="TCP", - ): + resolutionReceiver: IResolutionReceiver, + hostName: str, + portNumber: int = 0, + addressTypes: Optional[Sequence[Type[IAddress]]] = None, + transportSemantics: str = "TCP", + ) -> IResolutionReceiver: resolution = HostResolution(hostName) resolutionReceiver.resolutionBegan(resolution) @@ -140,9 +145,9 @@ class URLPreviewTests(unittest.HomeserverTestCase): resolutionReceiver.resolutionComplete() return resolutionReceiver - self.reactor.nameResolver = Resolver() + self.reactor.nameResolver = Resolver() # type: ignore[assignment] - def create_test_resource(self): + def create_test_resource(self) -> MediaRepositoryResource: return self.hs.get_media_repository_resource() def _assert_small_png(self, json_body: JsonDict) -> None: @@ -153,7 +158,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.assertEqual(json_body["og:image:type"], "image/png") self.assertEqual(json_body["matrix:image:size"], 67) - def test_cache_returns_correct_type(self): + def test_cache_returns_correct_type(self) -> None: self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] channel = self.make_request( @@ -207,7 +212,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): channel.json_body, {"og:title": "~matrix~", "og:description": "hi"} ) - def test_non_ascii_preview_httpequiv(self): + def test_non_ascii_preview_httpequiv(self) -> None: self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] end_content = ( @@ -243,7 +248,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430") - def test_video_rejected(self): + def test_video_rejected(self) -> None: self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] end_content = b"anything" @@ -279,7 +284,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): }, ) - def test_audio_rejected(self): + def test_audio_rejected(self) -> None: self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] end_content = b"anything" @@ -315,7 +320,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): }, ) - def test_non_ascii_preview_content_type(self): + def test_non_ascii_preview_content_type(self) -> None: self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] end_content = ( @@ -350,7 +355,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430") - def test_overlong_title(self): + def test_overlong_title(self) -> None: self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] end_content = ( @@ -387,7 +392,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): # We should only see the `og:description` field, as `title` is too long and should be stripped out self.assertCountEqual(["og:description"], res.keys()) - def test_ipaddr(self): + def test_ipaddr(self) -> None: """ IP addresses can be previewed directly. """ @@ -417,7 +422,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): channel.json_body, {"og:title": "~matrix~", "og:description": "hi"} ) - def test_blacklisted_ip_specific(self): + def test_blacklisted_ip_specific(self) -> None: """ Blacklisted IP addresses, found via DNS, are not spidered. """ @@ -438,7 +443,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): }, ) - def test_blacklisted_ip_range(self): + def test_blacklisted_ip_range(self) -> None: """ Blacklisted IP ranges, IPs found over DNS, are not spidered. """ @@ -457,7 +462,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): }, ) - def test_blacklisted_ip_specific_direct(self): + def test_blacklisted_ip_specific_direct(self) -> None: """ Blacklisted IP addresses, accessed directly, are not spidered. """ @@ -476,7 +481,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, 403) - def test_blacklisted_ip_range_direct(self): + def test_blacklisted_ip_range_direct(self) -> None: """ Blacklisted IP ranges, accessed directly, are not spidered. """ @@ -493,7 +498,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): }, ) - def test_blacklisted_ip_range_whitelisted_ip(self): + def test_blacklisted_ip_range_whitelisted_ip(self) -> None: """ Blacklisted but then subsequently whitelisted IP addresses can be spidered. @@ -526,7 +531,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): channel.json_body, {"og:title": "~matrix~", "og:description": "hi"} ) - def test_blacklisted_ip_with_external_ip(self): + def test_blacklisted_ip_with_external_ip(self) -> None: """ If a hostname resolves a blacklisted IP, even if there's a non-blacklisted one, it will be rejected. @@ -549,7 +554,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): }, ) - def test_blacklisted_ipv6_specific(self): + def test_blacklisted_ipv6_specific(self) -> None: """ Blacklisted IP addresses, found via DNS, are not spidered. """ @@ -572,7 +577,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): }, ) - def test_blacklisted_ipv6_range(self): + def test_blacklisted_ipv6_range(self) -> None: """ Blacklisted IP ranges, IPs found over DNS, are not spidered. """ @@ -591,7 +596,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): }, ) - def test_OPTIONS(self): + def test_OPTIONS(self) -> None: """ OPTIONS returns the OPTIONS. """ @@ -601,7 +606,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body, {}) - def test_accept_language_config_option(self): + def test_accept_language_config_option(self) -> None: """ Accept-Language header is sent to the remote server """ @@ -652,7 +657,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): server.data, ) - def test_data_url(self): + def test_data_url(self) -> None: """ Requesting to preview a data URL is not supported. """ @@ -675,7 +680,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.assertEqual(channel.code, 500) - def test_inline_data_url(self): + def test_inline_data_url(self) -> None: """ An inline image (as a data URL) should be parsed properly. """ @@ -712,7 +717,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200) self._assert_small_png(channel.json_body) - def test_oembed_photo(self): + def test_oembed_photo(self) -> None: """Test an oEmbed endpoint which returns a 'photo' type which redirects the preview to a new URL.""" self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")] self.lookups["cdn.twitter.com"] = [(IPv4Address, "10.1.2.3")] @@ -771,7 +776,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.assertEqual(body["og:url"], "http://twitter.com/matrixdotorg/status/12345") self._assert_small_png(body) - def test_oembed_rich(self): + def test_oembed_rich(self) -> None: """Test an oEmbed endpoint which returns HTML content via the 'rich' type.""" self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")] @@ -817,7 +822,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): }, ) - def test_oembed_format(self): + def test_oembed_format(self) -> None: """Test an oEmbed endpoint which requires the format in the URL.""" self.lookups["www.hulu.com"] = [(IPv4Address, "10.1.2.3")] @@ -866,7 +871,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): }, ) - def test_oembed_autodiscovery(self): + def test_oembed_autodiscovery(self) -> None: """ Autodiscovery works by finding the link in the HTML response and then requesting an oEmbed URL. 1. Request a preview of a URL which is not known to the oEmbed code. @@ -962,7 +967,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): ) self._assert_small_png(body) - def _download_image(self): + def _download_image(self) -> Tuple[str, str]: """Downloads an image into the URL cache. Returns: A (host, media_id) tuple representing the MXC URI of the image. @@ -995,7 +1000,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.assertIsNone(_port) return host, media_id - def test_storage_providers_exclude_files(self): + def test_storage_providers_exclude_files(self) -> None: """Test that files are not stored in or fetched from storage providers.""" host, media_id = self._download_image() @@ -1037,7 +1042,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): "URL cache file was unexpectedly retrieved from a storage provider", ) - def test_storage_providers_exclude_thumbnails(self): + def test_storage_providers_exclude_thumbnails(self) -> None: """Test that thumbnails are not stored in or fetched from storage providers.""" host, media_id = self._download_image() @@ -1090,7 +1095,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): "URL cache thumbnail was unexpectedly retrieved from a storage provider", ) - def test_cache_expiry(self): + def test_cache_expiry(self) -> None: """Test that URL cache files and thumbnails are cleaned up properly on expiry.""" self.preview_url.clock = MockClock() diff --git a/tests/rest/test_health.py b/tests/rest/test_health.py index 01d48c3860..da325955f8 100644 --- a/tests/rest/test_health.py +++ b/tests/rest/test_health.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from http import HTTPStatus from synapse.rest.health import HealthResource @@ -19,12 +19,12 @@ from tests import unittest class HealthCheckTests(unittest.HomeserverTestCase): - def create_test_resource(self): + def create_test_resource(self) -> HealthResource: # replace the JsonResource with a HealthResource. return HealthResource() - def test_health(self): + def test_health(self) -> None: channel = self.make_request("GET", "/health", shorthand=False) - self.assertEqual(channel.code, 200) + self.assertEqual(channel.code, HTTPStatus.OK) self.assertEqual(channel.result["body"], b"OK") diff --git a/tests/rest/test_well_known.py b/tests/rest/test_well_known.py index 118aa93a32..11f78f52b8 100644 --- a/tests/rest/test_well_known.py +++ b/tests/rest/test_well_known.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from http import HTTPStatus + from twisted.web.resource import Resource from synapse.rest.well_known import well_known_resource @@ -19,7 +21,7 @@ from tests import unittest class WellKnownTests(unittest.HomeserverTestCase): - def create_test_resource(self): + def create_test_resource(self) -> Resource: # replace the JsonResource with a Resource wrapping the WellKnownResource res = Resource() res.putChild(b".well-known", well_known_resource(self.hs)) @@ -31,12 +33,12 @@ class WellKnownTests(unittest.HomeserverTestCase): "default_identity_server": "https://testis", } ) - def test_client_well_known(self): + def test_client_well_known(self) -> None: channel = self.make_request( "GET", "/.well-known/matrix/client", shorthand=False ) - self.assertEqual(channel.code, 200) + self.assertEqual(channel.code, HTTPStatus.OK) self.assertEqual( channel.json_body, { @@ -50,27 +52,27 @@ class WellKnownTests(unittest.HomeserverTestCase): "public_baseurl": None, } ) - def test_client_well_known_no_public_baseurl(self): + def test_client_well_known_no_public_baseurl(self) -> None: channel = self.make_request( "GET", "/.well-known/matrix/client", shorthand=False ) - self.assertEqual(channel.code, 404) + self.assertEqual(channel.code, HTTPStatus.NOT_FOUND) @unittest.override_config({"serve_server_wellknown": True}) - def test_server_well_known(self): + def test_server_well_known(self) -> None: channel = self.make_request( "GET", "/.well-known/matrix/server", shorthand=False ) - self.assertEqual(channel.code, 200) + self.assertEqual(channel.code, HTTPStatus.OK) self.assertEqual( channel.json_body, {"m.server": "test:443"}, ) - def test_server_well_known_disabled(self): + def test_server_well_known_disabled(self) -> None: channel = self.make_request( "GET", "/.well-known/matrix/server", shorthand=False ) - self.assertEqual(channel.code, 404) + self.assertEqual(channel.code, HTTPStatus.NOT_FOUND) diff --git a/tests/server.py b/tests/server.py index 82990c2eb9..6ce2a17bf4 100644 --- a/tests/server.py +++ b/tests/server.py @@ -54,13 +54,18 @@ from twisted.internet.interfaces import ( ITransport, ) from twisted.python.failure import Failure -from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock +from twisted.test.proto_helpers import ( + AccumulatingProtocol, + MemoryReactor, + MemoryReactorClock, +) from twisted.web.http_headers import Headers from twisted.web.resource import IResource from twisted.web.server import Request, Site from synapse.config.database import DatabaseConnectionConfig from synapse.http.site import SynapseRequest +from synapse.logging.context import ContextResourceUsage from synapse.server import HomeServer from synapse.storage import DataStore from synapse.storage.engines import PostgresEngine, create_engine @@ -88,18 +93,19 @@ class TimedOutException(Exception): """ -@attr.s +@attr.s(auto_attribs=True) class FakeChannel: """ A fake Twisted Web Channel (the part that interfaces with the wire). """ - site = attr.ib(type=Union[Site, "FakeSite"]) - _reactor = attr.ib() - result = attr.ib(type=dict, default=attr.Factory(dict)) - _ip = attr.ib(type=str, default="127.0.0.1") + site: Union[Site, "FakeSite"] + _reactor: MemoryReactor + result: dict = attr.Factory(dict) + _ip: str = "127.0.0.1" _producer: Optional[Union[IPullProducer, IPushProducer]] = None + resource_usage: Optional[ContextResourceUsage] = None @property def json_body(self): @@ -168,6 +174,8 @@ class FakeChannel: def requestDone(self, _self): self.result["done"] = True + if isinstance(_self, SynapseRequest): + self.resource_usage = _self.logcontext.get_resource_usage() def getPeer(self): # We give an address so that getClientIP returns a non null entry, diff --git a/tests/storage/test_account_data.py b/tests/storage/test_account_data.py index 272cd35402..72bf5b3d31 100644 --- a/tests/storage/test_account_data.py +++ b/tests/storage/test_account_data.py @@ -47,9 +47,18 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase): expected_ignorer_user_ids, ) + def assert_ignored( + self, ignorer_user_id: str, expected_ignored_user_ids: Set[str] + ) -> None: + self.assertEqual( + self.get_success(self.store.ignored_users(ignorer_user_id)), + expected_ignored_user_ids, + ) + def test_ignoring_users(self): """Basic adding/removing of users from the ignore list.""" self._update_ignore_list("@other:test", "@another:remote") + self.assert_ignored(self.user, {"@other:test", "@another:remote"}) # Check a user which no one ignores. self.assert_ignorers("@user:test", set()) @@ -62,6 +71,7 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase): # Add one user, remove one user, and leave one user. self._update_ignore_list("@foo:test", "@another:remote") + self.assert_ignored(self.user, {"@foo:test", "@another:remote"}) # Check the removed user. self.assert_ignorers("@other:test", set()) @@ -76,20 +86,24 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase): """Ensure that caching works properly between different users.""" # The first user ignores a user. self._update_ignore_list("@other:test") + self.assert_ignored(self.user, {"@other:test"}) self.assert_ignorers("@other:test", {self.user}) # The second user ignores them. self._update_ignore_list("@other:test", ignorer_user_id="@second:test") + self.assert_ignored("@second:test", {"@other:test"}) self.assert_ignorers("@other:test", {self.user, "@second:test"}) # The first user un-ignores them. self._update_ignore_list() + self.assert_ignored(self.user, set()) self.assert_ignorers("@other:test", {"@second:test"}) def test_invalid_data(self): """Invalid data ends up clearing out the ignored users list.""" # Add some data and ensure it is there. self._update_ignore_list("@other:test") + self.assert_ignored(self.user, {"@other:test"}) self.assert_ignorers("@other:test", {self.user}) # No ignored_users key. @@ -102,10 +116,12 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase): ) # No one ignores the user now. + self.assert_ignored(self.user, set()) self.assert_ignorers("@other:test", set()) # Add some data and ensure it is there. self._update_ignore_list("@other:test") + self.assert_ignored(self.user, {"@other:test"}) self.assert_ignorers("@other:test", {self.user}) # Invalid data. @@ -118,4 +134,5 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase): ) # No one ignores the user now. + self.assert_ignored(self.user, set()) self.assert_ignorers("@other:test", set()) diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index 39dcc094bd..fd619b64d4 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -14,16 +14,23 @@ from unittest.mock import Mock +import yaml + from twisted.internet.defer import Deferred, ensureDeferred +from twisted.test.proto_helpers import MemoryReactor +from synapse.server import HomeServer from synapse.storage.background_updates import BackgroundUpdater +from synapse.types import JsonDict +from synapse.util import Clock from tests import unittest from tests.test_utils import make_awaitable, simple_async_mock +from tests.unittest import override_config class BackgroundUpdateTestCase(unittest.HomeserverTestCase): - def prepare(self, reactor, clock, homeserver): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.updates: BackgroundUpdater = self.hs.get_datastores().main.db_pool.updates # the base test class should have run the real bg updates for us self.assertTrue( @@ -34,50 +41,50 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): self.updates.register_background_update_handler( "test_update", self.update_handler ) + self.store = self.hs.get_datastores().main + + async def update(self, progress: JsonDict, count: int) -> int: + duration_ms = 10 + await self.clock.sleep((count * duration_ms) / 1000) + progress = {"my_key": progress["my_key"] + 1} + await self.store.db_pool.runInteraction( + "update_progress", + self.updates._background_update_progress_txn, + "test_update", + progress, + ) + return count - def test_do_background_update(self): + def test_do_background_update(self) -> None: # the time we claim it takes to update one item when running the update duration_ms = 10 # the target runtime for each bg update target_background_update_duration_ms = 100 - store = self.hs.get_datastores().main self.get_success( - store.db_pool.simple_insert( + self.store.db_pool.simple_insert( "background_updates", values={"update_name": "test_update", "progress_json": '{"my_key": 1}'}, ) ) - # first step: make a bit of progress - async def update(progress, count): - await self.clock.sleep((count * duration_ms) / 1000) - progress = {"my_key": progress["my_key"] + 1} - await store.db_pool.runInteraction( - "update_progress", - self.updates._background_update_progress_txn, - "test_update", - progress, - ) - return count - - self.update_handler.side_effect = update + self.update_handler.side_effect = self.update self.update_handler.reset_mock() res = self.get_success( self.updates.do_next_background_update(False), - by=0.01, + by=0.02, ) self.assertFalse(res) # on the first call, we should get run with the default background update size self.update_handler.assert_called_once_with( - {"my_key": 1}, self.updates.MINIMUM_BACKGROUND_BATCH_SIZE + {"my_key": 1}, self.updates.default_background_batch_size ) # second step: complete the update # we should now get run with a much bigger number of items to update - async def update(progress, count): + async def update(progress: JsonDict, count: int) -> int: self.assertEqual(progress, {"my_key": 2}) self.assertAlmostEqual( count, @@ -99,16 +106,234 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): self.assertTrue(result) self.assertFalse(self.update_handler.called) + @override_config( + yaml.safe_load( + """ + background_updates: + default_batch_size: 20 + """ + ) + ) + def test_background_update_default_batch_set_by_config(self) -> None: + """ + Test that the background update is run with the default_batch_size set by the config + """ + + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + values={"update_name": "test_update", "progress_json": '{"my_key": 1}'}, + ) + ) + + self.update_handler.side_effect = self.update + self.update_handler.reset_mock() + res = self.get_success( + self.updates.do_next_background_update(False), + by=0.01, + ) + self.assertFalse(res) + + # on the first call, we should get run with the default background update size specified in the config + self.update_handler.assert_called_once_with({"my_key": 1}, 20) + + def test_background_update_default_sleep_behavior(self) -> None: + """ + Test default background update behavior, which is to sleep + """ + + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + values={"update_name": "test_update", "progress_json": '{"my_key": 1}'}, + ) + ) + + self.update_handler.side_effect = self.update + self.update_handler.reset_mock() + self.updates.start_doing_background_updates() + + # 2: advance the reactor less than the default sleep duration (1000ms) + self.reactor.pump([0.5]) + # check that an update has not been run + self.update_handler.assert_not_called() + + # advance reactor past default sleep duration + self.reactor.pump([1]) + # check that update has been run + self.update_handler.assert_called() + + @override_config( + yaml.safe_load( + """ + background_updates: + sleep_duration_ms: 500 + """ + ) + ) + def test_background_update_sleep_set_in_config(self) -> None: + """ + Test that changing the sleep time in the config changes how long it sleeps + """ + + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + values={"update_name": "test_update", "progress_json": '{"my_key": 1}'}, + ) + ) + + self.update_handler.side_effect = self.update + self.update_handler.reset_mock() + self.updates.start_doing_background_updates() + + # 2: advance the reactor less than the configured sleep duration (500ms) + self.reactor.pump([0.45]) + # check that an update has not been run + self.update_handler.assert_not_called() + + # advance reactor past config sleep duration but less than default duration + self.reactor.pump([0.75]) + # check that update has been run + self.update_handler.assert_called() + + @override_config( + yaml.safe_load( + """ + background_updates: + sleep_enabled: false + """ + ) + ) + def test_disabling_background_update_sleep(self) -> None: + """ + Test that disabling sleep in the config results in bg update not sleeping + """ + + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + values={"update_name": "test_update", "progress_json": '{"my_key": 1}'}, + ) + ) + + self.update_handler.side_effect = self.update + self.update_handler.reset_mock() + self.updates.start_doing_background_updates() + + # 2: advance the reactor very little + self.reactor.pump([0.025]) + # check that an update has run + self.update_handler.assert_called() + + @override_config( + yaml.safe_load( + """ + background_updates: + background_update_duration_ms: 500 + """ + ) + ) + def test_background_update_duration_set_in_config(self) -> None: + """ + Test that the desired duration set in the config is used in determining batch size + """ + # Duration of one background update item + duration_ms = 10 + + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + values={"update_name": "test_update", "progress_json": '{"my_key": 1}'}, + ) + ) + + self.update_handler.side_effect = self.update + self.update_handler.reset_mock() + res = self.get_success( + self.updates.do_next_background_update(False), + by=0.02, + ) + self.assertFalse(res) + + # the first update was run with the default batch size, this should be run with 500ms as the + # desired duration + async def update(progress: JsonDict, count: int) -> int: + self.assertEqual(progress, {"my_key": 2}) + self.assertAlmostEqual( + count, + 500 / duration_ms, + places=0, + ) + await self.updates._end_background_update("test_update") + return count + + self.update_handler.side_effect = update + self.get_success(self.updates.do_next_background_update(False)) + + @override_config( + yaml.safe_load( + """ + background_updates: + min_batch_size: 5 + """ + ) + ) + def test_background_update_min_batch_set_in_config(self) -> None: + """ + Test that the minimum batch size set in the config is used + """ + # a very long-running individual update + duration_ms = 50 + + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + values={"update_name": "test_update", "progress_json": '{"my_key": 1}'}, + ) + ) + + # Run the update with the long-running update item + async def update_long(progress: JsonDict, count: int) -> int: + await self.clock.sleep((count * duration_ms) / 1000) + progress = {"my_key": progress["my_key"] + 1} + await self.store.db_pool.runInteraction( + "update_progress", + self.updates._background_update_progress_txn, + "test_update", + progress, + ) + return count + + self.update_handler.side_effect = update_long + self.update_handler.reset_mock() + res = self.get_success( + self.updates.do_next_background_update(False), + by=1, + ) + self.assertFalse(res) + + # the first update was run with the default batch size, this should be run with minimum batch size + # as the first items took a very long time + async def update_short(progress: JsonDict, count: int) -> int: + self.assertEqual(progress, {"my_key": 2}) + self.assertEqual(count, 5) + await self.updates._end_background_update("test_update") + return count + + self.update_handler.side_effect = update_short + self.get_success(self.updates.do_next_background_update(False)) + class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase): - def prepare(self, reactor, clock, homeserver): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.updates: BackgroundUpdater = self.hs.get_datastores().main.db_pool.updates # the base test class should have run the real bg updates for us self.assertTrue( self.get_success(self.updates.has_completed_background_updates()) ) - self.update_deferred = Deferred() + self.update_deferred: Deferred[int] = Deferred() self.update_handler = Mock(return_value=self.update_deferred) self.updates.register_background_update_handler( "test_update", self.update_handler @@ -137,7 +362,7 @@ class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase): ), ) - def test_controller(self): + def test_controller(self) -> None: store = self.hs.get_datastores().main self.get_success( store.db_pool.simple_insert( @@ -147,7 +372,7 @@ class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase): ) # Set the return value for the context manager. - enter_defer = Deferred() + enter_defer: Deferred[int] = Deferred() self._update_ctx_manager.__aenter__ = Mock(return_value=enter_defer) # Start the background update. diff --git a/tests/storage/test_database.py b/tests/storage/test_database.py index 6fbac0ab14..a40fc20ef9 100644 --- a/tests/storage/test_database.py +++ b/tests/storage/test_database.py @@ -12,25 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.database import make_tuple_comparison_clause -from synapse.storage.engines import BaseDatabaseEngine +from typing import Callable, Tuple +from unittest.mock import Mock, call -from tests import unittest +from twisted.internet import defer +from twisted.internet.defer import CancelledError, Deferred +from twisted.test.proto_helpers import MemoryReactor +from synapse.server import HomeServer +from synapse.storage.database import ( + DatabasePool, + LoggingTransaction, + make_tuple_comparison_clause, +) +from synapse.util import Clock -def _stub_db_engine(**kwargs) -> BaseDatabaseEngine: - # returns a DatabaseEngine, circumventing the abc mechanism - # any kwargs are set as attributes on the class before instantiating it - t = type( - "TestBaseDatabaseEngine", - (BaseDatabaseEngine,), - dict(BaseDatabaseEngine.__dict__), - ) - # defeat the abc mechanism - t.__abstractmethods__ = set() - for k, v in kwargs.items(): - setattr(t, k, v) - return t(None, None) +from tests import unittest class TupleComparisonClauseTestCase(unittest.TestCase): @@ -38,3 +35,150 @@ class TupleComparisonClauseTestCase(unittest.TestCase): clause, args = make_tuple_comparison_clause([("a", 1), ("b", 2)]) self.assertEqual(clause, "(a,b) > (?,?)") self.assertEqual(args, [1, 2]) + + +class CallbacksTestCase(unittest.HomeserverTestCase): + """Tests for transaction callbacks.""" + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main + self.db_pool: DatabasePool = self.store.db_pool + + def _run_interaction( + self, func: Callable[[LoggingTransaction], object] + ) -> Tuple[Mock, Mock]: + """Run the given function in a database transaction, with callbacks registered. + + Args: + func: The function to be run in a transaction. The transaction will be + retried if `func` raises an `OperationalError`. + + Returns: + Two mocks, which were registered as an `after_callback` and an + `exception_callback` respectively, on every transaction attempt. + """ + after_callback = Mock() + exception_callback = Mock() + + def _test_txn(txn: LoggingTransaction) -> None: + txn.call_after(after_callback, 123, 456, extra=789) + txn.call_on_exception(exception_callback, 987, 654, extra=321) + func(txn) + + try: + self.get_success_or_raise( + self.db_pool.runInteraction("test_transaction", _test_txn) + ) + except Exception: + pass + + return after_callback, exception_callback + + def test_after_callback(self) -> None: + """Test that the after callback is called when a transaction succeeds.""" + after_callback, exception_callback = self._run_interaction(lambda txn: None) + + after_callback.assert_called_once_with(123, 456, extra=789) + exception_callback.assert_not_called() + + def test_exception_callback(self) -> None: + """Test that the exception callback is called when a transaction fails.""" + _test_txn = Mock(side_effect=ZeroDivisionError) + after_callback, exception_callback = self._run_interaction(_test_txn) + + after_callback.assert_not_called() + exception_callback.assert_called_once_with(987, 654, extra=321) + + def test_failed_retry(self) -> None: + """Test that the exception callback is called for every failed attempt.""" + # Always raise an `OperationalError`. + _test_txn = Mock(side_effect=self.db_pool.engine.module.OperationalError) + after_callback, exception_callback = self._run_interaction(_test_txn) + + after_callback.assert_not_called() + exception_callback.assert_has_calls( + [ + call(987, 654, extra=321), + call(987, 654, extra=321), + call(987, 654, extra=321), + call(987, 654, extra=321), + call(987, 654, extra=321), + call(987, 654, extra=321), + ] + ) + self.assertEqual(exception_callback.call_count, 6) # no additional calls + + def test_successful_retry(self) -> None: + """Test callbacks for a failed transaction followed by a successful attempt.""" + # Raise an `OperationalError` on the first attempt only. + _test_txn = Mock( + side_effect=[self.db_pool.engine.module.OperationalError, None] + ) + after_callback, exception_callback = self._run_interaction(_test_txn) + + # Calling both `after_callback`s when the first attempt failed is rather + # surprising (#12184). Let's document the behaviour in a test. + after_callback.assert_has_calls( + [ + call(123, 456, extra=789), + call(123, 456, extra=789), + ] + ) + self.assertEqual(after_callback.call_count, 2) # no additional calls + exception_callback.assert_not_called() + + +class CancellationTestCase(unittest.HomeserverTestCase): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main + self.db_pool: DatabasePool = self.store.db_pool + + def test_after_callback(self) -> None: + """Test that the after callback is called when a transaction succeeds.""" + d: "Deferred[None]" + after_callback = Mock() + exception_callback = Mock() + + def _test_txn(txn: LoggingTransaction) -> None: + txn.call_after(after_callback, 123, 456, extra=789) + txn.call_on_exception(exception_callback, 987, 654, extra=321) + d.cancel() + + d = defer.ensureDeferred( + self.db_pool.runInteraction("test_transaction", _test_txn) + ) + self.get_failure(d, CancelledError) + + after_callback.assert_called_once_with(123, 456, extra=789) + exception_callback.assert_not_called() + + def test_exception_callback(self) -> None: + """Test that the exception callback is called when a transaction fails.""" + d: "Deferred[None]" + after_callback = Mock() + exception_callback = Mock() + + def _test_txn(txn: LoggingTransaction) -> None: + txn.call_after(after_callback, 123, 456, extra=789) + txn.call_on_exception(exception_callback, 987, 654, extra=321) + d.cancel() + # Simulate a retryable failure on every attempt. + raise self.db_pool.engine.module.OperationalError() + + d = defer.ensureDeferred( + self.db_pool.runInteraction("test_transaction", _test_txn) + ) + self.get_failure(d, CancelledError) + + after_callback.assert_not_called() + exception_callback.assert_has_calls( + [ + call(987, 654, extra=321), + call(987, 654, extra=321), + call(987, 654, extra=321), + call(987, 654, extra=321), + call(987, 654, extra=321), + call(987, 654, extra=321), + ] + ) + self.assertEqual(exception_callback.call_count, 6) # no additional calls diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py index 6ac4b93f98..395396340b 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py @@ -13,9 +13,13 @@ # limitations under the License. from typing import List, Optional -from synapse.storage.database import DatabasePool +from twisted.test.proto_helpers import MemoryReactor + +from synapse.server import HomeServer +from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.engines import IncorrectDatabaseSetup from synapse.storage.util.id_generators import MultiWriterIdGenerator +from synapse.util import Clock from tests.unittest import HomeserverTestCase from tests.utils import USE_POSTGRES_FOR_TESTS @@ -25,13 +29,13 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): if not USE_POSTGRES_FOR_TESTS: skip = "Requires Postgres" - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main self.db_pool: DatabasePool = self.store.db_pool self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) - def _setup_db(self, txn): + def _setup_db(self, txn: LoggingTransaction) -> None: txn.execute("CREATE SEQUENCE foobar_seq") txn.execute( """ @@ -59,12 +63,12 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): return self.get_success_or_raise(self.db_pool.runWithConnection(_create)) - def _insert_rows(self, instance_name: str, number: int): + def _insert_rows(self, instance_name: str, number: int) -> None: """Insert N rows as the given instance, inserting with stream IDs pulled from the postgres sequence. """ - def _insert(txn): + def _insert(txn: LoggingTransaction) -> None: for _ in range(number): txn.execute( "INSERT INTO foobar VALUES (nextval('foobar_seq'), ?)", @@ -80,12 +84,12 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self.get_success(self.db_pool.runInteraction("_insert_rows", _insert)) - def _insert_row_with_id(self, instance_name: str, stream_id: int): + def _insert_row_with_id(self, instance_name: str, stream_id: int) -> None: """Insert one row as the given instance with given stream_id, updating the postgres sequence position to match. """ - def _insert(txn): + def _insert(txn: LoggingTransaction) -> None: txn.execute( "INSERT INTO foobar VALUES (?, ?)", ( @@ -104,7 +108,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self.get_success(self.db_pool.runInteraction("_insert_row_with_id", _insert)) - def test_empty(self): + def test_empty(self) -> None: """Test an ID generator against an empty database gives sensible current positions. """ @@ -114,7 +118,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): # The table is empty so we expect an empty map for positions self.assertEqual(id_gen.get_positions(), {}) - def test_single_instance(self): + def test_single_instance(self) -> None: """Test that reads and writes from a single process are handled correctly. """ @@ -130,7 +134,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): # Try allocating a new ID gen and check that we only see position # advanced after we leave the context manager. - async def _get_next_async(): + async def _get_next_async() -> None: async with id_gen.get_next() as stream_id: self.assertEqual(stream_id, 8) @@ -142,7 +146,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self.assertEqual(id_gen.get_positions(), {"master": 8}) self.assertEqual(id_gen.get_current_token_for_writer("master"), 8) - def test_out_of_order_finish(self): + def test_out_of_order_finish(self) -> None: """Test that IDs persisted out of order are correctly handled""" # Prefill table with 7 rows written by 'master' @@ -191,7 +195,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self.assertEqual(id_gen.get_positions(), {"master": 11}) self.assertEqual(id_gen.get_current_token_for_writer("master"), 11) - def test_multi_instance(self): + def test_multi_instance(self) -> None: """Test that reads and writes from multiple processes are handled correctly. """ @@ -215,7 +219,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): # Try allocating a new ID gen and check that we only see position # advanced after we leave the context manager. - async def _get_next_async(): + async def _get_next_async() -> None: async with first_id_gen.get_next() as stream_id: self.assertEqual(stream_id, 8) @@ -233,7 +237,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): # ... but calling `get_next` on the second instance should give a unique # stream ID - async def _get_next_async(): + async def _get_next_async2() -> None: async with second_id_gen.get_next() as stream_id: self.assertEqual(stream_id, 9) @@ -241,7 +245,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): second_id_gen.get_positions(), {"first": 3, "second": 7} ) - self.get_success(_get_next_async()) + self.get_success(_get_next_async2()) self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 9}) @@ -249,7 +253,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): second_id_gen.advance("first", 8) self.assertEqual(second_id_gen.get_positions(), {"first": 8, "second": 9}) - def test_get_next_txn(self): + def test_get_next_txn(self) -> None: """Test that the `get_next_txn` function works correctly.""" # Prefill table with 7 rows written by 'master' @@ -263,7 +267,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): # Try allocating a new ID gen and check that we only see position # advanced after we leave the context manager. - def _get_next_txn(txn): + def _get_next_txn(txn: LoggingTransaction) -> None: stream_id = id_gen.get_next_txn(txn) self.assertEqual(stream_id, 8) @@ -275,7 +279,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self.assertEqual(id_gen.get_positions(), {"master": 8}) self.assertEqual(id_gen.get_current_token_for_writer("master"), 8) - def test_get_persisted_upto_position(self): + def test_get_persisted_upto_position(self) -> None: """Test that `get_persisted_upto_position` correctly tracks updates to positions. """ @@ -317,7 +321,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): id_gen.advance("second", 15) self.assertEqual(id_gen.get_persisted_upto_position(), 11) - def test_get_persisted_upto_position_get_next(self): + def test_get_persisted_upto_position_get_next(self) -> None: """Test that `get_persisted_upto_position` correctly tracks updates to positions when `get_next` is called. """ @@ -331,7 +335,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self.assertEqual(id_gen.get_persisted_upto_position(), 5) - async def _get_next_async(): + async def _get_next_async() -> None: async with id_gen.get_next() as stream_id: self.assertEqual(stream_id, 6) self.assertEqual(id_gen.get_persisted_upto_position(), 5) @@ -344,7 +348,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): # `persisted_upto_position` in this case, then it will be correct in the # other cases that are tested above (since they'll hit the same code). - def test_restart_during_out_of_order_persistence(self): + def test_restart_during_out_of_order_persistence(self) -> None: """Test that restarting a process while another process is writing out of order updates are handled correctly. """ @@ -388,7 +392,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): id_gen_worker.advance("master", 9) self.assertEqual(id_gen_worker.get_positions(), {"master": 9}) - def test_writer_config_change(self): + def test_writer_config_change(self) -> None: """Test that changing the writer config correctly works.""" self._insert_row_with_id("first", 3) @@ -421,7 +425,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): # Check that we get a sane next stream ID with this new config. - async def _get_next_async(): + async def _get_next_async() -> None: async with id_gen_3.get_next() as stream_id: self.assertEqual(stream_id, 6) @@ -435,7 +439,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self.assertEqual(id_gen_5.get_current_token_for_writer("first"), 6) self.assertEqual(id_gen_5.get_current_token_for_writer("third"), 6) - def test_sequence_consistency(self): + def test_sequence_consistency(self) -> None: """Test that we error out if the table and sequence diverges.""" # Prefill with some rows @@ -458,13 +462,13 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): if not USE_POSTGRES_FOR_TESTS: skip = "Requires Postgres" - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main self.db_pool: DatabasePool = self.store.db_pool self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) - def _setup_db(self, txn): + def _setup_db(self, txn: LoggingTransaction) -> None: txn.execute("CREATE SEQUENCE foobar_seq") txn.execute( """ @@ -493,10 +497,10 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): return self.get_success(self.db_pool.runWithConnection(_create)) - def _insert_row(self, instance_name: str, stream_id: int): + def _insert_row(self, instance_name: str, stream_id: int) -> None: """Insert one row as the given instance with given stream_id.""" - def _insert(txn): + def _insert(txn: LoggingTransaction) -> None: txn.execute( "INSERT INTO foobar VALUES (?, ?)", ( @@ -514,13 +518,13 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): self.get_success(self.db_pool.runInteraction("_insert_row", _insert)) - def test_single_instance(self): + def test_single_instance(self) -> None: """Test that reads and writes from a single process are handled correctly. """ id_gen = self._create_id_generator() - async def _get_next_async(): + async def _get_next_async() -> None: async with id_gen.get_next() as stream_id: self._insert_row("master", stream_id) @@ -530,7 +534,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): self.assertEqual(id_gen.get_current_token_for_writer("master"), -1) self.assertEqual(id_gen.get_persisted_upto_position(), -1) - async def _get_next_async2(): + async def _get_next_async2() -> None: async with id_gen.get_next_mult(3) as stream_ids: for stream_id in stream_ids: self._insert_row("master", stream_id) @@ -548,14 +552,14 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): self.assertEqual(second_id_gen.get_current_token_for_writer("master"), -4) self.assertEqual(second_id_gen.get_persisted_upto_position(), -4) - def test_multiple_instance(self): + def test_multiple_instance(self) -> None: """Tests that having multiple instances that get advanced over federation works corretly. """ id_gen_1 = self._create_id_generator("first", writers=["first", "second"]) id_gen_2 = self._create_id_generator("second", writers=["first", "second"]) - async def _get_next_async(): + async def _get_next_async() -> None: async with id_gen_1.get_next() as stream_id: self._insert_row("first", stream_id) id_gen_2.advance("first", stream_id) @@ -567,7 +571,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): self.assertEqual(id_gen_1.get_persisted_upto_position(), -1) self.assertEqual(id_gen_2.get_persisted_upto_position(), -1) - async def _get_next_async2(): + async def _get_next_async2() -> None: async with id_gen_2.get_next() as stream_id: self._insert_row("second", stream_id) id_gen_1.advance("second", stream_id) @@ -584,13 +588,13 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase): if not USE_POSTGRES_FOR_TESTS: skip = "Requires Postgres" - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main self.db_pool: DatabasePool = self.store.db_pool self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) - def _setup_db(self, txn): + def _setup_db(self, txn: LoggingTransaction) -> None: txn.execute("CREATE SEQUENCE foobar_seq") txn.execute( """ @@ -642,7 +646,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase): from the postgres sequence. """ - def _insert(txn): + def _insert(txn: LoggingTransaction) -> None: for _ in range(number): txn.execute( "INSERT INTO %s VALUES (nextval('foobar_seq'), ?)" % (table,), @@ -659,7 +663,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase): self.get_success(self.db_pool.runInteraction("_insert_rows", _insert)) - def test_load_existing_stream(self): + def test_load_existing_stream(self) -> None: """Test creating ID gens with multiple tables that have rows from after the position in `stream_positions` table. """ diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py index 6a1cf33054..eaa0d7d749 100644 --- a/tests/storage/test_stream.py +++ b/tests/storage/test_stream.py @@ -129,21 +129,19 @@ class PaginationTestCase(HomeserverTestCase): def test_filter_relation_senders(self): # Messages which second user reacted to. - filter = {"io.element.relation_senders": [self.second_user_id]} + filter = {"related_by_senders": [self.second_user_id]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) self.assertEqual(chunk[0].event_id, self.event_id_1) # Messages which third user reacted to. - filter = {"io.element.relation_senders": [self.third_user_id]} + filter = {"related_by_senders": [self.third_user_id]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) self.assertEqual(chunk[0].event_id, self.event_id_2) # Messages which either user reacted to. - filter = { - "io.element.relation_senders": [self.second_user_id, self.third_user_id] - } + filter = {"related_by_senders": [self.second_user_id, self.third_user_id]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 2, chunk) self.assertCountEqual( @@ -152,20 +150,20 @@ class PaginationTestCase(HomeserverTestCase): def test_filter_relation_type(self): # Messages which have annotations. - filter = {"io.element.relation_types": [RelationTypes.ANNOTATION]} + filter = {"related_by_rel_types": [RelationTypes.ANNOTATION]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) self.assertEqual(chunk[0].event_id, self.event_id_1) # Messages which have references. - filter = {"io.element.relation_types": [RelationTypes.REFERENCE]} + filter = {"related_by_rel_types": [RelationTypes.REFERENCE]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) self.assertEqual(chunk[0].event_id, self.event_id_2) # Messages which have either annotations or references. filter = { - "io.element.relation_types": [ + "related_by_rel_types": [ RelationTypes.ANNOTATION, RelationTypes.REFERENCE, ] @@ -179,8 +177,8 @@ class PaginationTestCase(HomeserverTestCase): def test_filter_relation_senders_and_type(self): # Messages which second user reacted to. filter = { - "io.element.relation_senders": [self.second_user_id], - "io.element.relation_types": [RelationTypes.ANNOTATION], + "related_by_senders": [self.second_user_id], + "related_by_rel_types": [RelationTypes.ANNOTATION], } chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) @@ -201,7 +199,7 @@ class PaginationTestCase(HomeserverTestCase): tok=self.second_tok, ) - filter = {"io.element.relation_senders": [self.second_user_id]} + filter = {"related_by_senders": [self.second_user_id]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) self.assertEqual(chunk[0].event_id, self.event_id_1) diff --git a/tests/storage/test_unsafe_locale.py b/tests/storage/test_unsafe_locale.py new file mode 100644 index 0000000000..ba53c22818 --- /dev/null +++ b/tests/storage/test_unsafe_locale.py @@ -0,0 +1,46 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest.mock import MagicMock, patch + +from synapse.storage.database import make_conn +from synapse.storage.engines._base import IncorrectDatabaseSetup + +from tests.unittest import HomeserverTestCase +from tests.utils import USE_POSTGRES_FOR_TESTS + + +class UnsafeLocaleTest(HomeserverTestCase): + if not USE_POSTGRES_FOR_TESTS: + skip = "Requires Postgres" + + @patch("synapse.storage.engines.postgres.PostgresEngine.get_db_locale") + def test_unsafe_locale(self, mock_db_locale: MagicMock) -> None: + mock_db_locale.return_value = ("B", "B") + database = self.hs.get_datastores().databases[0] + + db_conn = make_conn(database._database_config, database.engine, "test_unsafe") + with self.assertRaises(IncorrectDatabaseSetup): + database.engine.check_database(db_conn) + with self.assertRaises(IncorrectDatabaseSetup): + database.engine.check_new_database(db_conn) + db_conn.close() + + def test_safe_locale(self) -> None: + database = self.hs.get_datastores().databases[0] + + db_conn = make_conn(database._database_config, database.engine, "test_unsafe") + with db_conn.cursor() as txn: + res = database.engine.get_db_locale(txn) + self.assertEqual(res, ("C", "C")) + db_conn.close() diff --git a/tests/test_visibility.py b/tests/test_visibility.py index 219b5660b1..532e3fe9cd 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -13,11 +13,12 @@ # limitations under the License. import logging from typing import Optional +from unittest.mock import patch from synapse.api.room_versions import RoomVersions -from synapse.events import EventBase -from synapse.types import JsonDict -from synapse.visibility import filter_events_for_server +from synapse.events import EventBase, make_event_from_dict +from synapse.types import JsonDict, create_requester +from synapse.visibility import filter_events_for_client, filter_events_for_server from tests import unittest from tests.utils import create_room @@ -185,3 +186,72 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): self.get_success(self.storage.persistence.persist_event(event, context)) return event + + +class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase): + def test_out_of_band_invite_rejection(self): + # this is where we have received an invite event over federation, and then + # rejected it. + invite_pdu = { + "room_id": "!room:id", + "depth": 1, + "auth_events": [], + "prev_events": [], + "origin_server_ts": 1, + "sender": "@someone:" + self.OTHER_SERVER_NAME, + "type": "m.room.member", + "state_key": "@user:test", + "content": {"membership": "invite"}, + } + self.add_hashes_and_signatures(invite_pdu) + invite_event_id = make_event_from_dict(invite_pdu, RoomVersions.V9).event_id + + self.get_success( + self.hs.get_federation_server().on_invite_request( + self.OTHER_SERVER_NAME, + invite_pdu, + "9", + ) + ) + + # stub out do_remotely_reject_invite so that we fall back to a locally- + # generated rejection + with patch.object( + self.hs.get_federation_handler(), + "do_remotely_reject_invite", + side_effect=Exception(), + ): + reject_event_id, _ = self.get_success( + self.hs.get_room_member_handler().remote_reject_invite( + invite_event_id, + txn_id=None, + requester=create_requester("@user:test"), + content={}, + ) + ) + + invite_event, reject_event = self.get_success( + self.hs.get_datastores().main.get_events_as_list( + [invite_event_id, reject_event_id] + ) + ) + + # the invited user should be able to see both the invite and the rejection + self.assertEqual( + self.get_success( + filter_events_for_client( + self.hs.get_storage(), "@user:test", [invite_event, reject_event] + ) + ), + [invite_event, reject_event], + ) + + # other users should see neither + self.assertEqual( + self.get_success( + filter_events_for_client( + self.hs.get_storage(), "@other:test", [invite_event, reject_event] + ) + ), + [], + ) diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 19741ffcda..48e616ac74 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -17,7 +17,7 @@ from typing import Set from unittest import mock from twisted.internet import defer, reactor -from twisted.internet.defer import Deferred +from twisted.internet.defer import CancelledError, Deferred from synapse.api.errors import SynapseError from synapse.logging.context import ( @@ -28,7 +28,7 @@ from synapse.logging.context import ( make_deferred_yieldable, ) from synapse.util.caches import descriptors -from synapse.util.caches.descriptors import cached, lru_cache +from synapse.util.caches.descriptors import cached, cachedList, lru_cache from tests import unittest from tests.test_utils import get_awaitable_result @@ -141,6 +141,84 @@ class DescriptorTestCase(unittest.TestCase): self.assertEqual(r, "chips") obj.mock.assert_not_called() + @defer.inlineCallbacks + def test_cache_uncached_args(self): + """ + Only the arguments not named in uncached_args should matter to the cache + + Note that this is identical to test_cache_num_args, but provides the + arguments differently. + """ + + class Cls: + # Note that it is important that this is not the last argument to + # test behaviour of skipping arguments properly. + @descriptors.cached(uncached_args=("arg2",)) + def fn(self, arg1, arg2, arg3): + return self.mock(arg1, arg2, arg3) + + def __init__(self): + self.mock = mock.Mock() + + obj = Cls() + obj.mock.return_value = "fish" + r = yield obj.fn(1, 2, 3) + self.assertEqual(r, "fish") + obj.mock.assert_called_once_with(1, 2, 3) + obj.mock.reset_mock() + + # a call with different params should call the mock again + obj.mock.return_value = "chips" + r = yield obj.fn(2, 3, 4) + self.assertEqual(r, "chips") + obj.mock.assert_called_once_with(2, 3, 4) + obj.mock.reset_mock() + + # the two values should now be cached; we should be able to vary + # the second argument and still get the cached result. + r = yield obj.fn(1, 4, 3) + self.assertEqual(r, "fish") + r = yield obj.fn(2, 5, 4) + self.assertEqual(r, "chips") + obj.mock.assert_not_called() + + @defer.inlineCallbacks + def test_cache_kwargs(self): + """Test that keyword arguments are treated properly""" + + class Cls: + def __init__(self): + self.mock = mock.Mock() + + @descriptors.cached() + def fn(self, arg1, kwarg1=2): + return self.mock(arg1, kwarg1=kwarg1) + + obj = Cls() + obj.mock.return_value = "fish" + r = yield obj.fn(1, kwarg1=2) + self.assertEqual(r, "fish") + obj.mock.assert_called_once_with(1, kwarg1=2) + obj.mock.reset_mock() + + # a call with different params should call the mock again + obj.mock.return_value = "chips" + r = yield obj.fn(1, kwarg1=3) + self.assertEqual(r, "chips") + obj.mock.assert_called_once_with(1, kwarg1=3) + obj.mock.reset_mock() + + # the values should now be cached. + r = yield obj.fn(1, kwarg1=2) + self.assertEqual(r, "fish") + # We should be able to not provide kwarg1 and get the cached value back. + r = yield obj.fn(1) + self.assertEqual(r, "fish") + # Keyword arguments can be in any order. + r = yield obj.fn(kwarg1=2, arg1=1) + self.assertEqual(r, "fish") + obj.mock.assert_not_called() + def test_cache_with_sync_exception(self): """If the wrapped function throws synchronously, things should continue to work""" @@ -415,6 +493,74 @@ class DescriptorTestCase(unittest.TestCase): obj.invalidate() top_invalidate.assert_called_once() + def test_cancel(self): + """Test that cancelling a lookup does not cancel other lookups""" + complete_lookup: "Deferred[None]" = Deferred() + + class Cls: + @cached() + async def fn(self, arg1): + await complete_lookup + return str(arg1) + + obj = Cls() + + d1 = obj.fn(123) + d2 = obj.fn(123) + self.assertFalse(d1.called) + self.assertFalse(d2.called) + + # Cancel `d1`, which is the lookup that caused `fn` to run. + d1.cancel() + + # `d2` should complete normally. + complete_lookup.callback(None) + self.failureResultOf(d1, CancelledError) + self.assertEqual(d2.result, "123") + + def test_cancel_logcontexts(self): + """Test that cancellation does not break logcontexts. + + * The `CancelledError` must be raised with the correct logcontext. + * The inner lookup must not resume with a finished logcontext. + * The inner lookup must not restore a finished logcontext when done. + """ + complete_lookup: "Deferred[None]" = Deferred() + + class Cls: + inner_context_was_finished = False + + @cached() + async def fn(self, arg1): + await make_deferred_yieldable(complete_lookup) + self.inner_context_was_finished = current_context().finished + return str(arg1) + + obj = Cls() + + async def do_lookup(): + with LoggingContext("c1") as c1: + try: + await obj.fn(123) + self.fail("No CancelledError thrown") + except CancelledError: + self.assertEqual( + current_context(), + c1, + "CancelledError was not raised with the correct logcontext", + ) + # suppress the error and succeed + + d = defer.ensureDeferred(do_lookup()) + d.cancel() + + complete_lookup.callback(None) + self.successResultOf(d) + self.assertFalse( + obj.inner_context_was_finished, "Tried to restart a finished logcontext" + ) + self.assertEqual(current_context(), SENTINEL_CONTEXT) + class CacheDecoratorTestCase(unittest.HomeserverTestCase): """More tests for @cached @@ -656,7 +802,7 @@ class CachedListDescriptorTestCase(unittest.TestCase): def fn(self, arg1, arg2): pass - @descriptors.cachedList("fn", "args1") + @descriptors.cachedList(cached_method_name="fn", list_name="args1") async def list_fn(self, args1, arg2): assert current_context().name == "c1" # we want this to behave like an asynchronous function @@ -715,7 +861,7 @@ class CachedListDescriptorTestCase(unittest.TestCase): def fn(self, arg1): pass - @descriptors.cachedList("fn", "args1") + @descriptors.cachedList(cached_method_name="fn", list_name="args1") def list_fn(self, args1) -> "Deferred[dict]": return self.mock(args1) @@ -758,7 +904,7 @@ class CachedListDescriptorTestCase(unittest.TestCase): def fn(self, arg1, arg2): pass - @descriptors.cachedList("fn", "args1") + @descriptors.cachedList(cached_method_name="fn", list_name="args1") async def list_fn(self, args1, arg2): # we want this to behave like an asynchronous function await run_on_reactor() @@ -787,3 +933,78 @@ class CachedListDescriptorTestCase(unittest.TestCase): obj.fn.invalidate((10, 2)) invalidate0.assert_called_once() invalidate1.assert_called_once() + + def test_cancel(self): + """Test that cancelling a lookup does not cancel other lookups""" + complete_lookup: "Deferred[None]" = Deferred() + + class Cls: + @cached() + def fn(self, arg1): + pass + + @cachedList(cached_method_name="fn", list_name="args") + async def list_fn(self, args): + await complete_lookup + return {arg: str(arg) for arg in args} + + obj = Cls() + + d1 = obj.list_fn([123, 456]) + d2 = obj.list_fn([123, 456, 789]) + self.assertFalse(d1.called) + self.assertFalse(d2.called) + + d1.cancel() + + # `d2` should complete normally. + complete_lookup.callback(None) + self.failureResultOf(d1, CancelledError) + self.assertEqual(d2.result, {123: "123", 456: "456", 789: "789"}) + + def test_cancel_logcontexts(self): + """Test that cancellation does not break logcontexts. + + * The `CancelledError` must be raised with the correct logcontext. + * The inner lookup must not resume with a finished logcontext. + * The inner lookup must not restore a finished logcontext when done. + """ + complete_lookup: "Deferred[None]" = Deferred() + + class Cls: + inner_context_was_finished = False + + @cached() + def fn(self, arg1): + pass + + @cachedList(cached_method_name="fn", list_name="args") + async def list_fn(self, args): + await make_deferred_yieldable(complete_lookup) + self.inner_context_was_finished = current_context().finished + return {arg: str(arg) for arg in args} + + obj = Cls() + + async def do_lookup(): + with LoggingContext("c1") as c1: + try: + await obj.list_fn([123]) + self.fail("No CancelledError thrown") + except CancelledError: + self.assertEqual( + current_context(), + c1, + "CancelledError was not raised with the correct logcontext", + ) + # suppress the error and succeed + + d = defer.ensureDeferred(do_lookup()) + d.cancel() + + complete_lookup.callback(None) + self.successResultOf(d) + self.assertFalse( + obj.inner_context_was_finished, "Tried to restart a finished logcontext" + ) + self.assertEqual(current_context(), SENTINEL_CONTEXT) diff --git a/tests/util/test_async_helpers.py b/tests/util/test_async_helpers.py index 362014f4cb..e5bc416de1 100644 --- a/tests/util/test_async_helpers.py +++ b/tests/util/test_async_helpers.py @@ -13,6 +13,8 @@ # limitations under the License. import traceback +from parameterized import parameterized_class + from twisted.internet import defer from twisted.internet.defer import CancelledError, Deferred, ensureDeferred from twisted.internet.task import Clock @@ -23,10 +25,12 @@ from synapse.logging.context import ( LoggingContext, PreserveLoggingContext, current_context, + make_deferred_yieldable, ) from synapse.util.async_helpers import ( ObservableDeferred, concurrently_execute, + delay_cancellation, stop_cancellation, timeout_deferred, ) @@ -100,6 +104,34 @@ class ObservableDeferredTest(TestCase): self.assertEqual(str(results[0].value), "gah!", "observer 1 errback result") self.assertEqual(str(results[1].value), "gah!", "observer 2 errback result") + def test_cancellation(self): + """Test that cancelling an observer does not affect other observers.""" + origin_d: "Deferred[int]" = Deferred() + observable = ObservableDeferred(origin_d, consumeErrors=True) + + observer1 = observable.observe() + observer2 = observable.observe() + observer3 = observable.observe() + + self.assertFalse(observer1.called) + self.assertFalse(observer2.called) + self.assertFalse(observer3.called) + + # cancel the second observer + observer2.cancel() + self.assertFalse(observer1.called) + self.failureResultOf(observer2, CancelledError) + self.assertFalse(observer3.called) + + # other observers resolve as normal + origin_d.callback(123) + self.assertEqual(observer1.result, 123, "observer 1 callback result") + self.assertEqual(observer3.result, 123, "observer 3 callback result") + + # additional observers resolve as normal + observer4 = observable.observe() + self.assertEqual(observer4.result, 123, "observer 4 callback result") + class TimeoutDeferredTest(TestCase): def setUp(self): @@ -285,13 +317,27 @@ class ConcurrentlyExecuteTest(TestCase): self.successResultOf(d2) -class StopCancellationTests(TestCase): - """Tests for the `stop_cancellation` function.""" +@parameterized_class( + ("wrapper",), + [("stop_cancellation",), ("delay_cancellation",)], +) +class CancellationWrapperTests(TestCase): + """Common tests for the `stop_cancellation` and `delay_cancellation` functions.""" + + wrapper: str + + def wrap_deferred(self, deferred: "Deferred[str]") -> "Deferred[str]": + if self.wrapper == "stop_cancellation": + return stop_cancellation(deferred) + elif self.wrapper == "delay_cancellation": + return delay_cancellation(deferred) + else: + raise ValueError(f"Unsupported wrapper type: {self.wrapper}") def test_succeed(self): """Test that the new `Deferred` receives the result.""" deferred: "Deferred[str]" = Deferred() - wrapper_deferred = stop_cancellation(deferred) + wrapper_deferred = self.wrap_deferred(deferred) # Success should propagate through. deferred.callback("success") @@ -301,7 +347,7 @@ class StopCancellationTests(TestCase): def test_failure(self): """Test that the new `Deferred` receives the `Failure`.""" deferred: "Deferred[str]" = Deferred() - wrapper_deferred = stop_cancellation(deferred) + wrapper_deferred = self.wrap_deferred(deferred) # Failure should propagate through. deferred.errback(ValueError("abc")) @@ -309,6 +355,10 @@ class StopCancellationTests(TestCase): self.failureResultOf(wrapper_deferred, ValueError) self.assertIsNone(deferred.result, "`Failure` was not consumed") + +class StopCancellationTests(TestCase): + """Tests for the `stop_cancellation` function.""" + def test_cancellation(self): """Test that cancellation of the new `Deferred` leaves the original running.""" deferred: "Deferred[str]" = Deferred() @@ -319,11 +369,101 @@ class StopCancellationTests(TestCase): self.assertTrue(wrapper_deferred.called) self.failureResultOf(wrapper_deferred, CancelledError) self.assertFalse( - deferred.called, "Original `Deferred` was unexpectedly cancelled." + deferred.called, "Original `Deferred` was unexpectedly cancelled" + ) + + # Now make the original `Deferred` fail. + # The `Failure` must be consumed, otherwise unwanted tracebacks will be printed + # in logs. + deferred.errback(ValueError("abc")) + self.assertIsNone(deferred.result, "`Failure` was not consumed") + + +class DelayCancellationTests(TestCase): + """Tests for the `delay_cancellation` function.""" + + def test_cancellation(self): + """Test that cancellation of the new `Deferred` waits for the original.""" + deferred: "Deferred[str]" = Deferred() + wrapper_deferred = delay_cancellation(deferred) + + # Cancel the new `Deferred`. + wrapper_deferred.cancel() + self.assertNoResult(wrapper_deferred) + self.assertFalse( + deferred.called, "Original `Deferred` was unexpectedly cancelled" + ) + + # Now make the original `Deferred` fail. + # The `Failure` must be consumed, otherwise unwanted tracebacks will be printed + # in logs. + deferred.errback(ValueError("abc")) + self.assertIsNone(deferred.result, "`Failure` was not consumed") + + # Now that the original `Deferred` has failed, we should get a `CancelledError`. + self.failureResultOf(wrapper_deferred, CancelledError) + + def test_suppresses_second_cancellation(self): + """Test that a second cancellation is suppressed. + + Identical to `test_cancellation` except the new `Deferred` is cancelled twice. + """ + deferred: "Deferred[str]" = Deferred() + wrapper_deferred = delay_cancellation(deferred) + + # Cancel the new `Deferred`, twice. + wrapper_deferred.cancel() + wrapper_deferred.cancel() + self.assertNoResult(wrapper_deferred) + self.assertFalse( + deferred.called, "Original `Deferred` was unexpectedly cancelled" ) - # Now make the inner `Deferred` fail. + # Now make the original `Deferred` fail. # The `Failure` must be consumed, otherwise unwanted tracebacks will be printed # in logs. deferred.errback(ValueError("abc")) self.assertIsNone(deferred.result, "`Failure` was not consumed") + + # Now that the original `Deferred` has failed, we should get a `CancelledError`. + self.failureResultOf(wrapper_deferred, CancelledError) + + def test_propagates_cancelled_error(self): + """Test that a `CancelledError` from the original `Deferred` gets propagated.""" + deferred: "Deferred[str]" = Deferred() + wrapper_deferred = delay_cancellation(deferred) + + # Fail the original `Deferred` with a `CancelledError`. + cancelled_error = CancelledError() + deferred.errback(cancelled_error) + + # The new `Deferred` should fail with exactly the same `CancelledError`. + self.assertTrue(wrapper_deferred.called) + self.assertIs(cancelled_error, self.failureResultOf(wrapper_deferred).value) + + def test_preserves_logcontext(self): + """Test that logging contexts are preserved.""" + blocking_d: "Deferred[None]" = Deferred() + + async def inner(): + await make_deferred_yieldable(blocking_d) + + async def outer(): + with LoggingContext("c") as c: + try: + await delay_cancellation(defer.ensureDeferred(inner())) + self.fail("`CancelledError` was not raised") + except CancelledError: + self.assertEqual(c, current_context()) + # Succeed with no error, unless the logging context is wrong. + + # Run and block inside `inner()`. + d = defer.ensureDeferred(outer()) + self.assertEqual(SENTINEL_CONTEXT, current_context()) + + d.cancel() + + # Now unblock. `outer()` will consume the `CancelledError` and check the + # logging context. + blocking_d.callback(None) + self.successResultOf(d) diff --git a/tests/util/test_check_dependencies.py b/tests/util/test_check_dependencies.py index 3c07252252..5d1aa025d1 100644 --- a/tests/util/test_check_dependencies.py +++ b/tests/util/test_check_dependencies.py @@ -12,7 +12,7 @@ from tests.unittest import TestCase class DummyDistribution(metadata.Distribution): - def __init__(self, version: str): + def __init__(self, version: object): self._version = version @property @@ -27,7 +27,10 @@ class DummyDistribution(metadata.Distribution): old = DummyDistribution("0.1.2") +old_release_candidate = DummyDistribution("0.1.2rc3") new = DummyDistribution("1.2.3") +new_release_candidate = DummyDistribution("1.2.3rc4") +distribution_with_no_version = DummyDistribution(None) # could probably use stdlib TestCase --- no need for twisted here @@ -65,6 +68,35 @@ class TestDependencyChecker(TestCase): # should not raise check_requirements() + def test_version_reported_as_none(self) -> None: + """Complain if importlib.metadata.version() returns None. + + This shouldn't normally happen, but it was seen in the wild (#12223). + """ + with patch( + "synapse.util.check_dependencies.metadata.requires", + return_value=["dummypkg >= 1"], + ): + with self.mock_installed_package(distribution_with_no_version): + self.assertRaises(DependencyException, check_requirements) + + def test_checks_ignore_dev_dependencies(self) -> None: + """Bot generic and per-extra checks should ignore dev dependencies.""" + with patch( + "synapse.util.check_dependencies.metadata.requires", + return_value=["dummypkg >= 1; extra == 'mypy'"], + ), patch("synapse.util.check_dependencies.RUNTIME_EXTRAS", {"cool-extra"}): + # We're testing that none of these calls raise. + with self.mock_installed_package(None): + check_requirements() + check_requirements("cool-extra") + with self.mock_installed_package(old): + check_requirements() + check_requirements("cool-extra") + with self.mock_installed_package(new): + check_requirements() + check_requirements("cool-extra") + def test_generic_check_of_optional_dependency(self) -> None: """Complain if an optional package is old.""" with patch( @@ -85,11 +117,28 @@ class TestDependencyChecker(TestCase): with patch( "synapse.util.check_dependencies.metadata.requires", return_value=["dummypkg >= 1; extra == 'cool-extra'"], - ), patch("synapse.util.check_dependencies.EXTRAS", {"cool-extra"}): + ), patch("synapse.util.check_dependencies.RUNTIME_EXTRAS", {"cool-extra"}): with self.mock_installed_package(None): self.assertRaises(DependencyException, check_requirements, "cool-extra") with self.mock_installed_package(old): self.assertRaises(DependencyException, check_requirements, "cool-extra") with self.mock_installed_package(new): # should not raise + check_requirements("cool-extra") + + def test_release_candidates_satisfy_dependency(self) -> None: + """ + Tests that release candidates count as far as satisfying a dependency + is concerned. + (Regression test, see #12176.) + """ + with patch( + "synapse.util.check_dependencies.metadata.requires", + return_value=["dummypkg >= 1"], + ): + with self.mock_installed_package(old_release_candidate): + self.assertRaises(DependencyException, check_requirements) + + with self.mock_installed_package(new_release_candidate): + # should not raise check_requirements() diff --git a/tests/util/test_rwlock.py b/tests/util/test_rwlock.py index 0774625b85..0c84226197 100644 --- a/tests/util/test_rwlock.py +++ b/tests/util/test_rwlock.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import AsyncContextManager, Callable, Sequence, Tuple + from twisted.internet import defer -from twisted.internet.defer import Deferred +from twisted.internet.defer import CancelledError, Deferred from synapse.util.async_helpers import ReadWriteLock @@ -21,87 +23,187 @@ from tests import unittest class ReadWriteLockTestCase(unittest.TestCase): - def _assert_called_before_not_after(self, lst, first_false): - for i, d in enumerate(lst[:first_false]): - self.assertTrue(d.called, msg="%d was unexpectedly false" % i) + def _start_reader_or_writer( + self, + read_or_write: Callable[[str], AsyncContextManager], + key: str, + return_value: str, + ) -> Tuple["Deferred[str]", "Deferred[None]", "Deferred[None]"]: + """Starts a reader or writer which acquires the lock, blocks, then completes. + + Args: + read_or_write: A function returning a context manager for a lock. + Either a bound `ReadWriteLock.read` or `ReadWriteLock.write`. + key: The key to read or write. + return_value: A string that the reader or writer will resolve with when + done. + + Returns: + A tuple of three `Deferred`s: + * A `Deferred` that resolves with `return_value` once the reader or writer + completes successfully. + * A `Deferred` that resolves once the reader or writer acquires the lock. + * A `Deferred` that blocks the reader or writer. Must be resolved by the + caller to allow the reader or writer to release the lock and complete. + """ + acquired_d: "Deferred[None]" = Deferred() + unblock_d: "Deferred[None]" = Deferred() + + async def reader_or_writer(): + async with read_or_write(key): + acquired_d.callback(None) + await unblock_d + return return_value + + d = defer.ensureDeferred(reader_or_writer()) + return d, acquired_d, unblock_d + + def _start_blocking_reader( + self, rwlock: ReadWriteLock, key: str, return_value: str + ) -> Tuple["Deferred[str]", "Deferred[None]", "Deferred[None]"]: + """Starts a reader which acquires the lock, blocks, then releases the lock. + + See the docstring for `_start_reader_or_writer` for details about the arguments + and return values. + """ + return self._start_reader_or_writer(rwlock.read, key, return_value) + + def _start_blocking_writer( + self, rwlock: ReadWriteLock, key: str, return_value: str + ) -> Tuple["Deferred[str]", "Deferred[None]", "Deferred[None]"]: + """Starts a writer which acquires the lock, blocks, then releases the lock. + + See the docstring for `_start_reader_or_writer` for details about the arguments + and return values. + """ + return self._start_reader_or_writer(rwlock.write, key, return_value) + + def _start_nonblocking_reader( + self, rwlock: ReadWriteLock, key: str, return_value: str + ) -> Tuple["Deferred[str]", "Deferred[None]"]: + """Starts a reader which acquires the lock, then releases it immediately. + + See the docstring for `_start_reader_or_writer` for details about the arguments. + + Returns: + A tuple of two `Deferred`s: + * A `Deferred` that resolves with `return_value` once the reader completes + successfully. + * A `Deferred` that resolves once the reader acquires the lock. + """ + d, acquired_d, unblock_d = self._start_reader_or_writer( + rwlock.read, key, return_value + ) + unblock_d.callback(None) + return d, acquired_d + + def _start_nonblocking_writer( + self, rwlock: ReadWriteLock, key: str, return_value: str + ) -> Tuple["Deferred[str]", "Deferred[None]"]: + """Starts a writer which acquires the lock, then releases it immediately. + + See the docstring for `_start_reader_or_writer` for details about the arguments. + + Returns: + A tuple of two `Deferred`s: + * A `Deferred` that resolves with `return_value` once the writer completes + successfully. + * A `Deferred` that resolves once the writer acquires the lock. + """ + d, acquired_d, unblock_d = self._start_reader_or_writer( + rwlock.write, key, return_value + ) + unblock_d.callback(None) + return d, acquired_d + + def _assert_first_n_resolved( + self, deferreds: Sequence["defer.Deferred[None]"], n: int + ) -> None: + """Assert that exactly the first n `Deferred`s in the given list are resolved. - for i, d in enumerate(lst[first_false:]): + Args: + deferreds: The list of `Deferred`s to be checked. + n: The number of `Deferred`s at the start of `deferreds` that should be + resolved. + """ + for i, d in enumerate(deferreds[:n]): + self.assertTrue(d.called, msg="deferred %d was unexpectedly unresolved" % i) + + for i, d in enumerate(deferreds[n:]): self.assertFalse( - d.called, msg="%d was unexpectedly true" % (i + first_false) + d.called, msg="deferred %d was unexpectedly resolved" % (i + n) ) def test_rwlock(self): rwlock = ReadWriteLock() - - key = object() + key = "key" ds = [ - rwlock.read(key), # 0 - rwlock.read(key), # 1 - rwlock.write(key), # 2 - rwlock.write(key), # 3 - rwlock.read(key), # 4 - rwlock.read(key), # 5 - rwlock.write(key), # 6 + self._start_blocking_reader(rwlock, key, "0"), + self._start_blocking_reader(rwlock, key, "1"), + self._start_blocking_writer(rwlock, key, "2"), + self._start_blocking_writer(rwlock, key, "3"), + self._start_blocking_reader(rwlock, key, "4"), + self._start_blocking_reader(rwlock, key, "5"), + self._start_blocking_writer(rwlock, key, "6"), ] - ds = [defer.ensureDeferred(d) for d in ds] + # `Deferred`s that resolve when each reader or writer acquires the lock. + acquired_ds = [acquired_d for _, acquired_d, _ in ds] + # `Deferred`s that will trigger the release of locks when resolved. + release_ds = [release_d for _, _, release_d in ds] - self._assert_called_before_not_after(ds, 2) + # The first two readers should acquire their locks. + self._assert_first_n_resolved(acquired_ds, 2) - with ds[0].result: - self._assert_called_before_not_after(ds, 2) - self._assert_called_before_not_after(ds, 2) + # Release one of the read locks. The next writer should not acquire the lock, + # because there is another reader holding the lock. + self._assert_first_n_resolved(acquired_ds, 2) + release_ds[0].callback(None) + self._assert_first_n_resolved(acquired_ds, 2) - with ds[1].result: - self._assert_called_before_not_after(ds, 2) - self._assert_called_before_not_after(ds, 3) + # Release the other read lock. The next writer should acquire the lock. + self._assert_first_n_resolved(acquired_ds, 2) + release_ds[1].callback(None) + self._assert_first_n_resolved(acquired_ds, 3) - with ds[2].result: - self._assert_called_before_not_after(ds, 3) - self._assert_called_before_not_after(ds, 4) + # Release the write lock. The next writer should acquire the lock. + self._assert_first_n_resolved(acquired_ds, 3) + release_ds[2].callback(None) + self._assert_first_n_resolved(acquired_ds, 4) - with ds[3].result: - self._assert_called_before_not_after(ds, 4) - self._assert_called_before_not_after(ds, 6) + # Release the write lock. The next two readers should acquire locks. + self._assert_first_n_resolved(acquired_ds, 4) + release_ds[3].callback(None) + self._assert_first_n_resolved(acquired_ds, 6) - with ds[5].result: - self._assert_called_before_not_after(ds, 6) - self._assert_called_before_not_after(ds, 6) + # Release one of the read locks. The next writer should not acquire the lock, + # because there is another reader holding the lock. + self._assert_first_n_resolved(acquired_ds, 6) + release_ds[5].callback(None) + self._assert_first_n_resolved(acquired_ds, 6) - with ds[4].result: - self._assert_called_before_not_after(ds, 6) - self._assert_called_before_not_after(ds, 7) + # Release the other read lock. The next writer should acquire the lock. + self._assert_first_n_resolved(acquired_ds, 6) + release_ds[4].callback(None) + self._assert_first_n_resolved(acquired_ds, 7) - with ds[6].result: - pass + # Release the write lock. + release_ds[6].callback(None) - d = defer.ensureDeferred(rwlock.write(key)) - self.assertTrue(d.called) - with d.result: - pass + # Acquire and release the write and read locks one last time for good measure. + _, acquired_d = self._start_nonblocking_writer(rwlock, key, "last writer") + self.assertTrue(acquired_d.called) - d = defer.ensureDeferred(rwlock.read(key)) - self.assertTrue(d.called) - with d.result: - pass + _, acquired_d = self._start_nonblocking_reader(rwlock, key, "last reader") + self.assertTrue(acquired_d.called) def test_lock_handoff_to_nonblocking_writer(self): """Test a writer handing the lock to another writer that completes instantly.""" rwlock = ReadWriteLock() key = "key" - unblock: "Deferred[None]" = Deferred() - - async def blocking_write(): - with await rwlock.write(key): - await unblock - - async def nonblocking_write(): - with await rwlock.write(key): - pass - - d1 = defer.ensureDeferred(blocking_write()) - d2 = defer.ensureDeferred(nonblocking_write()) + d1, _, unblock = self._start_blocking_writer(rwlock, key, "write 1 completed") + d2, _ = self._start_nonblocking_writer(rwlock, key, "write 2 completed") self.assertFalse(d1.called) self.assertFalse(d2.called) @@ -111,5 +213,182 @@ class ReadWriteLockTestCase(unittest.TestCase): self.assertTrue(d2.called) # The `ReadWriteLock` should operate as normal. - d3 = defer.ensureDeferred(nonblocking_write()) + d3, _ = self._start_nonblocking_writer(rwlock, key, "write 3 completed") self.assertTrue(d3.called) + + def test_cancellation_while_holding_read_lock(self): + """Test cancellation while holding a read lock. + + A waiting writer should be given the lock when the reader holding the lock is + cancelled. + """ + rwlock = ReadWriteLock() + key = "key" + + # 1. A reader takes the lock and blocks. + reader_d, _, _ = self._start_blocking_reader(rwlock, key, "read completed") + + # 2. A writer waits for the reader to complete. + writer_d, _ = self._start_nonblocking_writer(rwlock, key, "write completed") + self.assertFalse(writer_d.called) + + # 3. The reader is cancelled. + reader_d.cancel() + self.failureResultOf(reader_d, CancelledError) + + # 4. The writer should take the lock and complete. + self.assertTrue( + writer_d.called, "Writer is stuck waiting for a cancelled reader" + ) + self.assertEqual("write completed", self.successResultOf(writer_d)) + + def test_cancellation_while_holding_write_lock(self): + """Test cancellation while holding a write lock. + + A waiting reader should be given the lock when the writer holding the lock is + cancelled. + """ + rwlock = ReadWriteLock() + key = "key" + + # 1. A writer takes the lock and blocks. + writer_d, _, _ = self._start_blocking_writer(rwlock, key, "write completed") + + # 2. A reader waits for the writer to complete. + reader_d, _ = self._start_nonblocking_reader(rwlock, key, "read completed") + self.assertFalse(reader_d.called) + + # 3. The writer is cancelled. + writer_d.cancel() + self.failureResultOf(writer_d, CancelledError) + + # 4. The reader should take the lock and complete. + self.assertTrue( + reader_d.called, "Reader is stuck waiting for a cancelled writer" + ) + self.assertEqual("read completed", self.successResultOf(reader_d)) + + def test_cancellation_while_waiting_for_read_lock(self): + """Test cancellation while waiting for a read lock. + + Tests that cancelling a waiting reader: + * does not cancel the writer it is waiting on + * does not cancel the next writer waiting on it + * does not allow the next writer to acquire the lock before an earlier writer + has finished + * does not keep the next writer waiting indefinitely + + These correspond to the asserts with explicit messages. + """ + rwlock = ReadWriteLock() + key = "key" + + # 1. A writer takes the lock and blocks. + writer1_d, _, unblock_writer1 = self._start_blocking_writer( + rwlock, key, "write 1 completed" + ) + + # 2. A reader waits for the first writer to complete. + # This reader will be cancelled later. + reader_d, _ = self._start_nonblocking_reader(rwlock, key, "read completed") + self.assertFalse(reader_d.called) + + # 3. A second writer waits for both the first writer and the reader to complete. + writer2_d, _ = self._start_nonblocking_writer(rwlock, key, "write 2 completed") + self.assertFalse(writer2_d.called) + + # 4. The waiting reader is cancelled. + # Neither of the writers should be cancelled. + # The second writer should still be waiting, but only on the first writer. + reader_d.cancel() + self.failureResultOf(reader_d, CancelledError) + self.assertFalse(writer1_d.called, "First writer was unexpectedly cancelled") + self.assertFalse( + writer2_d.called, + "Second writer was unexpectedly cancelled or given the lock before the " + "first writer finished", + ) + + # 5. Unblock the first writer, which should complete. + unblock_writer1.callback(None) + self.assertEqual("write 1 completed", self.successResultOf(writer1_d)) + + # 6. The second writer should take the lock and complete. + self.assertTrue( + writer2_d.called, "Second writer is stuck waiting for a cancelled reader" + ) + self.assertEqual("write 2 completed", self.successResultOf(writer2_d)) + + def test_cancellation_while_waiting_for_write_lock(self): + """Test cancellation while waiting for a write lock. + + Tests that cancelling a waiting writer: + * does not cancel the reader or writer it is waiting on + * does not cancel the next writer waiting on it + * does not allow the next writer to acquire the lock before an earlier reader + and writer have finished + * does not keep the next writer waiting indefinitely + + These correspond to the asserts with explicit messages. + """ + rwlock = ReadWriteLock() + key = "key" + + # 1. A reader takes the lock and blocks. + reader_d, _, unblock_reader = self._start_blocking_reader( + rwlock, key, "read completed" + ) + + # 2. A writer waits for the reader to complete. + writer1_d, _, unblock_writer1 = self._start_blocking_writer( + rwlock, key, "write 1 completed" + ) + + # 3. A second writer waits for both the reader and first writer to complete. + # This writer will be cancelled later. + writer2_d, _ = self._start_nonblocking_writer(rwlock, key, "write 2 completed") + self.assertFalse(writer2_d.called) + + # 4. A third writer waits for the second writer to complete. + writer3_d, _ = self._start_nonblocking_writer(rwlock, key, "write 3 completed") + self.assertFalse(writer3_d.called) + + # 5. The second writer is cancelled, but continues waiting for the lock. + # The reader, first writer and third writer should not be cancelled. + # The first writer should still be waiting on the reader. + # The third writer should still be waiting on the second writer. + writer2_d.cancel() + self.assertNoResult(writer2_d) + self.assertFalse(reader_d.called, "Reader was unexpectedly cancelled") + self.assertFalse(writer1_d.called, "First writer was unexpectedly cancelled") + self.assertFalse( + writer3_d.called, + "Third writer was unexpectedly cancelled or given the lock before the first " + "writer finished", + ) + + # 6. Unblock the reader, which should complete. + # The first writer should be given the lock and block. + # The third writer should still be waiting on the second writer. + unblock_reader.callback(None) + self.assertEqual("read completed", self.successResultOf(reader_d)) + self.assertNoResult(writer2_d) + self.assertFalse( + writer3_d.called, + "Third writer was unexpectedly given the lock before the first writer " + "finished", + ) + + # 7. Unblock the first writer, which should complete. + unblock_writer1.callback(None) + self.assertEqual("write 1 completed", self.successResultOf(writer1_d)) + + # 8. The second writer should take the lock and release it immediately, since it + # has been cancelled. + self.failureResultOf(writer2_d, CancelledError) + + # 9. The third writer should take the lock and complete. + self.assertTrue( + writer3_d.called, "Third writer is stuck waiting for a cancelled writer" + ) + self.assertEqual("write 3 completed", self.successResultOf(writer3_d)) diff --git a/tests/utils.py b/tests/utils.py index ef99c72e0b..f6b1d60371 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -15,13 +15,8 @@ import atexit import os -from unittest.mock import Mock, patch -from urllib import parse as urlparse - -from twisted.internet import defer from synapse.api.constants import EventTypes -from synapse.api.errors import CodeMessageException, cs_error from synapse.api.room_versions import RoomVersions from synapse.config.homeserver import HomeServerConfig from synapse.config.server import DEFAULT_ROOM_VERSION @@ -187,111 +182,6 @@ def mock_getRawHeaders(headers=None): return getRawHeaders -# This is a mock /resource/ not an entire server -class MockHttpResource: - def __init__(self, prefix=""): - self.callbacks = [] # 3-tuple of method/pattern/function - self.prefix = prefix - - def trigger_get(self, path): - return self.trigger(b"GET", path, None) - - @patch("twisted.web.http.Request") - @defer.inlineCallbacks - def trigger( - self, http_method, path, content, mock_request, federation_auth_origin=None - ): - """Fire an HTTP event. - - Args: - http_method : The HTTP method - path : The HTTP path - content : The HTTP body - mock_request : Mocked request to pass to the event so it can get - content. - federation_auth_origin (bytes|None): domain to authenticate as, for federation - Returns: - A tuple of (code, response) - Raises: - KeyError If no event is found which will handle the path. - """ - path = self.prefix + path - - # annoyingly we return a twisted http request which has chained calls - # to get at the http content, hence mock it here. - mock_content = Mock() - config = {"read.return_value": content} - mock_content.configure_mock(**config) - mock_request.content = mock_content - - mock_request.method = http_method.encode("ascii") - mock_request.uri = path.encode("ascii") - - mock_request.getClientIP.return_value = "-" - - headers = {} - if federation_auth_origin is not None: - headers[b"Authorization"] = [ - b"X-Matrix origin=%s,key=,sig=" % (federation_auth_origin,) - ] - mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers) - - # return the right path if the event requires it - mock_request.path = path - - # add in query params to the right place - try: - mock_request.args = urlparse.parse_qs(path.split("?")[1]) - mock_request.path = path.split("?")[0] - path = mock_request.path - except Exception: - pass - - if isinstance(path, bytes): - path = path.decode("utf8") - - for (method, pattern, func) in self.callbacks: - if http_method != method: - continue - - matcher = pattern.match(path) - if matcher: - try: - args = [urlparse.unquote(u) for u in matcher.groups()] - - (code, response) = yield defer.ensureDeferred( - func(mock_request, *args) - ) - return code, response - except CodeMessageException as e: - return e.code, cs_error(e.msg, code=e.errcode) - - raise KeyError("No event can handle %s" % path) - - def register_paths(self, method, path_patterns, callback, servlet_name): - for path_pattern in path_patterns: - self.callbacks.append((method, path_pattern, callback)) - - -class MockKey: - alg = "mock_alg" - version = "mock_version" - signature = b"\x9a\x87$" - - @property - def verify_key(self): - return self - - def sign(self, message): - return self - - def verify(self, message, sig): - assert sig == b"\x9a\x87$" - - def encode(self): - return b"<fake_encoded_key>" - - class MockClock: now = 1000 |