From e691243e191d9dad2bcbf55f9659d007f75fd28e Mon Sep 17 00:00:00 2001 From: David Robertson Date: Thu, 24 Aug 2023 15:53:07 +0100 Subject: Fix typechecking with twisted trunk (#16121) --- synapse/util/caches/deferred_cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/util') diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py index bf7bd351e0..029eedcc6f 100644 --- a/synapse/util/caches/deferred_cache.py +++ b/synapse/util/caches/deferred_cache.py @@ -470,7 +470,7 @@ class CacheMultipleEntries(CacheEntry[KT, VT]): def deferred(self, key: KT) -> "defer.Deferred[VT]": if not self._deferred: self._deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True) - return self._deferred.observe().addCallback(lambda res: res.get(key)) + return self._deferred.observe().addCallback(lambda res: res[key]) def add_invalidation_callback( self, key: KT, callback: Optional[Callable[[], None]] -- cgit 1.5.1 From 501da8ecd8f056fb953fbccb43fc60ba9edb91d5 Mon Sep 17 00:00:00 2001 From: Mathieu Velten Date: Mon, 28 Aug 2023 16:03:51 +0200 Subject: Task scheduler: add replication notify for new task to launch ASAP (#16184) --- changelog.d/16184.misc | 1 + synapse/replication/tcp/commands.py | 12 +++++ synapse/replication/tcp/handler.py | 18 ++++++++ synapse/util/task_scheduler.py | 92 +++++++++++++++++-------------------- tests/util/test_task_scheduler.py | 58 +++++++++++++++-------- 5 files changed, 114 insertions(+), 67 deletions(-) create mode 100644 changelog.d/16184.misc (limited to 'synapse/util') diff --git a/changelog.d/16184.misc b/changelog.d/16184.misc new file mode 100644 index 0000000000..3c0baddfe1 --- /dev/null +++ b/changelog.d/16184.misc @@ -0,0 +1 @@ +Task scheduler: add replication notify for new task to launch ASAP. diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index 10f5c98ff8..58a871c6d9 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -452,6 +452,17 @@ class LockReleasedCommand(Command): return json_encoder.encode([self.instance_name, self.lock_name, self.lock_key]) +class NewActiveTaskCommand(_SimpleCommand): + """Sent to inform instance handling background tasks that a new active task is available to run. + + Format:: + + NEW_ACTIVE_TASK "" + """ + + NAME = "NEW_ACTIVE_TASK" + + _COMMANDS: Tuple[Type[Command], ...] = ( ServerCommand, RdataCommand, @@ -466,6 +477,7 @@ _COMMANDS: Tuple[Type[Command], ...] = ( RemoteServerUpCommand, ClearUserSyncsCommand, LockReleasedCommand, + NewActiveTaskCommand, ) # Map of command name to command type. diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 38adcbe1d0..92c5a55acc 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -40,6 +40,7 @@ from synapse.replication.tcp.commands import ( Command, FederationAckCommand, LockReleasedCommand, + NewActiveTaskCommand, PositionCommand, RdataCommand, RemoteServerUpCommand, @@ -238,6 +239,10 @@ class ReplicationCommandHandler: if self._is_master: self._server_notices_sender = hs.get_server_notices_sender() + self._task_scheduler = None + if hs.config.worker.run_background_tasks: + self._task_scheduler = hs.get_task_scheduler() + if hs.config.redis.redis_enabled: # If we're using Redis, it's the background worker that should # receive USER_IP commands and store the relevant client IPs. @@ -663,6 +668,15 @@ class ReplicationCommandHandler: cmd.instance_name, cmd.lock_name, cmd.lock_key ) + async def on_NEW_ACTIVE_TASK( + self, conn: IReplicationConnection, cmd: NewActiveTaskCommand + ) -> None: + """Called when get a new NEW_ACTIVE_TASK command.""" + if self._task_scheduler: + task = await self._task_scheduler.get_task(cmd.data) + if task: + await self._task_scheduler._launch_task(task) + def new_connection(self, connection: IReplicationConnection) -> None: """Called when we have a new connection.""" self._connections.append(connection) @@ -776,6 +790,10 @@ class ReplicationCommandHandler: if instance_name == self._instance_name: self.send_command(LockReleasedCommand(instance_name, lock_name, lock_key)) + def send_new_active_task(self, task_id: str) -> None: + """Called when a new task has been scheduled for immediate launch and is ACTIVE.""" + self.send_command(NewActiveTaskCommand(task_id)) + UpdateToken = TypeVar("UpdateToken") UpdateRow = TypeVar("UpdateRow") diff --git a/synapse/util/task_scheduler.py b/synapse/util/task_scheduler.py index 4aea64b338..9e89aeb748 100644 --- a/synapse/util/task_scheduler.py +++ b/synapse/util/task_scheduler.py @@ -57,14 +57,13 @@ class TaskScheduler: the code launching the task. You can also specify the `result` (and/or an `error`) when returning from the function. - The reconciliation loop runs every 5 mns, so this is not a precise scheduler. When wanting - to launch now, the launch will still not happen before the next loop run. - - Tasks will be run on the worker specified with `run_background_tasks_on` config, - or the main one by default. + The reconciliation loop runs every minute, so this is not a precise scheduler. There is a limit of 10 concurrent tasks, so tasks may be delayed if the pool is already full. In this regard, please take great care that scheduled tasks can actually finished. For now there is no mechanism to stop a running task if it is stuck. + + Tasks will be run on the worker specified with `run_background_tasks_on` config, + or the main one by default. """ # Precision of the scheduler, evaluation of tasks to run will only happen @@ -85,7 +84,7 @@ class TaskScheduler: self._actions: Dict[ str, Callable[ - [ScheduledTask, bool], + [ScheduledTask], Awaitable[Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]], ], ] = {} @@ -98,11 +97,13 @@ class TaskScheduler: "handle_scheduled_tasks", self._handle_scheduled_tasks, ) + else: + self.replication_client = hs.get_replication_command_handler() def register_action( self, function: Callable[ - [ScheduledTask, bool], + [ScheduledTask], Awaitable[Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]], ], action_name: str, @@ -115,10 +116,9 @@ class TaskScheduler: calling `schedule_task` but rather in an `__init__` method. Args: - function: The function to be executed for this action. The parameters - passed to the function when launched are the `ScheduledTask` being run, - and a `first_launch` boolean to signal if it's a resumed task or the first - launch of it. The function should return a tuple of new `status`, `result` + function: The function to be executed for this action. The parameter + passed to the function when launched is the `ScheduledTask` being run. + The function should return a tuple of new `status`, `result` and `error` as specified in `ScheduledTask`. action_name: The name of the action to be associated with the function """ @@ -171,6 +171,12 @@ class TaskScheduler: ) await self._store.insert_scheduled_task(task) + if status == TaskStatus.ACTIVE: + if self._run_background_tasks: + await self._launch_task(task) + else: + self.replication_client.send_new_active_task(task.id) + return task.id async def update_task( @@ -265,21 +271,13 @@ class TaskScheduler: Args: id: id of the task to delete """ - if self.task_is_running(id): - raise Exception(f"Task {id} is currently running and can't be deleted") + task = await self.get_task(id) + if task is None: + raise Exception(f"Task {id} does not exist") + if task.status == TaskStatus.ACTIVE: + raise Exception(f"Task {id} is currently ACTIVE and can't be deleted") await self._store.delete_scheduled_task(id) - def task_is_running(self, id: str) -> bool: - """Check if a task is currently running. - - Can only be called from the worker handling the task scheduling. - - Args: - id: id of the task to check - """ - assert self._run_background_tasks - return id in self._running_tasks - async def _handle_scheduled_tasks(self) -> None: """Main loop taking care of launching tasks and cleaning up old ones.""" await self._launch_scheduled_tasks() @@ -288,29 +286,11 @@ class TaskScheduler: async def _launch_scheduled_tasks(self) -> None: """Retrieve and launch scheduled tasks that should be running at that time.""" for task in await self.get_tasks(statuses=[TaskStatus.ACTIVE]): - if not self.task_is_running(task.id): - if ( - len(self._running_tasks) - < TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS - ): - await self._launch_task(task, first_launch=False) - else: - if ( - self._clock.time_msec() - > task.timestamp + TaskScheduler.LAST_UPDATE_BEFORE_WARNING_MS - ): - logger.warn( - f"Task {task.id} (action {task.action}) has seen no update for more than 24h and may be stuck" - ) + await self._launch_task(task) for task in await self.get_tasks( statuses=[TaskStatus.SCHEDULED], max_timestamp=self._clock.time_msec() ): - if ( - not self.task_is_running(task.id) - and len(self._running_tasks) - < TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS - ): - await self._launch_task(task, first_launch=True) + await self._launch_task(task) running_tasks_gauge.set(len(self._running_tasks)) @@ -320,27 +300,27 @@ class TaskScheduler: statuses=[TaskStatus.FAILED, TaskStatus.COMPLETE] ): # FAILED and COMPLETE tasks should never be running - assert not self.task_is_running(task.id) + assert task.id not in self._running_tasks if ( self._clock.time_msec() > task.timestamp + TaskScheduler.KEEP_TASKS_FOR_MS ): await self._store.delete_scheduled_task(task.id) - async def _launch_task(self, task: ScheduledTask, first_launch: bool) -> None: + async def _launch_task(self, task: ScheduledTask) -> None: """Launch a scheduled task now. Args: task: the task to launch - first_launch: `True` if it's the first time is launched, `False` otherwise """ - assert task.action in self._actions + assert self._run_background_tasks + assert task.action in self._actions function = self._actions[task.action] async def wrapper() -> None: try: - (status, result, error) = await function(task, first_launch) + (status, result, error) = await function(task) except Exception: f = Failure() logger.error( @@ -360,6 +340,20 @@ class TaskScheduler: ) self._running_tasks.remove(task.id) + if len(self._running_tasks) >= TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS: + return + + if ( + self._clock.time_msec() + > task.timestamp + TaskScheduler.LAST_UPDATE_BEFORE_WARNING_MS + ): + logger.warn( + f"Task {task.id} (action {task.action}) has seen no update for more than 24h and may be stuck" + ) + + if task.id in self._running_tasks: + return + self._running_tasks.add(task.id) await self.update_task(task.id, status=TaskStatus.ACTIVE) description = f"{task.id}-{task.action}" diff --git a/tests/util/test_task_scheduler.py b/tests/util/test_task_scheduler.py index 3a97559bf0..8665aeb50c 100644 --- a/tests/util/test_task_scheduler.py +++ b/tests/util/test_task_scheduler.py @@ -22,10 +22,11 @@ from synapse.types import JsonMapping, ScheduledTask, TaskStatus from synapse.util import Clock from synapse.util.task_scheduler import TaskScheduler -from tests import unittest +from tests.replication._base import BaseMultiWorkerStreamTestCase +from tests.unittest import HomeserverTestCase, override_config -class TestTaskScheduler(unittest.HomeserverTestCase): +class TestTaskScheduler(HomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.task_scheduler = hs.get_task_scheduler() self.task_scheduler.register_action(self._test_task, "_test_task") @@ -34,7 +35,7 @@ class TestTaskScheduler(unittest.HomeserverTestCase): self.task_scheduler.register_action(self._resumable_task, "_resumable_task") async def _test_task( - self, task: ScheduledTask, first_launch: bool + self, task: ScheduledTask ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]: # This test task will copy the parameters to the result result = None @@ -77,7 +78,7 @@ class TestTaskScheduler(unittest.HomeserverTestCase): self.assertIsNone(task) async def _sleeping_task( - self, task: ScheduledTask, first_launch: bool + self, task: ScheduledTask ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]: # Sleep for a second await deferLater(self.reactor, 1, lambda: None) @@ -85,24 +86,18 @@ class TestTaskScheduler(unittest.HomeserverTestCase): def test_schedule_lot_of_tasks(self) -> None: """Schedule more than `TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS` tasks and check the behavior.""" - timestamp = self.clock.time_msec() + 30 * 1000 task_ids = [] for i in range(TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS + 1): task_ids.append( self.get_success( self.task_scheduler.schedule_task( "_sleeping_task", - timestamp=timestamp, params={"val": i}, ) ) ) - # The timestamp being 30s after now the task should been executed - # after the first scheduling loop is run - self.reactor.advance((TaskScheduler.SCHEDULE_INTERVAL_MS / 1000)) - - # This is to give the time to the sleeping tasks to finish + # This is to give the time to the active tasks to finish self.reactor.advance(1) # Check that only MAX_CONCURRENT_RUNNING_TASKS tasks has run and that one @@ -120,10 +115,11 @@ class TestTaskScheduler(unittest.HomeserverTestCase): ) scheduled_tasks = [ - t for t in tasks if t is not None and t.status == TaskStatus.SCHEDULED + t for t in tasks if t is not None and t.status == TaskStatus.ACTIVE ] self.assertEquals(len(scheduled_tasks), 1) + # We need to wait for the next run of the scheduler loop self.reactor.advance((TaskScheduler.SCHEDULE_INTERVAL_MS / 1000)) self.reactor.advance(1) @@ -138,7 +134,7 @@ class TestTaskScheduler(unittest.HomeserverTestCase): ) async def _raising_task( - self, task: ScheduledTask, first_launch: bool + self, task: ScheduledTask ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]: raise Exception("raising") @@ -146,15 +142,13 @@ class TestTaskScheduler(unittest.HomeserverTestCase): """Schedule a task raising an exception and check it runs to failure and report exception content.""" task_id = self.get_success(self.task_scheduler.schedule_task("_raising_task")) - self.reactor.advance((TaskScheduler.SCHEDULE_INTERVAL_MS / 1000)) - task = self.get_success(self.task_scheduler.get_task(task_id)) assert task is not None self.assertEqual(task.status, TaskStatus.FAILED) self.assertEqual(task.error, "raising") async def _resumable_task( - self, task: ScheduledTask, first_launch: bool + self, task: ScheduledTask ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]: if task.result and "in_progress" in task.result: return TaskStatus.COMPLETE, {"success": True}, None @@ -169,8 +163,6 @@ class TestTaskScheduler(unittest.HomeserverTestCase): """Schedule a resumable task and check that it gets properly resumed and complete after simulating a synapse restart.""" task_id = self.get_success(self.task_scheduler.schedule_task("_resumable_task")) - self.reactor.advance((TaskScheduler.SCHEDULE_INTERVAL_MS / 1000)) - task = self.get_success(self.task_scheduler.get_task(task_id)) assert task is not None self.assertEqual(task.status, TaskStatus.ACTIVE) @@ -184,3 +176,33 @@ class TestTaskScheduler(unittest.HomeserverTestCase): self.assertEqual(task.status, TaskStatus.COMPLETE) assert task.result is not None self.assertTrue(task.result.get("success")) + + +class TestTaskSchedulerWithBackgroundWorker(BaseMultiWorkerStreamTestCase): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.task_scheduler = hs.get_task_scheduler() + self.task_scheduler.register_action(self._test_task, "_test_task") + + async def _test_task( + self, task: ScheduledTask + ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]: + return (TaskStatus.COMPLETE, None, None) + + @override_config({"run_background_tasks_on": "worker1"}) + def test_schedule_task(self) -> None: + """Check that a task scheduled to run now is launch right away on the background worker.""" + bg_worker_hs = self.make_worker_hs( + "synapse.app.generic_worker", + extra_config={"worker_name": "worker1"}, + ) + bg_worker_hs.get_task_scheduler().register_action(self._test_task, "_test_task") + + task_id = self.get_success( + self.task_scheduler.schedule_task( + "_test_task", + ) + ) + + task = self.get_success(self.task_scheduler.get_task(task_id)) + assert task is not None + self.assertEqual(task.status, TaskStatus.COMPLETE) -- cgit 1.5.1 From 9ec3da06daf70b5e799545a6e12ead4846559d80 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 29 Aug 2023 10:38:56 -0400 Subject: Bump mypy-zope & mypy. (#16188) --- changelog.d/16188.misc | 1 + poetry.lock | 72 +++++++++++----------- synapse/_scripts/synapse_port_db.py | 9 ++- synapse/logging/opentracing.py | 14 ++--- synapse/storage/database.py | 17 ++++- synapse/util/check_dependencies.py | 6 +- tests/appservice/test_api.py | 6 +- tests/federation/test_complexity.py | 24 ++++---- tests/federation/test_federation_catch_up.py | 4 +- tests/federation/test_federation_sender.py | 4 +- tests/federation/transport/test_knocking.py | 4 +- tests/handlers/test_appservice.py | 10 +-- tests/handlers/test_cas.py | 8 +-- tests/handlers/test_e2e_keys.py | 4 +- tests/handlers/test_federation.py | 4 +- tests/handlers/test_oidc.py | 4 +- tests/handlers/test_password_providers.py | 2 +- tests/handlers/test_register.py | 6 +- tests/handlers/test_saml.py | 14 ++--- tests/handlers/test_typing.py | 26 ++++---- tests/logging/test_terse_json.py | 2 +- tests/module_api/test_api.py | 4 +- tests/push/test_bulk_push_rule_evaluator.py | 2 +- tests/replication/storage/test_events.py | 2 +- tests/rest/admin/test_user.py | 4 +- tests/rest/admin/test_username_available.py | 2 +- tests/rest/client/test_account.py | 2 +- tests/rest/client/test_events.py | 2 +- tests/rest/client/test_filter.py | 4 +- tests/rest/client/test_rooms.py | 12 ++-- tests/rest/client/test_shadow_banned.py | 2 +- tests/rest/client/test_third_party_rules.py | 2 +- tests/server.py | 2 +- .../test_resource_limits_server_notices.py | 30 ++++----- tests/storage/test_appservice.py | 2 +- tests/storage/test_monthly_active_users.py | 12 ++-- tests/test_federation.py | 6 +- tests/test_state.py | 4 +- tests/unittest.py | 6 +- 39 files changed, 180 insertions(+), 161 deletions(-) create mode 100644 changelog.d/16188.misc (limited to 'synapse/util') diff --git a/changelog.d/16188.misc b/changelog.d/16188.misc new file mode 100644 index 0000000000..93ceaeafc9 --- /dev/null +++ b/changelog.d/16188.misc @@ -0,0 +1 @@ +Improve type hints. diff --git a/poetry.lock b/poetry.lock index 1d37c88328..6d63d71b2c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1445,43 +1445,43 @@ files = [ [[package]] name = "mypy" -version = "1.0.1" +version = "1.4.1" description = "Optional static typing for Python" optional = false python-versions = ">=3.7" files = [ - {file = "mypy-1.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:71a808334d3f41ef011faa5a5cd8153606df5fc0b56de5b2e89566c8093a0c9a"}, - {file = "mypy-1.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:920169f0184215eef19294fa86ea49ffd4635dedfdea2b57e45cb4ee85d5ccaf"}, - {file = "mypy-1.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:27a0f74a298769d9fdc8498fcb4f2beb86f0564bcdb1a37b58cbbe78e55cf8c0"}, - {file = "mypy-1.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:65b122a993d9c81ea0bfde7689b3365318a88bde952e4dfa1b3a8b4ac05d168b"}, - {file = "mypy-1.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:5deb252fd42a77add936b463033a59b8e48eb2eaec2976d76b6878d031933fe4"}, - {file = "mypy-1.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:2013226d17f20468f34feddd6aae4635a55f79626549099354ce641bc7d40262"}, - {file = "mypy-1.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:48525aec92b47baed9b3380371ab8ab6e63a5aab317347dfe9e55e02aaad22e8"}, - {file = "mypy-1.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c96b8a0c019fe29040d520d9257d8c8f122a7343a8307bf8d6d4a43f5c5bfcc8"}, - {file = "mypy-1.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:448de661536d270ce04f2d7dddaa49b2fdba6e3bd8a83212164d4174ff43aa65"}, - {file = "mypy-1.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:d42a98e76070a365a1d1c220fcac8aa4ada12ae0db679cb4d910fabefc88b994"}, - {file = "mypy-1.0.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:e64f48c6176e243ad015e995de05af7f22bbe370dbb5b32bd6988438ec873919"}, - {file = "mypy-1.0.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5fdd63e4f50e3538617887e9aee91855368d9fc1dea30da743837b0df7373bc4"}, - {file = "mypy-1.0.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:dbeb24514c4acbc78d205f85dd0e800f34062efcc1f4a4857c57e4b4b8712bff"}, - {file = "mypy-1.0.1-cp37-cp37m-win_amd64.whl", hash = "sha256:a2948c40a7dd46c1c33765718936669dc1f628f134013b02ff5ac6c7ef6942bf"}, - {file = "mypy-1.0.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:5bc8d6bd3b274dd3846597855d96d38d947aedba18776aa998a8d46fabdaed76"}, - {file = "mypy-1.0.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:17455cda53eeee0a4adb6371a21dd3dbf465897de82843751cf822605d152c8c"}, - {file = "mypy-1.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e831662208055b006eef68392a768ff83596035ffd6d846786578ba1714ba8f6"}, - {file = "mypy-1.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:e60d0b09f62ae97a94605c3f73fd952395286cf3e3b9e7b97f60b01ddfbbda88"}, - {file = "mypy-1.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:0af4f0e20706aadf4e6f8f8dc5ab739089146b83fd53cb4a7e0e850ef3de0bb6"}, - {file = "mypy-1.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:24189f23dc66f83b839bd1cce2dfc356020dfc9a8bae03978477b15be61b062e"}, - {file = "mypy-1.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:93a85495fb13dc484251b4c1fd7a5ac370cd0d812bbfc3b39c1bafefe95275d5"}, - {file = "mypy-1.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5f546ac34093c6ce33f6278f7c88f0f147a4849386d3bf3ae193702f4fe31407"}, - {file = "mypy-1.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:c6c2ccb7af7154673c591189c3687b013122c5a891bb5651eca3db8e6c6c55bd"}, - {file = "mypy-1.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:15b5a824b58c7c822c51bc66308e759243c32631896743f030daf449fe3677f3"}, - {file = "mypy-1.0.1-py3-none-any.whl", hash = "sha256:eda5c8b9949ed411ff752b9a01adda31afe7eae1e53e946dbdf9db23865e66c4"}, - {file = "mypy-1.0.1.tar.gz", hash = "sha256:28cea5a6392bb43d266782983b5a4216c25544cd7d80be681a155ddcdafd152d"}, -] - -[package.dependencies] -mypy-extensions = ">=0.4.3" + {file = "mypy-1.4.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:566e72b0cd6598503e48ea610e0052d1b8168e60a46e0bfd34b3acf2d57f96a8"}, + {file = "mypy-1.4.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ca637024ca67ab24a7fd6f65d280572c3794665eaf5edcc7e90a866544076878"}, + {file = "mypy-1.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0dde1d180cd84f0624c5dcaaa89c89775550a675aff96b5848de78fb11adabcd"}, + {file = "mypy-1.4.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8c4d8e89aa7de683e2056a581ce63c46a0c41e31bd2b6d34144e2c80f5ea53dc"}, + {file = "mypy-1.4.1-cp310-cp310-win_amd64.whl", hash = "sha256:bfdca17c36ae01a21274a3c387a63aa1aafe72bff976522886869ef131b937f1"}, + {file = "mypy-1.4.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:7549fbf655e5825d787bbc9ecf6028731973f78088fbca3a1f4145c39ef09462"}, + {file = "mypy-1.4.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:98324ec3ecf12296e6422939e54763faedbfcc502ea4a4c38502082711867258"}, + {file = "mypy-1.4.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:141dedfdbfe8a04142881ff30ce6e6653c9685b354876b12e4fe6c78598b45e2"}, + {file = "mypy-1.4.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:8207b7105829eca6f3d774f64a904190bb2231de91b8b186d21ffd98005f14a7"}, + {file = "mypy-1.4.1-cp311-cp311-win_amd64.whl", hash = "sha256:16f0db5b641ba159eff72cff08edc3875f2b62b2fa2bc24f68c1e7a4e8232d01"}, + {file = "mypy-1.4.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:470c969bb3f9a9efcedbadcd19a74ffb34a25f8e6b0e02dae7c0e71f8372f97b"}, + {file = "mypy-1.4.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e5952d2d18b79f7dc25e62e014fe5a23eb1a3d2bc66318df8988a01b1a037c5b"}, + {file = "mypy-1.4.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:190b6bab0302cec4e9e6767d3eb66085aef2a1cc98fe04936d8a42ed2ba77bb7"}, + {file = "mypy-1.4.1-cp37-cp37m-win_amd64.whl", hash = "sha256:9d40652cc4fe33871ad3338581dca3297ff5f2213d0df345bcfbde5162abf0c9"}, + {file = "mypy-1.4.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:01fd2e9f85622d981fd9063bfaef1aed6e336eaacca00892cd2d82801ab7c042"}, + {file = "mypy-1.4.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:2460a58faeea905aeb1b9b36f5065f2dc9a9c6e4c992a6499a2360c6c74ceca3"}, + {file = "mypy-1.4.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a2746d69a8196698146a3dbe29104f9eb6a2a4d8a27878d92169a6c0b74435b6"}, + {file = "mypy-1.4.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:ae704dcfaa180ff7c4cfbad23e74321a2b774f92ca77fd94ce1049175a21c97f"}, + {file = "mypy-1.4.1-cp38-cp38-win_amd64.whl", hash = "sha256:43d24f6437925ce50139a310a64b2ab048cb2d3694c84c71c3f2a1626d8101dc"}, + {file = "mypy-1.4.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c482e1246726616088532b5e964e39765b6d1520791348e6c9dc3af25b233828"}, + {file = "mypy-1.4.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:43b592511672017f5b1a483527fd2684347fdffc041c9ef53428c8dc530f79a3"}, + {file = "mypy-1.4.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:34a9239d5b3502c17f07fd7c0b2ae6b7dd7d7f6af35fbb5072c6208e76295816"}, + {file = "mypy-1.4.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:5703097c4936bbb9e9bce41478c8d08edd2865e177dc4c52be759f81ee4dd26c"}, + {file = "mypy-1.4.1-cp39-cp39-win_amd64.whl", hash = "sha256:e02d700ec8d9b1859790c0475df4e4092c7bf3272a4fd2c9f33d87fac4427b8f"}, + {file = "mypy-1.4.1-py3-none-any.whl", hash = "sha256:45d32cec14e7b97af848bddd97d85ea4f0db4d5a149ed9676caa4eb2f7402bb4"}, + {file = "mypy-1.4.1.tar.gz", hash = "sha256:9bbcd9ab8ea1f2e1c8031c21445b511442cc45c89951e49bbf852cbb70755b1b"}, +] + +[package.dependencies] +mypy-extensions = ">=1.0.0" tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} -typing-extensions = ">=3.10" +typing-extensions = ">=4.1.0" [package.extras] dmypy = ["psutil (>=4.0)"] @@ -1502,17 +1502,17 @@ files = [ [[package]] name = "mypy-zope" -version = "0.9.1" +version = "1.0.0" description = "Plugin for mypy to support zope interfaces" optional = false python-versions = "*" files = [ - {file = "mypy-zope-0.9.1.tar.gz", hash = "sha256:4c87dbc71fec35f6533746ecdf9d400cd9281338d71c16b5676bb5ed00a97ca2"}, - {file = "mypy_zope-0.9.1-py3-none-any.whl", hash = "sha256:733d4399affe9e61e332ce9c4049418d6775c39b473e4b9f409d51c207c1b71a"}, + {file = "mypy-zope-1.0.0.tar.gz", hash = "sha256:be815c2fcb5333aa87e8ec682029ad3214142fe2a05ea383f9ff2d77c98008b7"}, + {file = "mypy_zope-1.0.0-py3-none-any.whl", hash = "sha256:9732e9b2198f2aec3343b38a51905ff49d44dc9e39e8e8bc6fc490b232388209"}, ] [package.dependencies] -mypy = ">=1.0.0,<1.1.0" +mypy = ">=1.0.0,<1.5.0" "zope.interface" = "*" "zope.schema" = "*" diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py index 49242800b8..ab2b29cf1b 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py @@ -482,7 +482,10 @@ class Porter: do_backward[0] = False if forward_rows or backward_rows: - headers = [column[0] for column in txn.description] + assert txn.description is not None + headers: Optional[List[str]] = [ + column[0] for column in txn.description + ] else: headers = None @@ -544,6 +547,7 @@ class Porter: def r(txn: LoggingTransaction) -> Tuple[List[str], List[Tuple]]: txn.execute(select, (forward_chunk, self.batch_size)) rows = txn.fetchall() + assert txn.description is not None headers = [column[0] for column in txn.description] return headers, rows @@ -919,7 +923,8 @@ class Porter: def r(txn: LoggingTransaction) -> Tuple[List[str], List[Tuple]]: txn.execute(select) rows = txn.fetchall() - headers: List[str] = [column[0] for column in txn.description] + assert txn.description is not None + headers = [column[0] for column in txn.description] ts_ind = headers.index("ts") diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index be910128aa..5c3045e197 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -910,10 +910,10 @@ def _custom_sync_async_decorator( async def _wrapper( *args: P.args, **kwargs: P.kwargs ) -> Any: # Return type is RInner - with wrapping_logic(func, *args, **kwargs): - # type-ignore: func() returns R, but mypy doesn't know that R is - # Awaitable here. - return await func(*args, **kwargs) # type: ignore[misc] + # type-ignore: func() returns R, but mypy doesn't know that R is + # Awaitable here. + with wrapping_logic(func, *args, **kwargs): # type: ignore[arg-type] + return await func(*args, **kwargs) else: # The other case here handles sync functions including those decorated with @@ -980,8 +980,7 @@ def trace_with_opname( See the module's doc string for usage examples. """ - # type-ignore: mypy bug, see https://github.com/python/mypy/issues/12909 - @contextlib.contextmanager # type: ignore[arg-type] + @contextlib.contextmanager def _wrapping_logic( func: Callable[P, R], *args: P.args, **kwargs: P.kwargs ) -> Generator[None, None, None]: @@ -1024,8 +1023,7 @@ def tag_args(func: Callable[P, R]) -> Callable[P, R]: if not opentracing: return func - # type-ignore: mypy bug, see https://github.com/python/mypy/issues/12909 - @contextlib.contextmanager # type: ignore[arg-type] + @contextlib.contextmanager def _wrapping_logic( func: Callable[P, R], *args: P.args, **kwargs: P.kwargs ) -> Generator[None, None, None]: diff --git a/synapse/storage/database.py b/synapse/storage/database.py index a1c8fb0f46..55ac313f33 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -31,6 +31,7 @@ from typing import ( Iterator, List, Optional, + Sequence, Tuple, Type, TypeVar, @@ -358,7 +359,21 @@ class LoggingTransaction: return self.txn.rowcount @property - def description(self) -> Any: + def description( + self, + ) -> Optional[ + Sequence[ + Tuple[ + str, + Optional[Any], + Optional[int], + Optional[int], + Optional[int], + Optional[int], + Optional[int], + ] + ] + ]: return self.txn.description def execute_batch(self, sql: str, args: Iterable[Iterable[Any]]) -> None: diff --git a/synapse/util/check_dependencies.py b/synapse/util/check_dependencies.py index 114130a08f..f7cead9e12 100644 --- a/synapse/util/check_dependencies.py +++ b/synapse/util/check_dependencies.py @@ -51,9 +51,9 @@ class DependencyException(Exception): DEV_EXTRAS = {"lint", "mypy", "test", "dev"} -RUNTIME_EXTRAS = ( - set(metadata.metadata(DISTRIBUTION_NAME).get_all("Provides-Extra")) - DEV_EXTRAS -) +ALL_EXTRAS = metadata.metadata(DISTRIBUTION_NAME).get_all("Provides-Extra") +assert ALL_EXTRAS is not None +RUNTIME_EXTRAS = set(ALL_EXTRAS) - DEV_EXTRAS VERSION = metadata.version(DISTRIBUTION_NAME) diff --git a/tests/appservice/test_api.py b/tests/appservice/test_api.py index 3c635e3dcb..75fb5fae6b 100644 --- a/tests/appservice/test_api.py +++ b/tests/appservice/test_api.py @@ -96,7 +96,7 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase): ) # We assign to a method, which mypy doesn't like. - self.api.get_json = Mock(side_effect=get_json) # type: ignore[assignment] + self.api.get_json = Mock(side_effect=get_json) # type: ignore[method-assign] result = self.get_success( self.api.query_3pe(self.service, "user", PROTOCOL, {b"some": [b"field"]}) @@ -168,7 +168,7 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase): ) # We assign to a method, which mypy doesn't like. - self.api.get_json = Mock(side_effect=get_json) # type: ignore[assignment] + self.api.get_json = Mock(side_effect=get_json) # type: ignore[method-assign] result = self.get_success( self.api.query_3pe(self.service, "user", PROTOCOL, {b"some": [b"field"]}) @@ -215,7 +215,7 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase): return RESPONSE # We assign to a method, which mypy doesn't like. - self.api.post_json_get_json = Mock(side_effect=post_json_get_json) # type: ignore[assignment] + self.api.post_json_get_json = Mock(side_effect=post_json_get_json) # type: ignore[method-assign] MISSING_KEYS = [ # Known user, known device, missing algorithm. diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py index 5b58fb13b5..73a2766baf 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py @@ -57,7 +57,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): async def get_current_state_event_counts(room_id: str) -> int: return int(500 * 1.23) - store.get_current_state_event_counts = get_current_state_event_counts # type: ignore[assignment] + store.get_current_state_event_counts = get_current_state_event_counts # type: ignore[method-assign] # Get the room complexity again -- make sure it's our artificial value channel = self.make_signed_federation_request( @@ -74,8 +74,8 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = AsyncMock(return_value={"v1": 9999}) # type: ignore[assignment] - handler.federation_handler.do_invite_join = AsyncMock( # type: ignore[assignment] + fed_transport.client.get_json = AsyncMock(return_value={"v1": 9999}) # type: ignore[method-assign] + handler.federation_handler.do_invite_join = AsyncMock( # type: ignore[method-assign] return_value=("", 1) ) @@ -105,8 +105,8 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = AsyncMock(return_value={"v1": 9999}) # type: ignore[assignment] - handler.federation_handler.do_invite_join = AsyncMock( # type: ignore[assignment] + fed_transport.client.get_json = AsyncMock(return_value={"v1": 9999}) # type: ignore[method-assign] + handler.federation_handler.do_invite_join = AsyncMock( # type: ignore[method-assign] return_value=("", 1) ) @@ -142,8 +142,8 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = AsyncMock(return_value=None) # type: ignore[assignment] - handler.federation_handler.do_invite_join = AsyncMock( # type: ignore[assignment] + fed_transport.client.get_json = AsyncMock(return_value=None) # type: ignore[method-assign] + handler.federation_handler.do_invite_join = AsyncMock( # type: ignore[method-assign] return_value=("", 1) ) @@ -151,7 +151,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): async def get_current_state_event_counts(room_id: str) -> int: return 600 - self.hs.get_datastores().main.get_current_state_event_counts = get_current_state_event_counts # type: ignore[assignment] + self.hs.get_datastores().main.get_current_state_event_counts = get_current_state_event_counts # type: ignore[method-assign] d = handler._remote_join( create_requester(u1), @@ -199,8 +199,8 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase): fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = AsyncMock(return_value={"v1": 9999}) # type: ignore[assignment] - handler.federation_handler.do_invite_join = AsyncMock( # type: ignore[assignment] + fed_transport.client.get_json = AsyncMock(return_value={"v1": 9999}) # type: ignore[method-assign] + handler.federation_handler.do_invite_join = AsyncMock( # type: ignore[method-assign] return_value=("", 1) ) @@ -229,8 +229,8 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase): fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = AsyncMock(return_value={"v1": 9999}) # type: ignore[assignment] - handler.federation_handler.do_invite_join = AsyncMock( # type: ignore[assignment] + fed_transport.client.get_json = AsyncMock(return_value={"v1": 9999}) # type: ignore[method-assign] + handler.federation_handler.do_invite_join = AsyncMock( # type: ignore[method-assign] return_value=("", 1) ) diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py index 40318aa1b6..75ae740b43 100644 --- a/tests/federation/test_federation_catch_up.py +++ b/tests/federation/test_federation_catch_up.py @@ -50,7 +50,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): # This mock is crucial for destination_rooms to be populated. # TODO: this seems to no longer be the case---tests pass with this mock # commented out. - state_storage_controller.get_current_hosts_in_room = AsyncMock( # type: ignore[assignment] + state_storage_controller.get_current_hosts_in_room = AsyncMock( # type: ignore[method-assign] return_value={"test", "host2"} ) @@ -436,7 +436,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): def wake_destination_track(destination: str) -> None: woken.add(destination) - self.federation_sender.wake_destination = wake_destination_track # type: ignore[assignment] + self.federation_sender.wake_destination = wake_destination_track # type: ignore[method-assign] # We wait quite long so that all dests can be woken up, since there is a delay # between them. diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index 5ea4a75a9f..7bd3d06859 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -47,11 +47,11 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): federation_transport_client=self.federation_transport_client, ) - hs.get_storage_controllers().state.get_current_hosts_in_room = AsyncMock( # type: ignore[assignment] + hs.get_storage_controllers().state.get_current_hosts_in_room = AsyncMock( # type: ignore[method-assign] return_value={"test", "host2"} ) - hs.get_storage_controllers().state.get_current_hosts_in_room_or_partial_state_approximation = ( # type: ignore[assignment] + hs.get_storage_controllers().state.get_current_hosts_in_room_or_partial_state_approximation = ( # type: ignore[method-assign] hs.get_storage_controllers().state.get_current_hosts_in_room ) diff --git a/tests/federation/transport/test_knocking.py b/tests/federation/transport/test_knocking.py index 70209ab090..3f42f79f26 100644 --- a/tests/federation/transport/test_knocking.py +++ b/tests/federation/transport/test_knocking.py @@ -218,7 +218,7 @@ class FederationKnockingTestCase( ) -> EventBase: return pdu - homeserver.get_federation_server()._check_sigs_and_hash = ( # type: ignore[assignment] + homeserver.get_federation_server()._check_sigs_and_hash = ( # type: ignore[method-assign] approve_all_signature_checking ) @@ -229,7 +229,7 @@ class FederationKnockingTestCase( ) -> None: pass - homeserver.get_federation_event_handler()._check_event_auth = _check_event_auth # type: ignore[assignment] + homeserver.get_federation_event_handler()._check_event_auth = _check_event_auth # type: ignore[method-assign] return super().prepare(reactor, clock, homeserver) diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index 4bd0facd65..46d022092e 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -400,11 +400,11 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): # Mock the ApplicationServiceScheduler's _TransactionController's send method so that # we can track any outgoing ephemeral events self.send_mock = AsyncMock() - hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock # type: ignore[assignment] + hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock # type: ignore[method-assign] # Mock out application services, and allow defining our own in tests self._services: List[ApplicationService] = [] - self.hs.get_datastores().main.get_app_services = Mock( # type: ignore[assignment] + self.hs.get_datastores().main.get_app_services = Mock( # type: ignore[method-assign] return_value=self._services ) @@ -898,11 +898,11 @@ class ApplicationServicesHandlerDeviceListsTestCase(unittest.HomeserverTestCase) # Mock ApplicationServiceApi's put_json, so we can verify the raw JSON that # will be sent over the wire self.put_json = AsyncMock() - hs.get_application_service_api().put_json = self.put_json # type: ignore[assignment] + hs.get_application_service_api().put_json = self.put_json # type: ignore[method-assign] # Mock out application services, and allow defining our own in tests self._services: List[ApplicationService] = [] - self.hs.get_datastores().main.get_app_services = Mock( # type: ignore[assignment] + self.hs.get_datastores().main.get_app_services = Mock( # type: ignore[method-assign] return_value=self._services ) @@ -1004,7 +1004,7 @@ class ApplicationServicesHandlerOtkCountsTestCase(unittest.HomeserverTestCase): # Mock the ApplicationServiceScheduler's _TransactionController's send method so that # we can track what's going out self.send_mock = AsyncMock() - hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock # type: ignore[assignment] # We assign to a method. + hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock # type: ignore[method-assign] # We assign to a method. # Define an application service for the tests self._service_token = "VERYSECRET" diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py index 2cb24add20..8582b1cd1e 100644 --- a/tests/handlers/test_cas.py +++ b/tests/handlers/test_cas.py @@ -60,7 +60,7 @@ class CasHandlerTestCase(HomeserverTestCase): # stub out the auth handler auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = AsyncMock() # type: ignore[assignment] + auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign] cas_response = CasResponse("test_user", {}) request = _mock_request() @@ -88,7 +88,7 @@ class CasHandlerTestCase(HomeserverTestCase): # stub out the auth handler auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = AsyncMock() # type: ignore[assignment] + auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign] # Map a user via SSO. cas_response = CasResponse("test_user", {}) @@ -128,7 +128,7 @@ class CasHandlerTestCase(HomeserverTestCase): # stub out the auth handler auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = AsyncMock() # type: ignore[assignment] + auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign] cas_response = CasResponse("föö", {}) request = _mock_request() @@ -159,7 +159,7 @@ class CasHandlerTestCase(HomeserverTestCase): # stub out the auth handler auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = AsyncMock() # type: ignore[assignment] + auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign] # The response doesn't have the proper userGroup or department. cas_response = CasResponse("test_user", {}) diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 7917766a08..c5556f2844 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -800,7 +800,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY" remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ" - self.hs.get_federation_client().query_client_keys = mock.AsyncMock( # type: ignore[assignment] + self.hs.get_federation_client().query_client_keys = mock.AsyncMock( # type: ignore[method-assign] return_value={ "device_keys": {remote_user_id: {}}, "master_keys": { @@ -876,7 +876,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY" remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ" - self.hs.get_federation_client().query_user_devices = mock.AsyncMock( # type: ignore[assignment] + self.hs.get_federation_client().query_user_devices = mock.AsyncMock( # type: ignore[method-assign] return_value={ "user_id": remote_user_id, "stream_id": 1, diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index bd743b3578..21d63ab1f2 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -371,14 +371,14 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase): # We mock out the FederationClient.backfill method, to pretend that a remote # server has returned our fake event. federation_client_backfill_mock = AsyncMock(return_value=[event]) - self.hs.get_federation_client().backfill = federation_client_backfill_mock # type: ignore[assignment] + self.hs.get_federation_client().backfill = federation_client_backfill_mock # type: ignore[method-assign] # We also mock the persist method with a side effect of itself. This allows us # to track when it has been called while preserving its function. persist_events_and_notify_mock = Mock( side_effect=self.hs.get_federation_event_handler().persist_events_and_notify ) - self.hs.get_federation_event_handler().persist_events_and_notify = ( # type: ignore[assignment] + self.hs.get_federation_event_handler().persist_events_and_notify = ( # type: ignore[method-assign] persist_events_and_notify_mock ) diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index 9b2c7812cc..e797aaae00 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -157,7 +157,7 @@ class OidcHandlerTestCase(HomeserverTestCase): sso_handler = hs.get_sso_handler() # Mock the render error method. self.render_error = Mock(return_value=None) - sso_handler.render_error = self.render_error # type: ignore[assignment] + sso_handler.render_error = self.render_error # type: ignore[method-assign] # Reduce the number of attempts when generating MXIDs. sso_handler._MAP_USERNAME_RETRIES = 3 @@ -165,7 +165,7 @@ class OidcHandlerTestCase(HomeserverTestCase): auth_handler = hs.get_auth_handler() # Mock the complete SSO login method. self.complete_sso_login = AsyncMock() - auth_handler.complete_sso_login = self.complete_sso_login # type: ignore[assignment] + auth_handler.complete_sso_login = self.complete_sso_login # type: ignore[method-assign] return hs diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py index 4496370c3f..11ec8c7f11 100644 --- a/tests/handlers/test_password_providers.py +++ b/tests/handlers/test_password_providers.py @@ -830,7 +830,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): username: The username to use for the test. registration: Whether to test with registration URLs. """ - self.hs.get_identity_handler().send_threepid_validation = AsyncMock( # type: ignore[assignment] + self.hs.get_identity_handler().send_threepid_validation = AsyncMock( # type: ignore[method-assign] return_value=0 ) diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index a04234829f..e9fbf32c7c 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -202,7 +202,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): @override_config({"limit_usage_by_mau": True}) def test_get_or_create_user_mau_not_blocked(self) -> None: - self.store.count_monthly_users = AsyncMock( # type: ignore[assignment] + self.store.count_monthly_users = AsyncMock( # type: ignore[method-assign] return_value=self.hs.config.server.max_mau_value - 1 ) # Ensure does not throw exception @@ -299,7 +299,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): def test_auto_create_auto_join_rooms_when_user_is_the_first_real_user(self) -> None: room_alias_str = "#room:test" - self.store.count_real_users = AsyncMock(return_value=1) # type: ignore[assignment] + self.store.count_real_users = AsyncMock(return_value=1) # type: ignore[method-assign] self.store.is_real_user = AsyncMock(return_value=True) user_id = self.get_success(self.handler.register_user(localpart="real")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) @@ -314,7 +314,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): def test_auto_create_auto_join_rooms_when_user_is_not_the_first_real_user( self, ) -> None: - self.store.count_real_users = AsyncMock(return_value=2) # type: ignore[assignment] + self.store.count_real_users = AsyncMock(return_value=2) # type: ignore[method-assign] self.store.is_real_user = AsyncMock(return_value=True) user_id = self.get_success(self.handler.register_user(localpart="real")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py index 6e666d7bed..00f4e181e8 100644 --- a/tests/handlers/test_saml.py +++ b/tests/handlers/test_saml.py @@ -133,7 +133,7 @@ class SamlHandlerTestCase(HomeserverTestCase): # stub out the auth handler auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = AsyncMock() # type: ignore[assignment] + auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign] # send a mocked-up SAML response to the callback saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"}) @@ -163,7 +163,7 @@ class SamlHandlerTestCase(HomeserverTestCase): # stub out the auth handler auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = AsyncMock() # type: ignore[assignment] + auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign] # Map a user via SSO. saml_response = FakeAuthnResponse( @@ -205,11 +205,11 @@ class SamlHandlerTestCase(HomeserverTestCase): # stub out the auth handler auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = AsyncMock() # type: ignore[assignment] + auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign] # mock out the error renderer too sso_handler = self.hs.get_sso_handler() - sso_handler.render_error = Mock(return_value=None) # type: ignore[assignment] + sso_handler.render_error = Mock(return_value=None) # type: ignore[method-assign] saml_response = FakeAuthnResponse({"uid": "test", "username": "föö"}) request = _mock_request() @@ -226,9 +226,9 @@ class SamlHandlerTestCase(HomeserverTestCase): # stub out the auth handler and error renderer auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = AsyncMock() # type: ignore[assignment] + auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign] sso_handler = self.hs.get_sso_handler() - sso_handler.render_error = Mock(return_value=None) # type: ignore[assignment] + sso_handler.render_error = Mock(return_value=None) # type: ignore[method-assign] # register a user to occupy the first-choice MXID store = self.hs.get_datastores().main @@ -311,7 +311,7 @@ class SamlHandlerTestCase(HomeserverTestCase): # stub out the auth handler auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = AsyncMock() # type: ignore[assignment] + auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign] # The response doesn't have the proper userGroup or department. saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"}) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index d776526bc1..2a295da3a0 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -122,15 +122,15 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.datastore.get_destination_retry_timings = AsyncMock(return_value=None) - self.datastore.get_device_updates_by_remote = AsyncMock( # type: ignore[assignment] + self.datastore.get_device_updates_by_remote = AsyncMock( # type: ignore[method-assign] return_value=(0, []) ) - self.datastore.get_destination_last_successful_stream_ordering = AsyncMock( # type: ignore[assignment] + self.datastore.get_destination_last_successful_stream_ordering = AsyncMock( # type: ignore[method-assign] return_value=None ) - self.datastore.get_received_txn_response = AsyncMock( # type: ignore[assignment] + self.datastore.get_received_txn_response = AsyncMock( # type: ignore[method-assign] return_value=None ) @@ -143,25 +143,25 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): raise AuthError(401, "User is not in the room") return None - hs.get_auth().check_user_in_room = Mock( # type: ignore[assignment] + hs.get_auth().check_user_in_room = Mock( # type: ignore[method-assign] side_effect=check_user_in_room ) async def check_host_in_room(room_id: str, server_name: str) -> bool: return room_id == ROOM_ID - hs.get_event_auth_handler().is_host_in_room = Mock( # type: ignore[assignment] + hs.get_event_auth_handler().is_host_in_room = Mock( # type: ignore[method-assign] side_effect=check_host_in_room ) async def get_current_hosts_in_room(room_id: str) -> Set[str]: return {member.domain for member in self.room_members} - hs.get_storage_controllers().state.get_current_hosts_in_room = Mock( # type: ignore[assignment] + hs.get_storage_controllers().state.get_current_hosts_in_room = Mock( # type: ignore[method-assign] side_effect=get_current_hosts_in_room ) - hs.get_storage_controllers().state.get_current_hosts_in_room_or_partial_state_approximation = Mock( # type: ignore[assignment] + hs.get_storage_controllers().state.get_current_hosts_in_room_or_partial_state_approximation = Mock( # type: ignore[method-assign] side_effect=get_current_hosts_in_room ) @@ -170,24 +170,24 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.datastore.get_users_in_room = Mock(side_effect=get_users_in_room) - self.datastore.get_user_directory_stream_pos = AsyncMock( # type: ignore[assignment] + self.datastore.get_user_directory_stream_pos = AsyncMock( # type: ignore[method-assign] # we deliberately return a non-None stream pos to avoid # doing an initial_sync return_value=1 ) - self.datastore.get_partial_current_state_deltas = Mock(return_value=(0, None)) # type: ignore[assignment] + self.datastore.get_partial_current_state_deltas = Mock(return_value=(0, None)) # type: ignore[method-assign] - self.datastore.get_to_device_stream_token = Mock( # type: ignore[assignment] + self.datastore.get_to_device_stream_token = Mock( # type: ignore[method-assign] return_value=0 ) - self.datastore.get_new_device_msgs_for_remote = AsyncMock( # type: ignore[assignment] + self.datastore.get_new_device_msgs_for_remote = AsyncMock( # type: ignore[method-assign] return_value=([], 0) ) - self.datastore.delete_device_msgs_for_remote = AsyncMock( # type: ignore[assignment] + self.datastore.delete_device_msgs_for_remote = AsyncMock( # type: ignore[method-assign] return_value=None ) - self.datastore.set_received_txn_response = AsyncMock( # type: ignore[assignment] + self.datastore.set_received_txn_response = AsyncMock( # type: ignore[method-assign] return_value=None ) diff --git a/tests/logging/test_terse_json.py b/tests/logging/test_terse_json.py index fa27f1279a..c379853e20 100644 --- a/tests/logging/test_terse_json.py +++ b/tests/logging/test_terse_json.py @@ -164,7 +164,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase): # Call requestReceived to finish instantiating the object. request.content = BytesIO() # Partially skip some internal processing of SynapseRequest. - request._started_processing = Mock() # type: ignore[assignment] + request._started_processing = Mock() # type: ignore[method-assign] request.request_metrics = Mock(spec=["name"]) with patch.object(Request, "render"): request.requestReceived(b"POST", b"/_matrix/client/versions", b"1.1") diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py index 9ce9326190..172fc3a736 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py @@ -233,7 +233,7 @@ class ModuleApiTestCase(BaseModuleApiTestCase): def test_sending_events_into_room(self) -> None: """Tests that a module can send events into a room""" # Mock out create_and_send_nonmember_event to check whether events are being sent - self.event_creation_handler.create_and_send_nonmember_event = Mock( # type: ignore[assignment] + self.event_creation_handler.create_and_send_nonmember_event = Mock( # type: ignore[method-assign] spec=[], side_effect=self.event_creation_handler.create_and_send_nonmember_event, ) @@ -579,7 +579,7 @@ class ModuleApiTestCase(BaseModuleApiTestCase): # Necessary to fake a remote join. fake_stream_id = 1 mocked_remote_join = AsyncMock(return_value=("fake-event-id", fake_stream_id)) - self.hs.get_room_member_handler()._remote_join = mocked_remote_join # type: ignore[assignment] + self.hs.get_room_member_handler()._remote_join = mocked_remote_join # type: ignore[method-assign] fake_remote_host = f"{self.module_api.server_name}-remote" # Given that the join is to be faked, we expect the relevant join event not to diff --git a/tests/push/test_bulk_push_rule_evaluator.py b/tests/push/test_bulk_push_rule_evaluator.py index a3880ac171..7c23b77e0a 100644 --- a/tests/push/test_bulk_push_rule_evaluator.py +++ b/tests/push/test_bulk_push_rule_evaluator.py @@ -190,7 +190,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): # Mock the method which calculates push rules -- we do this instead of # e.g. checking the results in the database because we want to ensure # that code isn't even running. - bulk_evaluator._action_for_event_by_user = AsyncMock() # type: ignore[assignment] + bulk_evaluator._action_for_event_by_user = AsyncMock() # type: ignore[method-assign] # Ensure no actions are generated! self.get_success(bulk_evaluator.action_for_events_by_user([(event, context)])) diff --git a/tests/replication/storage/test_events.py b/tests/replication/storage/test_events.py index f7c6417a09..af25815fa5 100644 --- a/tests/replication/storage/test_events.py +++ b/tests/replication/storage/test_events.py @@ -58,7 +58,7 @@ def patch__eq__(cls: object) -> Callable[[], None]: def unpatch() -> None: if eq is not None: - cls.__eq__ = eq # type: ignore[assignment] + cls.__eq__ = eq # type: ignore[method-assign] return unpatch diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 339a41c7e1..2f6bd0d74f 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -71,8 +71,8 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): self.hs.config.registration.registration_shared_secret = "shared" - self.hs.get_media_repository = Mock() # type: ignore[assignment] - self.hs.get_deactivate_account_handler = Mock() # type: ignore[assignment] + self.hs.get_media_repository = Mock() # type: ignore[method-assign] + self.hs.get_deactivate_account_handler = Mock() # type: ignore[method-assign] return self.hs diff --git a/tests/rest/admin/test_username_available.py b/tests/rest/admin/test_username_available.py index 6c04e6c56c..4c69d224b8 100644 --- a/tests/rest/admin/test_username_available.py +++ b/tests/rest/admin/test_username_available.py @@ -50,7 +50,7 @@ class UsernameAvailableTestCase(unittest.HomeserverTestCase): ) handler = self.hs.get_registration_handler() - handler.check_username = check_username # type: ignore[assignment] + handler.check_username = check_username # type: ignore[method-assign] def test_username_available(self) -> None: """ diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py index ac19f3c6da..e9f495e206 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py @@ -1346,7 +1346,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase): return {} # Register a mock that will return the expected result depending on the remote. - self.hs.get_federation_http_client().post_json = Mock(side_effect=post_json) # type: ignore[assignment] + self.hs.get_federation_http_client().post_json = Mock(side_effect=post_json) # type: ignore[method-assign] # Check that we've got the correct response from the client-side endpoint. self._test_status( diff --git a/tests/rest/client/test_events.py b/tests/rest/client/test_events.py index 54df2a252c..141e0f57a3 100644 --- a/tests/rest/client/test_events.py +++ b/tests/rest/client/test_events.py @@ -45,7 +45,7 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): hs = self.setup_test_homeserver(config=config) - hs.get_federation_handler = Mock() # type: ignore[assignment] + hs.get_federation_handler = Mock() # type: ignore[method-assign] return hs diff --git a/tests/rest/client/test_filter.py b/tests/rest/client/test_filter.py index a2d5d340be..90a8df147c 100644 --- a/tests/rest/client/test_filter.py +++ b/tests/rest/client/test_filter.py @@ -65,14 +65,14 @@ class FilterTestCase(unittest.HomeserverTestCase): def test_add_filter_non_local_user(self) -> None: _is_mine = self.hs.is_mine - self.hs.is_mine = lambda target_user: False # type: ignore[assignment] + self.hs.is_mine = lambda target_user: False # type: ignore[method-assign] channel = self.make_request( "POST", "/_matrix/client/r0/user/%s/filter" % (self.user_id), self.EXAMPLE_FILTER_JSON, ) - self.hs.is_mine = _is_mine # type: ignore[assignment] + self.hs.is_mine = _is_mine # type: ignore[method-assign] self.assertEqual(channel.code, 403) self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN) diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 53182459e4..47c1d38ad7 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -68,7 +68,7 @@ class RoomBase(unittest.HomeserverTestCase): "red", ) - self.hs.get_federation_handler = Mock() # type: ignore[assignment] + self.hs.get_federation_handler = Mock() # type: ignore[method-assign] self.hs.get_federation_handler.return_value.maybe_backfill = AsyncMock( return_value=None ) @@ -76,7 +76,7 @@ class RoomBase(unittest.HomeserverTestCase): async def _insert_client_ip(*args: Any, **kwargs: Any) -> None: return None - self.hs.get_datastores().main.insert_client_ip = _insert_client_ip # type: ignore[assignment] + self.hs.get_datastores().main.insert_client_ip = _insert_client_ip # type: ignore[method-assign] return self.hs @@ -3413,8 +3413,8 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase): # a remote IS. We keep the mock for make_and_store_3pid_invite around so we # can check its call_count later on during the test. make_invite_mock = AsyncMock(return_value=(Mock(event_id="abc"), 0)) - self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock # type: ignore[assignment] - self.hs.get_identity_handler().lookup_3pid = AsyncMock( # type: ignore[assignment] + self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock # type: ignore[method-assign] + self.hs.get_identity_handler().lookup_3pid = AsyncMock( # type: ignore[method-assign] return_value=None, ) @@ -3477,8 +3477,8 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase): # a remote IS. We keep the mock for make_and_store_3pid_invite around so we # can check its call_count later on during the test. make_invite_mock = AsyncMock(return_value=(Mock(event_id="abc"), 0)) - self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock # type: ignore[assignment] - self.hs.get_identity_handler().lookup_3pid = AsyncMock( # type: ignore[assignment] + self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock # type: ignore[method-assign] + self.hs.get_identity_handler().lookup_3pid = AsyncMock( # type: ignore[method-assign] return_value=None, ) diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py index 8d2cdf8751..9aecf88e41 100644 --- a/tests/rest/client/test_shadow_banned.py +++ b/tests/rest/client/test_shadow_banned.py @@ -84,7 +84,7 @@ class RoomTestCase(_ShadowBannedBase): def test_invite_3pid(self) -> None: """Ensure that a 3PID invite does not attempt to contact the identity server.""" identity_handler = self.hs.get_identity_handler() - identity_handler.lookup_3pid = Mock( # type: ignore[assignment] + identity_handler.lookup_3pid = Mock( # type: ignore[method-assign] side_effect=AssertionError("This should not get called") ) diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py index da37fcb045..57eb713b15 100644 --- a/tests/rest/client/test_third_party_rules.py +++ b/tests/rest/client/test_third_party_rules.py @@ -117,7 +117,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): async def _check_event_auth(origin: Any, event: Any, context: Any) -> None: pass - hs.get_federation_event_handler()._check_event_auth = _check_event_auth # type: ignore[assignment] + hs.get_federation_event_handler()._check_event_auth = _check_event_auth # type: ignore[method-assign] return hs diff --git a/tests/server.py b/tests/server.py index 659ccce838..08633fe640 100644 --- a/tests/server.py +++ b/tests/server.py @@ -722,7 +722,7 @@ def _make_test_homeserver_synchronous(server: HomeServer) -> None: **kwargs, ) - pool.runWithConnection = runWithConnection # type: ignore[assignment] + pool.runWithConnection = runWithConnection # type: ignore[method-assign] pool.runInteraction = runInteraction # type: ignore[assignment] # Replace the thread pool with a threadless 'thread' pool pool.threadpool = ThreadPool(clock._reactor) diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 47c53a5475..17f428bfc5 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -69,7 +69,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): self._rlsn = rlsn self._rlsn._store.user_last_seen_monthly_active = AsyncMock(return_value=1000) - self._rlsn._server_notices_manager.send_notice = AsyncMock( # type: ignore[assignment] + self._rlsn._server_notices_manager.send_notice = AsyncMock( # type: ignore[method-assign] return_value=Mock() ) self._send_notice = self._rlsn._server_notices_manager.send_notice @@ -82,8 +82,8 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): self._rlsn._server_notices_manager.maybe_get_notice_room_for_user = AsyncMock( return_value="!something:localhost" ) - self._rlsn._store.add_tag_to_room = AsyncMock(return_value=None) # type: ignore[assignment] - self._rlsn._store.get_tags_for_room = AsyncMock(return_value={}) # type: ignore[assignment] + self._rlsn._store.add_tag_to_room = AsyncMock(return_value=None) # type: ignore[method-assign] + self._rlsn._store.get_tags_for_room = AsyncMock(return_value={}) # type: ignore[method-assign] @override_config({"hs_disabled": True}) def test_maybe_send_server_notice_disabled_hs(self) -> None: @@ -100,13 +100,13 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): def test_maybe_send_server_notice_to_user_remove_blocked_notice(self) -> None: """Test when user has blocked notice, but should have it removed""" - self._rlsn._auth_blocking.check_auth_blocking = AsyncMock( # type: ignore[assignment] + self._rlsn._auth_blocking.check_auth_blocking = AsyncMock( # type: ignore[method-assign] return_value=None ) mock_event = Mock( type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} ) - self._rlsn._store.get_events = AsyncMock( # type: ignore[assignment] + self._rlsn._store.get_events = AsyncMock( # type: ignore[method-assign] return_value={"123": mock_event} ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) @@ -122,7 +122,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): """ Test when user has blocked notice, but notice ought to be there (NOOP) """ - self._rlsn._auth_blocking.check_auth_blocking = AsyncMock( # type: ignore[assignment] + self._rlsn._auth_blocking.check_auth_blocking = AsyncMock( # type: ignore[method-assign] return_value=None, side_effect=ResourceLimitError(403, "foo"), ) @@ -130,7 +130,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): mock_event = Mock( type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} ) - self._rlsn._store.get_events = AsyncMock( # type: ignore[assignment] + self._rlsn._store.get_events = AsyncMock( # type: ignore[method-assign] return_value={"123": mock_event} ) @@ -142,7 +142,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): """ Test when user does not have blocked notice, but should have one """ - self._rlsn._auth_blocking.check_auth_blocking = AsyncMock( # type: ignore[assignment] + self._rlsn._auth_blocking.check_auth_blocking = AsyncMock( # type: ignore[method-assign] return_value=None, side_effect=ResourceLimitError(403, "foo"), ) @@ -155,7 +155,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): """ Test when user does not have blocked notice, nor should they (NOOP) """ - self._rlsn._auth_blocking.check_auth_blocking = AsyncMock( # type: ignore[assignment] + self._rlsn._auth_blocking.check_auth_blocking = AsyncMock( # type: ignore[method-assign] return_value=None ) @@ -168,7 +168,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): Test when user is not part of the MAU cohort - this should not ever happen - but ... """ - self._rlsn._auth_blocking.check_auth_blocking = AsyncMock( # type: ignore[assignment] + self._rlsn._auth_blocking.check_auth_blocking = AsyncMock( # type: ignore[method-assign] return_value=None ) self._rlsn._store.user_last_seen_monthly_active = AsyncMock(return_value=None) @@ -184,7 +184,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): Test that when server is over MAU limit and alerting is suppressed, then an alert message is not sent into the room """ - self._rlsn._auth_blocking.check_auth_blocking = AsyncMock( # type: ignore[assignment] + self._rlsn._auth_blocking.check_auth_blocking = AsyncMock( # type: ignore[method-assign] return_value=None, side_effect=ResourceLimitError( 403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER @@ -199,7 +199,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): """ Test that when a server is disabled, that MAU limit alerting is ignored. """ - self._rlsn._auth_blocking.check_auth_blocking = AsyncMock( # type: ignore[assignment] + self._rlsn._auth_blocking.check_auth_blocking = AsyncMock( # type: ignore[method-assign] return_value=None, side_effect=ResourceLimitError( 403, "foo", limit_type=LimitBlockingTypes.HS_DISABLED @@ -218,21 +218,21 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): When the room is already in a blocked state, test that when alerting is suppressed that the room is returned to an unblocked state. """ - self._rlsn._auth_blocking.check_auth_blocking = AsyncMock( # type: ignore[assignment] + self._rlsn._auth_blocking.check_auth_blocking = AsyncMock( # type: ignore[method-assign] return_value=None, side_effect=ResourceLimitError( 403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER ), ) - self._rlsn._is_room_currently_blocked = AsyncMock( # type: ignore[assignment] + self._rlsn._is_room_currently_blocked = AsyncMock( # type: ignore[method-assign] return_value=(True, []) ) mock_event = Mock( type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} ) - self._rlsn._store.get_events = AsyncMock( # type: ignore[assignment] + self._rlsn._store.get_events = AsyncMock( # type: ignore[method-assign] return_value={"123": mock_event} ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 48f39df9fe..cbce26a725 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -338,7 +338,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): # we aren't testing store._base stuff here, so mock this out # (ignore needed because Mypy won't allow us to assign to a method otherwise) - self.store.get_events_as_list = AsyncMock(return_value=events) # type: ignore[assignment] + self.store.get_events_as_list = AsyncMock(return_value=events) # type: ignore[method-assign] self.get_success(self._insert_txn(self.as_list[1]["id"], 9, other_events)) self.get_success(self._insert_txn(service.id, 10, events)) diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index 0bf706ba08..49366440ce 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -252,7 +252,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): ) self.get_success(d) - self.store.upsert_monthly_active_user = AsyncMock(return_value=None) # type: ignore[assignment] + self.store.upsert_monthly_active_user = AsyncMock(return_value=None) # type: ignore[method-assign] d = self.store.populate_monthly_active_users(user_id) self.get_success(d) @@ -260,9 +260,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.store.upsert_monthly_active_user.assert_not_called() def test_populate_monthly_users_should_update(self) -> None: - self.store.upsert_monthly_active_user = AsyncMock(return_value=None) # type: ignore[assignment] + self.store.upsert_monthly_active_user = AsyncMock(return_value=None) # type: ignore[method-assign] - self.store.is_trial_user = AsyncMock(return_value=False) # type: ignore[assignment] + self.store.is_trial_user = AsyncMock(return_value=False) # type: ignore[method-assign] self.store.user_last_seen_monthly_active = AsyncMock(return_value=None) d = self.store.populate_monthly_active_users("user_id") @@ -271,9 +271,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.store.upsert_monthly_active_user.assert_called_once() def test_populate_monthly_users_should_not_update(self) -> None: - self.store.upsert_monthly_active_user = AsyncMock(return_value=None) # type: ignore[assignment] + self.store.upsert_monthly_active_user = AsyncMock(return_value=None) # type: ignore[method-assign] - self.store.is_trial_user = AsyncMock(return_value=False) # type: ignore[assignment] + self.store.is_trial_user = AsyncMock(return_value=False) # type: ignore[method-assign] self.store.user_last_seen_monthly_active = AsyncMock( return_value=self.hs.get_clock().time_msec() ) @@ -356,7 +356,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): @override_config({"limit_usage_by_mau": False, "mau_stats_only": False}) def test_no_users_when_not_tracking(self) -> None: - self.store.upsert_monthly_active_user = AsyncMock(return_value=None) # type: ignore[assignment] + self.store.upsert_monthly_active_user = AsyncMock(return_value=None) # type: ignore[method-assign] self.get_success(self.store.populate_monthly_active_users("@user:sever")) diff --git a/tests/test_federation.py b/tests/test_federation.py index 779f70467b..f8ade6da38 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -80,7 +80,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): ) -> None: pass - federation_event_handler._check_event_auth = _check_event_auth # type: ignore[assignment] + federation_event_handler._check_event_auth = _check_event_auth # type: ignore[method-assign] self.client = self.hs.get_federation_client() async def _check_sigs_and_hash_for_pulled_events_and_fetch( @@ -190,7 +190,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): # Register the mock on the federation client. federation_client = self.hs.get_federation_client() - federation_client.query_user_devices = Mock(side_effect=query_user_devices) # type: ignore[assignment] + federation_client.query_user_devices = Mock(side_effect=query_user_devices) # type: ignore[method-assign] # Register a mock on the store so that the incoming update doesn't fail because # we don't share a room with the user. @@ -240,7 +240,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): # Register mock device list retrieval on the federation client. federation_client = self.hs.get_federation_client() - federation_client.query_user_devices = AsyncMock( # type: ignore[assignment] + federation_client.query_user_devices = AsyncMock( # type: ignore[method-assign] return_value={ "user_id": remote_user_id, "stream_id": 1, diff --git a/tests/test_state.py b/tests/test_state.py index eded38c766..9c8679cc1d 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -714,7 +714,7 @@ class StateTestCase(unittest.TestCase): store = _DummyStore() store.register_events(old_state_1) store.register_events(old_state_2) - self.dummy_store.get_events = store.get_events # type: ignore[assignment] + self.dummy_store.get_events = store.get_events # type: ignore[method-assign] context: EventContext context = yield self._get_context( @@ -773,7 +773,7 @@ class StateTestCase(unittest.TestCase): store = _DummyStore() store.register_events(old_state_1) store.register_events(old_state_2) - self.dummy_store.get_events = store.get_events # type: ignore[assignment] + self.dummy_store.get_events = store.get_events # type: ignore[method-assign] context: EventContext context = yield self._get_context( diff --git a/tests/unittest.py b/tests/unittest.py index 40672a4415..5d3640d8ac 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -395,9 +395,9 @@ class HomeserverTestCase(TestCase): ) # Type ignore: mypy doesn't like us assigning to methods. - self.hs.get_auth().get_user_by_req = get_requester # type: ignore[assignment] - self.hs.get_auth().get_user_by_access_token = get_requester # type: ignore[assignment] - self.hs.get_auth().get_access_token_from_request = Mock(return_value=token) # type: ignore[assignment] + self.hs.get_auth().get_user_by_req = get_requester # type: ignore[method-assign] + self.hs.get_auth().get_user_by_access_token = get_requester # type: ignore[method-assign] + self.hs.get_auth().get_access_token_from_request = Mock(return_value=token) # type: ignore[method-assign] if self.needs_threadpool: self.reactor.threadpool = ThreadPool() # type: ignore[assignment] -- cgit 1.5.1 From 62a1a9be52f4bc79b112f9841ddb3d03b8efccba Mon Sep 17 00:00:00 2001 From: David Robertson Date: Wed, 30 Aug 2023 00:39:39 +0100 Subject: Describe which rate limiter was hit in logs (#16135) --- changelog.d/16135.misc | 1 + synapse/api/errors.py | 14 ++- synapse/api/ratelimiting.py | 20 +++-- synapse/config/ratelimiting.py | 132 +++++++++++++++++++---------- synapse/handlers/auth.py | 8 +- synapse/handlers/devicemessage.py | 3 +- synapse/handlers/identity.py | 6 +- synapse/handlers/room_member.py | 21 ++--- synapse/handlers/room_summary.py | 5 +- synapse/http/server.py | 8 +- synapse/rest/client/login.py | 6 +- synapse/rest/client/login_token_request.py | 10 ++- synapse/rest/client/register.py | 3 +- synapse/server.py | 3 +- synapse/util/ratelimitutils.py | 3 +- tests/api/test_errors.py | 15 +++- tests/api/test_ratelimiting.py | 67 +++++++++------ tests/config/test_ratelimiting.py | 31 +++++++ 18 files changed, 235 insertions(+), 121 deletions(-) create mode 100644 changelog.d/16135.misc (limited to 'synapse/util') diff --git a/changelog.d/16135.misc b/changelog.d/16135.misc new file mode 100644 index 0000000000..cba8733d02 --- /dev/null +++ b/changelog.d/16135.misc @@ -0,0 +1 @@ +Describe which rate limiter was hit in logs. diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 578e798773..fdb2955be8 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -211,6 +211,11 @@ class SynapseError(CodeMessageException): def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict": return cs_error(self.msg, self.errcode, **self._additional_fields) + @property + def debug_context(self) -> Optional[str]: + """Override this to add debugging context that shouldn't be sent to clients.""" + return None + class InvalidAPICallError(SynapseError): """You called an existing API endpoint, but fed that endpoint @@ -508,8 +513,8 @@ class LimitExceededError(SynapseError): def __init__( self, + limiter_name: str, code: int = 429, - msg: str = "Too Many Requests", retry_after_ms: Optional[int] = None, errcode: str = Codes.LIMIT_EXCEEDED, ): @@ -518,12 +523,17 @@ class LimitExceededError(SynapseError): if self.include_retry_after_header and retry_after_ms is not None else None ) - super().__init__(code, msg, errcode, headers=headers) + super().__init__(code, "Too Many Requests", errcode, headers=headers) self.retry_after_ms = retry_after_ms + self.limiter_name = limiter_name def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict": return cs_error(self.msg, self.errcode, retry_after_ms=self.retry_after_ms) + @property + def debug_context(self) -> Optional[str]: + return self.limiter_name + class RoomKeysVersionError(SynapseError): """A client has tried to upload to a non-current version of the room_keys store""" diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index 511790c7c5..887b214d64 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -61,12 +61,16 @@ class Ratelimiter: """ def __init__( - self, store: DataStore, clock: Clock, rate_hz: float, burst_count: int + self, + store: DataStore, + clock: Clock, + cfg: RatelimitSettings, ): self.clock = clock - self.rate_hz = rate_hz - self.burst_count = burst_count + self.rate_hz = cfg.per_second + self.burst_count = cfg.burst_count self.store = store + self._limiter_name = cfg.key # An ordered dictionary representing the token buckets tracked by this rate # limiter. Each entry maps a key of arbitrary type to a tuple representing: @@ -305,7 +309,8 @@ class Ratelimiter: if not allowed: raise LimitExceededError( - retry_after_ms=int(1000 * (time_allowed - time_now_s)) + limiter_name=self._limiter_name, + retry_after_ms=int(1000 * (time_allowed - time_now_s)), ) @@ -322,7 +327,9 @@ class RequestRatelimiter: # The rate_hz and burst_count are overridden on a per-user basis self.request_ratelimiter = Ratelimiter( - store=self.store, clock=self.clock, rate_hz=0, burst_count=0 + store=self.store, + clock=self.clock, + cfg=RatelimitSettings(key=rc_message.key, per_second=0, burst_count=0), ) self._rc_message = rc_message @@ -332,8 +339,7 @@ class RequestRatelimiter: self.admin_redaction_ratelimiter: Optional[Ratelimiter] = Ratelimiter( store=self.store, clock=self.clock, - rate_hz=rc_admin_redaction.per_second, - burst_count=rc_admin_redaction.burst_count, + cfg=rc_admin_redaction, ) else: self.admin_redaction_ratelimiter = None diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py index a5514e70a2..4efbaeac0d 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, cast import attr @@ -21,16 +21,47 @@ from synapse.types import JsonDict from ._base import Config +@attr.s(slots=True, frozen=True, auto_attribs=True) class RatelimitSettings: - def __init__( - self, - config: Dict[str, float], + key: str + per_second: float + burst_count: int + + @classmethod + def parse( + cls, + config: Dict[str, Any], + key: str, defaults: Optional[Dict[str, float]] = None, - ): + ) -> "RatelimitSettings": + """Parse config[key] as a new-style rate limiter config. + + The key may refer to a nested dictionary using a full stop (.) to separate + each nested key. For example, use the key "a.b.c" to parse the following: + + a: + b: + c: + per_second: 10 + burst_count: 200 + + If this lookup fails, we'll fallback to the defaults. + """ defaults = defaults or {"per_second": 0.17, "burst_count": 3.0} - self.per_second = config.get("per_second", defaults["per_second"]) - self.burst_count = int(config.get("burst_count", defaults["burst_count"])) + rl_config = config + for part in key.split("."): + rl_config = rl_config.get(part, {}) + + # By this point we should have hit the rate limiter parameters. + # We don't actually check this though! + rl_config = cast(Dict[str, float], rl_config) + + return cls( + key=key, + per_second=rl_config.get("per_second", defaults["per_second"]), + burst_count=int(rl_config.get("burst_count", defaults["burst_count"])), + ) @attr.s(auto_attribs=True) @@ -49,15 +80,14 @@ class RatelimitConfig(Config): # Load the new-style messages config if it exists. Otherwise fall back # to the old method. if "rc_message" in config: - self.rc_message = RatelimitSettings( - config["rc_message"], defaults={"per_second": 0.2, "burst_count": 10.0} + self.rc_message = RatelimitSettings.parse( + config, "rc_message", defaults={"per_second": 0.2, "burst_count": 10.0} ) else: self.rc_message = RatelimitSettings( - { - "per_second": config.get("rc_messages_per_second", 0.2), - "burst_count": config.get("rc_message_burst_count", 10.0), - } + key="rc_messages", + per_second=config.get("rc_messages_per_second", 0.2), + burst_count=config.get("rc_message_burst_count", 10.0), ) # Load the new-style federation config, if it exists. Otherwise, fall @@ -79,51 +109,59 @@ class RatelimitConfig(Config): } ) - self.rc_registration = RatelimitSettings(config.get("rc_registration", {})) + self.rc_registration = RatelimitSettings.parse(config, "rc_registration", {}) - self.rc_registration_token_validity = RatelimitSettings( - config.get("rc_registration_token_validity", {}), + self.rc_registration_token_validity = RatelimitSettings.parse( + config, + "rc_registration_token_validity", defaults={"per_second": 0.1, "burst_count": 5}, ) # It is reasonable to login with a bunch of devices at once (i.e. when # setting up an account), but it is *not* valid to continually be # logging into new devices. - rc_login_config = config.get("rc_login", {}) - self.rc_login_address = RatelimitSettings( - rc_login_config.get("address", {}), + self.rc_login_address = RatelimitSettings.parse( + config, + "rc_login.address", defaults={"per_second": 0.003, "burst_count": 5}, ) - self.rc_login_account = RatelimitSettings( - rc_login_config.get("account", {}), + self.rc_login_account = RatelimitSettings.parse( + config, + "rc_login.account", defaults={"per_second": 0.003, "burst_count": 5}, ) - self.rc_login_failed_attempts = RatelimitSettings( - rc_login_config.get("failed_attempts", {}) + self.rc_login_failed_attempts = RatelimitSettings.parse( + config, + "rc_login.failed_attempts", + {}, ) self.federation_rr_transactions_per_room_per_second = config.get( "federation_rr_transactions_per_room_per_second", 50 ) - rc_admin_redaction = config.get("rc_admin_redaction") self.rc_admin_redaction = None - if rc_admin_redaction: - self.rc_admin_redaction = RatelimitSettings(rc_admin_redaction) + if "rc_admin_redaction" in config: + self.rc_admin_redaction = RatelimitSettings.parse( + config, "rc_admin_redaction", {} + ) - self.rc_joins_local = RatelimitSettings( - config.get("rc_joins", {}).get("local", {}), + self.rc_joins_local = RatelimitSettings.parse( + config, + "rc_joins.local", defaults={"per_second": 0.1, "burst_count": 10}, ) - self.rc_joins_remote = RatelimitSettings( - config.get("rc_joins", {}).get("remote", {}), + self.rc_joins_remote = RatelimitSettings.parse( + config, + "rc_joins.remote", defaults={"per_second": 0.01, "burst_count": 10}, ) # Track the rate of joins to a given room. If there are too many, temporarily # prevent local joins and remote joins via this server. - self.rc_joins_per_room = RatelimitSettings( - config.get("rc_joins_per_room", {}), + self.rc_joins_per_room = RatelimitSettings.parse( + config, + "rc_joins_per_room", defaults={"per_second": 1, "burst_count": 10}, ) @@ -132,31 +170,37 @@ class RatelimitConfig(Config): # * For requests received over federation this is keyed by the origin. # # Note that this isn't exposed in the configuration as it is obscure. - self.rc_key_requests = RatelimitSettings( - config.get("rc_key_requests", {}), + self.rc_key_requests = RatelimitSettings.parse( + config, + "rc_key_requests", defaults={"per_second": 20, "burst_count": 100}, ) - self.rc_3pid_validation = RatelimitSettings( - config.get("rc_3pid_validation") or {}, + self.rc_3pid_validation = RatelimitSettings.parse( + config, + "rc_3pid_validation", defaults={"per_second": 0.003, "burst_count": 5}, ) - self.rc_invites_per_room = RatelimitSettings( - config.get("rc_invites", {}).get("per_room", {}), + self.rc_invites_per_room = RatelimitSettings.parse( + config, + "rc_invites.per_room", defaults={"per_second": 0.3, "burst_count": 10}, ) - self.rc_invites_per_user = RatelimitSettings( - config.get("rc_invites", {}).get("per_user", {}), + self.rc_invites_per_user = RatelimitSettings.parse( + config, + "rc_invites.per_user", defaults={"per_second": 0.003, "burst_count": 5}, ) - self.rc_invites_per_issuer = RatelimitSettings( - config.get("rc_invites", {}).get("per_issuer", {}), + self.rc_invites_per_issuer = RatelimitSettings.parse( + config, + "rc_invites.per_issuer", defaults={"per_second": 0.3, "burst_count": 10}, ) - self.rc_third_party_invite = RatelimitSettings( - config.get("rc_third_party_invite", {}), + self.rc_third_party_invite = RatelimitSettings.parse( + config, + "rc_third_party_invite", defaults={"per_second": 0.0025, "burst_count": 5}, ) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 59ecafa6a0..2b0c505130 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -218,19 +218,17 @@ class AuthHandler: self._failed_uia_attempts_ratelimiter = Ratelimiter( store=self.store, clock=self.clock, - rate_hz=self.hs.config.ratelimiting.rc_login_failed_attempts.per_second, - burst_count=self.hs.config.ratelimiting.rc_login_failed_attempts.burst_count, + cfg=self.hs.config.ratelimiting.rc_login_failed_attempts, ) # The number of seconds to keep a UI auth session active. self._ui_auth_session_timeout = hs.config.auth.ui_auth_session_timeout - # Ratelimitier for failed /login attempts + # Ratelimiter for failed /login attempts self._failed_login_attempts_ratelimiter = Ratelimiter( store=self.store, clock=hs.get_clock(), - rate_hz=self.hs.config.ratelimiting.rc_login_failed_attempts.per_second, - burst_count=self.hs.config.ratelimiting.rc_login_failed_attempts.burst_count, + cfg=self.hs.config.ratelimiting.rc_login_failed_attempts, ) self._clock = self.hs.get_clock() diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index 17ff8821d9..798c7039f9 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -90,8 +90,7 @@ class DeviceMessageHandler: self._ratelimiter = Ratelimiter( store=self.store, clock=hs.get_clock(), - rate_hz=hs.config.ratelimiting.rc_key_requests.per_second, - burst_count=hs.config.ratelimiting.rc_key_requests.burst_count, + cfg=hs.config.ratelimiting.rc_key_requests, ) async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None: diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 3031384d25..472879c964 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -66,14 +66,12 @@ class IdentityHandler: self._3pid_validation_ratelimiter_ip = Ratelimiter( store=self.store, clock=hs.get_clock(), - rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second, - burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count, + cfg=hs.config.ratelimiting.rc_3pid_validation, ) self._3pid_validation_ratelimiter_address = Ratelimiter( store=self.store, clock=hs.get_clock(), - rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second, - burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count, + cfg=hs.config.ratelimiting.rc_3pid_validation, ) async def ratelimit_request_token_requests( diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 1d8d4a72e7..de0f04e3fe 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -112,8 +112,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): self._join_rate_limiter_local = Ratelimiter( store=self.store, clock=self.clock, - rate_hz=hs.config.ratelimiting.rc_joins_local.per_second, - burst_count=hs.config.ratelimiting.rc_joins_local.burst_count, + cfg=hs.config.ratelimiting.rc_joins_local, ) # Tracks joins from local users to rooms this server isn't a member of. # I.e. joins this server makes by requesting /make_join /send_join from @@ -121,8 +120,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): self._join_rate_limiter_remote = Ratelimiter( store=self.store, clock=self.clock, - rate_hz=hs.config.ratelimiting.rc_joins_remote.per_second, - burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count, + cfg=hs.config.ratelimiting.rc_joins_remote, ) # TODO: find a better place to keep this Ratelimiter. # It needs to be @@ -135,8 +133,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): self._join_rate_per_room_limiter = Ratelimiter( store=self.store, clock=self.clock, - rate_hz=hs.config.ratelimiting.rc_joins_per_room.per_second, - burst_count=hs.config.ratelimiting.rc_joins_per_room.burst_count, + cfg=hs.config.ratelimiting.rc_joins_per_room, ) # Ratelimiter for invites, keyed by room (across all issuers, all @@ -144,8 +141,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): self._invites_per_room_limiter = Ratelimiter( store=self.store, clock=self.clock, - rate_hz=hs.config.ratelimiting.rc_invites_per_room.per_second, - burst_count=hs.config.ratelimiting.rc_invites_per_room.burst_count, + cfg=hs.config.ratelimiting.rc_invites_per_room, ) # Ratelimiter for invites, keyed by recipient (across all rooms, all @@ -153,8 +149,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): self._invites_per_recipient_limiter = Ratelimiter( store=self.store, clock=self.clock, - rate_hz=hs.config.ratelimiting.rc_invites_per_user.per_second, - burst_count=hs.config.ratelimiting.rc_invites_per_user.burst_count, + cfg=hs.config.ratelimiting.rc_invites_per_user, ) # Ratelimiter for invites, keyed by issuer (across all rooms, all @@ -162,15 +157,13 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): self._invites_per_issuer_limiter = Ratelimiter( store=self.store, clock=self.clock, - rate_hz=hs.config.ratelimiting.rc_invites_per_issuer.per_second, - burst_count=hs.config.ratelimiting.rc_invites_per_issuer.burst_count, + cfg=hs.config.ratelimiting.rc_invites_per_issuer, ) self._third_party_invite_limiter = Ratelimiter( store=self.store, clock=self.clock, - rate_hz=hs.config.ratelimiting.rc_third_party_invite.per_second, - burst_count=hs.config.ratelimiting.rc_third_party_invite.burst_count, + cfg=hs.config.ratelimiting.rc_third_party_invite, ) self.request_ratelimiter = hs.get_request_ratelimiter() diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py index dad3e23470..dd559b4c45 100644 --- a/synapse/handlers/room_summary.py +++ b/synapse/handlers/room_summary.py @@ -35,6 +35,7 @@ from synapse.api.errors import ( UnsupportedRoomVersionError, ) from synapse.api.ratelimiting import Ratelimiter +from synapse.config.ratelimiting import RatelimitSettings from synapse.events import EventBase from synapse.types import JsonDict, Requester, StrCollection from synapse.util.caches.response_cache import ResponseCache @@ -94,7 +95,9 @@ class RoomSummaryHandler: self._server_name = hs.hostname self._federation_client = hs.get_federation_client() self._ratelimiter = Ratelimiter( - store=self._store, clock=hs.get_clock(), rate_hz=5, burst_count=10 + store=self._store, + clock=hs.get_clock(), + cfg=RatelimitSettings("", per_second=5, burst_count=10), ) # If a user tries to fetch the same page multiple times in quick succession, diff --git a/synapse/http/server.py b/synapse/http/server.py index 5109cec983..3bbf91298e 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -115,7 +115,13 @@ def return_json_error( if exc.headers is not None: for header, value in exc.headers.items(): request.setHeader(header, value) - logger.info("%s SynapseError: %s - %s", request, error_code, exc.msg) + error_ctx = exc.debug_context + if error_ctx: + logger.info( + "%s SynapseError: %s - %s (%s)", request, error_code, exc.msg, error_ctx + ) + else: + logger.info("%s SynapseError: %s - %s", request, error_code, exc.msg) elif f.check(CancelledError): error_code = HTTP_STATUS_REQUEST_CANCELLED error_dict = {"error": "Request cancelled", "errcode": Codes.UNKNOWN} diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index d724c68920..7be327e26f 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -120,14 +120,12 @@ class LoginRestServlet(RestServlet): self._address_ratelimiter = Ratelimiter( store=self._main_store, clock=hs.get_clock(), - rate_hz=self.hs.config.ratelimiting.rc_login_address.per_second, - burst_count=self.hs.config.ratelimiting.rc_login_address.burst_count, + cfg=self.hs.config.ratelimiting.rc_login_address, ) self._account_ratelimiter = Ratelimiter( store=self._main_store, clock=hs.get_clock(), - rate_hz=self.hs.config.ratelimiting.rc_login_account.per_second, - burst_count=self.hs.config.ratelimiting.rc_login_account.burst_count, + cfg=self.hs.config.ratelimiting.rc_login_account, ) # ensure the CAS/SAML/OIDC handlers are loaded on this worker instance. diff --git a/synapse/rest/client/login_token_request.py b/synapse/rest/client/login_token_request.py index b1629f94a5..d189a923b5 100644 --- a/synapse/rest/client/login_token_request.py +++ b/synapse/rest/client/login_token_request.py @@ -16,6 +16,7 @@ import logging from typing import TYPE_CHECKING, Tuple from synapse.api.ratelimiting import Ratelimiter +from synapse.config.ratelimiting import RatelimitSettings from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.site import SynapseRequest @@ -66,15 +67,18 @@ class LoginTokenRequestServlet(RestServlet): self.token_timeout = hs.config.auth.login_via_existing_token_timeout self._require_ui_auth = hs.config.auth.login_via_existing_require_ui_auth - # Ratelimit aggressively to a maxmimum of 1 request per minute. + # Ratelimit aggressively to a maximum of 1 request per minute. # # This endpoint can be used to spawn additional sessions and could be # abused by a malicious client to create many sessions. self._ratelimiter = Ratelimiter( store=self._main_store, clock=hs.get_clock(), - rate_hz=1 / 60, - burst_count=1, + cfg=RatelimitSettings( + key="", + per_second=1 / 60, + burst_count=1, + ), ) @interactive_auth_handler diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index 77e3b91b79..132623462a 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -376,8 +376,7 @@ class RegistrationTokenValidityRestServlet(RestServlet): self.ratelimiter = Ratelimiter( store=self.store, clock=hs.get_clock(), - rate_hz=hs.config.ratelimiting.rc_registration_token_validity.per_second, - burst_count=hs.config.ratelimiting.rc_registration_token_validity.burst_count, + cfg=hs.config.ratelimiting.rc_registration_token_validity, ) async def on_GET(self, request: Request) -> Tuple[int, JsonDict]: diff --git a/synapse/server.py b/synapse/server.py index 7cdd3ea3c2..fd16dacd0d 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -408,8 +408,7 @@ class HomeServer(metaclass=abc.ABCMeta): return Ratelimiter( store=self.get_datastores().main, clock=self.get_clock(), - rate_hz=self.config.ratelimiting.rc_registration.per_second, - burst_count=self.config.ratelimiting.rc_registration.burst_count, + cfg=self.config.ratelimiting.rc_registration, ) @cache_in_self diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py index cde4a0780f..f693ba2a8c 100644 --- a/synapse/util/ratelimitutils.py +++ b/synapse/util/ratelimitutils.py @@ -291,7 +291,8 @@ class _PerHostRatelimiter: if self.metrics_name: rate_limit_reject_counter.labels(self.metrics_name).inc() raise LimitExceededError( - retry_after_ms=int(self.window_size / self.sleep_limit) + limiter_name="rc_federation", + retry_after_ms=int(self.window_size / self.sleep_limit), ) self.request_times.append(time_now) diff --git a/tests/api/test_errors.py b/tests/api/test_errors.py index 319abfe63d..8e159029d9 100644 --- a/tests/api/test_errors.py +++ b/tests/api/test_errors.py @@ -1,6 +1,5 @@ # Copyright 2023 The Matrix.org Foundation C.I.C. # -# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -13,24 +12,32 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json + from synapse.api.errors import LimitExceededError from tests import unittest -class ErrorsTestCase(unittest.TestCase): +class LimitExceededErrorTestCase(unittest.TestCase): + def test_key_appears_in_context_but_not_error_dict(self) -> None: + err = LimitExceededError("needle") + serialised = json.dumps(err.error_dict(None)) + self.assertIn("needle", err.debug_context) + self.assertNotIn("needle", serialised) + # Create a sub-class to avoid mutating the class-level property. class LimitExceededErrorHeaders(LimitExceededError): include_retry_after_header = True def test_limit_exceeded_header(self) -> None: - err = ErrorsTestCase.LimitExceededErrorHeaders(retry_after_ms=100) + err = self.LimitExceededErrorHeaders(limiter_name="test", retry_after_ms=100) self.assertEqual(err.error_dict(None).get("retry_after_ms"), 100) assert err.headers is not None self.assertEqual(err.headers.get("Retry-After"), "1") def test_limit_exceeded_rounding(self) -> None: - err = ErrorsTestCase.LimitExceededErrorHeaders(retry_after_ms=3001) + err = self.LimitExceededErrorHeaders(limiter_name="test", retry_after_ms=3001) self.assertEqual(err.error_dict(None).get("retry_after_ms"), 3001) assert err.headers is not None self.assertEqual(err.headers.get("Retry-After"), "4") diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py index fa6c1c02ce..a24638c9ef 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py @@ -1,5 +1,6 @@ from synapse.api.ratelimiting import LimitExceededError, Ratelimiter from synapse.appservice import ApplicationService +from synapse.config.ratelimiting import RatelimitSettings from synapse.types import create_requester from tests import unittest @@ -10,8 +11,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): limiter = Ratelimiter( store=self.hs.get_datastores().main, clock=self.clock, - rate_hz=0.1, - burst_count=1, + cfg=RatelimitSettings(key="", per_second=0.1, burst_count=1), ) allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(None, key="test_id", _time_now_s=0) @@ -43,8 +43,11 @@ class TestRatelimiter(unittest.HomeserverTestCase): limiter = Ratelimiter( store=self.hs.get_datastores().main, clock=self.clock, - rate_hz=0.1, - burst_count=1, + cfg=RatelimitSettings( + key="", + per_second=0.1, + burst_count=1, + ), ) allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(as_requester, _time_now_s=0) @@ -76,8 +79,11 @@ class TestRatelimiter(unittest.HomeserverTestCase): limiter = Ratelimiter( store=self.hs.get_datastores().main, clock=self.clock, - rate_hz=0.1, - burst_count=1, + cfg=RatelimitSettings( + key="", + per_second=0.1, + burst_count=1, + ), ) allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(as_requester, _time_now_s=0) @@ -101,8 +107,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): limiter = Ratelimiter( store=self.hs.get_datastores().main, clock=self.clock, - rate_hz=0.1, - burst_count=1, + cfg=RatelimitSettings(key="", per_second=0.1, burst_count=1), ) # Shouldn't raise @@ -128,8 +133,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): limiter = Ratelimiter( store=self.hs.get_datastores().main, clock=self.clock, - rate_hz=0.1, - burst_count=1, + cfg=RatelimitSettings(key="", per_second=0.1, burst_count=1), ) # First attempt should be allowed @@ -177,8 +181,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): limiter = Ratelimiter( store=self.hs.get_datastores().main, clock=self.clock, - rate_hz=0.1, - burst_count=1, + cfg=RatelimitSettings(key="", per_second=0.1, burst_count=1), ) # First attempt should be allowed @@ -208,8 +211,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): limiter = Ratelimiter( store=self.hs.get_datastores().main, clock=self.clock, - rate_hz=0.1, - burst_count=1, + cfg=RatelimitSettings(key="", per_second=0.1, burst_count=1), ) self.get_success_or_raise( limiter.can_do_action(None, key="test_id_1", _time_now_s=0) @@ -244,7 +246,11 @@ class TestRatelimiter(unittest.HomeserverTestCase): ) ) - limiter = Ratelimiter(store=store, clock=self.clock, rate_hz=0.1, burst_count=1) + limiter = Ratelimiter( + store=store, + clock=self.clock, + cfg=RatelimitSettings("", per_second=0.1, burst_count=1), + ) # Shouldn't raise for _ in range(20): @@ -254,8 +260,11 @@ class TestRatelimiter(unittest.HomeserverTestCase): limiter = Ratelimiter( store=self.hs.get_datastores().main, clock=self.clock, - rate_hz=0.1, - burst_count=3, + cfg=RatelimitSettings( + key="", + per_second=0.1, + burst_count=3, + ), ) # Test that 4 actions aren't allowed with a maximum burst of 3. allowed, time_allowed = self.get_success_or_raise( @@ -321,8 +330,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): limiter = Ratelimiter( store=self.hs.get_datastores().main, clock=self.clock, - rate_hz=0.1, - burst_count=3, + cfg=RatelimitSettings("", per_second=0.1, burst_count=3), ) def consume_at(time: float) -> bool: @@ -346,8 +354,11 @@ class TestRatelimiter(unittest.HomeserverTestCase): limiter = Ratelimiter( store=self.hs.get_datastores().main, clock=self.clock, - rate_hz=0.1, - burst_count=3, + cfg=RatelimitSettings( + "", + per_second=0.1, + burst_count=3, + ), ) # Observe two actions, leaving room in the bucket for one more. @@ -369,8 +380,11 @@ class TestRatelimiter(unittest.HomeserverTestCase): limiter = Ratelimiter( store=self.hs.get_datastores().main, clock=self.clock, - rate_hz=0.1, - burst_count=3, + cfg=RatelimitSettings( + "", + per_second=0.1, + burst_count=3, + ), ) # Observe three actions, filling up the bucket. @@ -398,8 +412,11 @@ class TestRatelimiter(unittest.HomeserverTestCase): limiter = Ratelimiter( store=self.hs.get_datastores().main, clock=self.clock, - rate_hz=0.1, - burst_count=3, + cfg=RatelimitSettings( + "", + per_second=0.1, + burst_count=3, + ), ) # Observe four actions, exceeding the bucket. diff --git a/tests/config/test_ratelimiting.py b/tests/config/test_ratelimiting.py index f12147eaa0..0c27dd21e2 100644 --- a/tests/config/test_ratelimiting.py +++ b/tests/config/test_ratelimiting.py @@ -12,11 +12,42 @@ # See the License for the specific language governing permissions and # limitations under the License. from synapse.config.homeserver import HomeServerConfig +from synapse.config.ratelimiting import RatelimitSettings from tests.unittest import TestCase from tests.utils import default_config +class ParseRatelimitSettingsTestcase(TestCase): + def test_depth_1(self) -> None: + cfg = { + "a": { + "per_second": 5, + "burst_count": 10, + } + } + parsed = RatelimitSettings.parse(cfg, "a") + self.assertEqual(parsed, RatelimitSettings("a", 5, 10)) + + def test_depth_2(self) -> None: + cfg = { + "a": { + "b": { + "per_second": 5, + "burst_count": 10, + }, + } + } + parsed = RatelimitSettings.parse(cfg, "a.b") + self.assertEqual(parsed, RatelimitSettings("a.b", 5, 10)) + + def test_missing(self) -> None: + parsed = RatelimitSettings.parse( + {}, "a", defaults={"per_second": 5, "burst_count": 10} + ) + self.assertEqual(parsed, RatelimitSettings("a", 5, 10)) + + class RatelimitConfigTestCase(TestCase): def test_parse_rc_federation(self) -> None: config_dict = default_config("test") -- cgit 1.5.1 From f84baecb6f57b5ddb570c43574f774fae5e8afed Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 4 Sep 2023 14:04:43 +0100 Subject: Don't reset retry timers on "valid" error codes (#16221) --- changelog.d/16221.bugfix | 1 + synapse/federation/transport/client.py | 4 +++- synapse/http/matrixfederationclient.py | 8 ++++++++ synapse/util/retryutils.py | 18 ++++++++++++++++-- tests/handlers/test_typing.py | 4 ++-- 5 files changed, 30 insertions(+), 5 deletions(-) create mode 100644 changelog.d/16221.bugfix (limited to 'synapse/util') diff --git a/changelog.d/16221.bugfix b/changelog.d/16221.bugfix new file mode 100644 index 0000000000..22678256e4 --- /dev/null +++ b/changelog.d/16221.bugfix @@ -0,0 +1 @@ +Fix long-standing bug where we did not correctly back off from servers that had "gone" if they returned 4xx series error codes. diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 5ce3f345cb..b5e4b2680e 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -249,8 +249,10 @@ class TransportLayerClient: data=json_data, json_data_callback=json_data_callback, long_retries=True, - backoff_on_404=True, # If we get a 404 the other side has gone try_trailing_slash_on_400=True, + # Sending a transaction should always succeed, if it doesn't + # then something is wrong and we should backoff. + backoff_on_all_error_codes=True, ) async def make_query( diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 11342ccac8..08c7fc1631 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -512,6 +512,7 @@ class MatrixFederationHttpClient: long_retries: bool = False, ignore_backoff: bool = False, backoff_on_404: bool = False, + backoff_on_all_error_codes: bool = False, ) -> IResponse: """ Sends a request to the given server. @@ -552,6 +553,7 @@ class MatrixFederationHttpClient: and try the request anyway. backoff_on_404: Back off if we get a 404 + backoff_on_all_error_codes: Back off if we get any error response Returns: Resolves with the HTTP response object on success. @@ -594,6 +596,7 @@ class MatrixFederationHttpClient: ignore_backoff=ignore_backoff, notifier=self.hs.get_notifier(), replication_client=self.hs.get_replication_command_handler(), + backoff_on_all_error_codes=backoff_on_all_error_codes, ) method_bytes = request.method.encode("ascii") @@ -889,6 +892,7 @@ class MatrixFederationHttpClient: backoff_on_404: bool = False, try_trailing_slash_on_400: bool = False, parser: Literal[None] = None, + backoff_on_all_error_codes: bool = False, ) -> JsonDict: ... @@ -906,6 +910,7 @@ class MatrixFederationHttpClient: backoff_on_404: bool = False, try_trailing_slash_on_400: bool = False, parser: Optional[ByteParser[T]] = None, + backoff_on_all_error_codes: bool = False, ) -> T: ... @@ -922,6 +927,7 @@ class MatrixFederationHttpClient: backoff_on_404: bool = False, try_trailing_slash_on_400: bool = False, parser: Optional[ByteParser[T]] = None, + backoff_on_all_error_codes: bool = False, ) -> Union[JsonDict, T]: """Sends the specified json data using PUT @@ -957,6 +963,7 @@ class MatrixFederationHttpClient: enabled. parser: The parser to use to decode the response. Defaults to parsing as JSON. + backoff_on_all_error_codes: Back off if we get any error response Returns: Succeeds when we get a 2xx HTTP response. The @@ -990,6 +997,7 @@ class MatrixFederationHttpClient: ignore_backoff=ignore_backoff, long_retries=long_retries, timeout=timeout, + backoff_on_all_error_codes=backoff_on_all_error_codes, ) if timeout is not None: diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py index 27e9fc976c..9d2065372c 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py @@ -128,6 +128,7 @@ class RetryDestinationLimiter: backoff_on_failure: bool = True, notifier: Optional["Notifier"] = None, replication_client: Optional["ReplicationCommandHandler"] = None, + backoff_on_all_error_codes: bool = False, ): """Marks the destination as "down" if an exception is thrown in the context, except for CodeMessageException with code < 500. @@ -147,6 +148,9 @@ class RetryDestinationLimiter: backoff_on_failure: set to False if we should not increase the retry interval on a failure. + + backoff_on_all_error_codes: Whether we should back off on any + error code. """ self.clock = clock self.store = store @@ -156,6 +160,7 @@ class RetryDestinationLimiter: self.retry_interval = retry_interval self.backoff_on_404 = backoff_on_404 self.backoff_on_failure = backoff_on_failure + self.backoff_on_all_error_codes = backoff_on_all_error_codes self.notifier = notifier self.replication_client = replication_client @@ -179,6 +184,7 @@ class RetryDestinationLimiter: exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], ) -> None: + success = exc_type is None valid_err_code = False if exc_type is None: valid_err_code = True @@ -195,7 +201,9 @@ class RetryDestinationLimiter: # won't accept our requests for at least a while. # 429 is us being aggressively rate limited, so lets rate limit # ourselves. - if exc_val.code == 404 and self.backoff_on_404: + if self.backoff_on_all_error_codes: + valid_err_code = False + elif exc_val.code == 404 and self.backoff_on_404: valid_err_code = False elif exc_val.code in (401, 429): valid_err_code = False @@ -204,7 +212,7 @@ class RetryDestinationLimiter: else: valid_err_code = False - if valid_err_code: + if success: # We connected successfully. if not self.retry_interval: return @@ -215,6 +223,12 @@ class RetryDestinationLimiter: self.failure_ts = None retry_last_ts = 0 self.retry_interval = 0 + elif valid_err_code: + # We got a potentially valid error code back. We don't reset the + # timers though, as the other side might actually be down anyway + # (e.g. some deprovisioned servers will always return a 404 or 403, + # and we don't want to keep resetting the retry timers for them). + return elif not self.backoff_on_failure: return else: diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 2a295da3a0..43c513b157 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -251,8 +251,8 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ), json_data_callback=ANY, long_retries=True, - backoff_on_404=True, try_trailing_slash_on_400=True, + backoff_on_all_error_codes=True, ) def test_started_typing_remote_recv(self) -> None: @@ -366,7 +366,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ), json_data_callback=ANY, long_retries=True, - backoff_on_404=True, + backoff_on_all_error_codes=True, try_trailing_slash_on_400=True, ) -- cgit 1.5.1 From d35bed8369514fe727b4fe1afb68f48cc8b2655a Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 4 Sep 2023 17:14:09 +0100 Subject: Don't wake up destination transaction queue if they're not due for retry. (#16223) --- changelog.d/16223.feature | 1 + synapse/federation/send_queue.py | 12 +-- synapse/federation/sender/__init__.py | 86 +++++++++++++++------- synapse/federation/sender/per_destination_queue.py | 6 +- synapse/handlers/device.py | 26 +++---- synapse/handlers/devicemessage.py | 7 +- synapse/handlers/presence.py | 16 ++-- synapse/handlers/typing.py | 14 +++- synapse/module_api/__init__.py | 2 +- synapse/replication/tcp/client.py | 8 +- synapse/storage/databases/main/transactions.py | 26 ++++++- synapse/util/retryutils.py | 25 +++++++ tests/federation/test_federation_sender.py | 27 ++++--- tests/handlers/test_presence.py | 60 ++++++++++++--- tests/handlers/test_typing.py | 2 - 15 files changed, 228 insertions(+), 90 deletions(-) create mode 100644 changelog.d/16223.feature (limited to 'synapse/util') diff --git a/changelog.d/16223.feature b/changelog.d/16223.feature new file mode 100644 index 0000000000..a52d66658b --- /dev/null +++ b/changelog.d/16223.feature @@ -0,0 +1 @@ +Improve resource usage when sending data to a large number of remote hosts that are marked as "down". diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index fb448f2155..6520795635 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -49,7 +49,7 @@ from synapse.api.presence import UserPresenceState from synapse.federation.sender import AbstractFederationSender, FederationSender from synapse.metrics import LaterGauge from synapse.replication.tcp.streams.federation import FederationStream -from synapse.types import JsonDict, ReadReceipt, RoomStreamToken +from synapse.types import JsonDict, ReadReceipt, RoomStreamToken, StrCollection from synapse.util.metrics import Measure from .units import Edu @@ -229,7 +229,7 @@ class FederationRemoteSendQueue(AbstractFederationSender): """ # nothing to do here: the replication listener will handle it. - def send_presence_to_destinations( + async def send_presence_to_destinations( self, states: Iterable[UserPresenceState], destinations: Iterable[str] ) -> None: """As per FederationSender @@ -245,7 +245,9 @@ class FederationRemoteSendQueue(AbstractFederationSender): self.notifier.on_new_replication_data() - def send_device_messages(self, destination: str, immediate: bool = True) -> None: + async def send_device_messages( + self, destinations: StrCollection, immediate: bool = True + ) -> None: """As per FederationSender""" # We don't need to replicate this as it gets sent down a different # stream. @@ -463,7 +465,7 @@ class ParsedFederationStreamData: edus: Dict[str, List[Edu]] -def process_rows_for_federation( +async def process_rows_for_federation( transaction_queue: FederationSender, rows: List[FederationStream.FederationStreamRow], ) -> None: @@ -496,7 +498,7 @@ def process_rows_for_federation( parsed_row.add_to_buffer(buff) for state, destinations in buff.presence_destinations: - transaction_queue.send_presence_to_destinations( + await transaction_queue.send_presence_to_destinations( states=[state], destinations=destinations ) diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 97abbdee18..fb20fd8a10 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -147,7 +147,10 @@ from twisted.internet import defer import synapse.metrics from synapse.api.presence import UserPresenceState from synapse.events import EventBase -from synapse.federation.sender.per_destination_queue import PerDestinationQueue +from synapse.federation.sender.per_destination_queue import ( + CATCHUP_RETRY_INTERVAL, + PerDestinationQueue, +) from synapse.federation.sender.transaction_manager import TransactionManager from synapse.federation.units import Edu from synapse.logging.context import make_deferred_yieldable, run_in_background @@ -161,9 +164,10 @@ from synapse.metrics.background_process_metrics import ( run_as_background_process, wrap_as_background_process, ) -from synapse.types import JsonDict, ReadReceipt, RoomStreamToken +from synapse.types import JsonDict, ReadReceipt, RoomStreamToken, StrCollection from synapse.util import Clock from synapse.util.metrics import Measure +from synapse.util.retryutils import filter_destinations_by_retry_limiter if TYPE_CHECKING: from synapse.events.presence_router import PresenceRouter @@ -213,7 +217,7 @@ class AbstractFederationSender(metaclass=abc.ABCMeta): raise NotImplementedError() @abc.abstractmethod - def send_presence_to_destinations( + async def send_presence_to_destinations( self, states: Iterable[UserPresenceState], destinations: Iterable[str] ) -> None: """Send the given presence states to the given destinations. @@ -242,9 +246,11 @@ class AbstractFederationSender(metaclass=abc.ABCMeta): raise NotImplementedError() @abc.abstractmethod - def send_device_messages(self, destination: str, immediate: bool = True) -> None: + async def send_device_messages( + self, destinations: StrCollection, immediate: bool = True + ) -> None: """Tells the sender that a new device message is ready to be sent to the - destination. The `immediate` flag specifies whether the messages should + destinations. The `immediate` flag specifies whether the messages should be tried to be sent immediately, or whether it can be delayed for a short while (to aid performance). """ @@ -716,6 +722,13 @@ class FederationSender(AbstractFederationSender): pdu.internal_metadata.stream_ordering, ) + destinations = await filter_destinations_by_retry_limiter( + destinations, + clock=self.clock, + store=self.store, + retry_due_within_ms=CATCHUP_RETRY_INTERVAL, + ) + for destination in destinations: self._get_per_destination_queue(destination).send_pdu(pdu) @@ -763,12 +776,20 @@ class FederationSender(AbstractFederationSender): domains_set = await self._storage_controllers.state.get_current_hosts_in_room_or_partial_state_approximation( room_id ) - domains = [ + domains: StrCollection = [ d for d in domains_set if not self.is_mine_server_name(d) and self._federation_shard_config.should_handle(self._instance_name, d) ] + + domains = await filter_destinations_by_retry_limiter( + domains, + clock=self.clock, + store=self.store, + retry_due_within_ms=CATCHUP_RETRY_INTERVAL, + ) + if not domains: return @@ -816,7 +837,7 @@ class FederationSender(AbstractFederationSender): for queue in queues: queue.flush_read_receipts_for_room(room_id) - def send_presence_to_destinations( + async def send_presence_to_destinations( self, states: Iterable[UserPresenceState], destinations: Iterable[str] ) -> None: """Send the given presence states to the given destinations. @@ -831,13 +852,20 @@ class FederationSender(AbstractFederationSender): for state in states: assert self.is_mine_id(state.user_id) + destinations = await filter_destinations_by_retry_limiter( + [ + d + for d in destinations + if self._federation_shard_config.should_handle(self._instance_name, d) + ], + clock=self.clock, + store=self.store, + retry_due_within_ms=CATCHUP_RETRY_INTERVAL, + ) + for destination in destinations: if self.is_mine_server_name(destination): continue - if not self._federation_shard_config.should_handle( - self._instance_name, destination - ): - continue self._get_per_destination_queue(destination).send_presence( states, start_loop=False @@ -896,21 +924,29 @@ class FederationSender(AbstractFederationSender): else: queue.send_edu(edu) - def send_device_messages(self, destination: str, immediate: bool = True) -> None: - if self.is_mine_server_name(destination): - logger.warning("Not sending device update to ourselves") - return - - if not self._federation_shard_config.should_handle( - self._instance_name, destination - ): - return + async def send_device_messages( + self, destinations: StrCollection, immediate: bool = True + ) -> None: + destinations = await filter_destinations_by_retry_limiter( + [ + destination + for destination in destinations + if self._federation_shard_config.should_handle( + self._instance_name, destination + ) + and not self.is_mine_server_name(destination) + ], + clock=self.clock, + store=self.store, + retry_due_within_ms=CATCHUP_RETRY_INTERVAL, + ) - if immediate: - self._get_per_destination_queue(destination).attempt_new_transaction() - else: - self._get_per_destination_queue(destination).mark_new_data() - self._destination_wakeup_queue.add_to_queue(destination) + for destination in destinations: + if immediate: + self._get_per_destination_queue(destination).attempt_new_transaction() + else: + self._get_per_destination_queue(destination).mark_new_data() + self._destination_wakeup_queue.add_to_queue(destination) def wake_destination(self, destination: str) -> None: """Called when we want to retry sending transactions to a remote. diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index 31c5c2b7de..9105ba664c 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -59,6 +59,10 @@ sent_edus_by_type = Counter( ) +# If the retry interval is larger than this then we enter "catchup" mode +CATCHUP_RETRY_INTERVAL = 60 * 60 * 1000 + + class PerDestinationQueue: """ Manages the per-destination transmission queues. @@ -370,7 +374,7 @@ class PerDestinationQueue: ), ) - if e.retry_interval > 60 * 60 * 1000: + if e.retry_interval > CATCHUP_RETRY_INTERVAL: # we won't retry for another hour! # (this suggests a significant outage) # We drop pending EDUs because otherwise they will diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 5ae427d52c..763f56dfc1 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -836,17 +836,16 @@ class DeviceHandler(DeviceWorkerHandler): user_id, hosts, ) - for host in hosts: - self.federation_sender.send_device_messages( - host, immediate=False - ) - # TODO: when called, this isn't in a logging context. - # This leads to log spam, sentry event spam, and massive - # memory usage. - # See https://github.com/matrix-org/synapse/issues/12552. - # log_kv( - # {"message": "sent device update to host", "host": host} - # ) + await self.federation_sender.send_device_messages( + hosts, immediate=False + ) + # TODO: when called, this isn't in a logging context. + # This leads to log spam, sentry event spam, and massive + # memory usage. + # See https://github.com/matrix-org/synapse/issues/12552. + # log_kv( + # {"message": "sent device update to host", "host": host} + # ) if current_stream_id != stream_id: # Clear the set of hosts we've already sent to as we're @@ -951,8 +950,9 @@ class DeviceHandler(DeviceWorkerHandler): # Notify things that device lists need to be sent out. self.notifier.notify_replication() - for host in potentially_changed_hosts: - self.federation_sender.send_device_messages(host, immediate=False) + await self.federation_sender.send_device_messages( + potentially_changed_hosts, immediate=False + ) def _update_device_from_client_ips( diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index 798c7039f9..1c79f7a61e 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -302,10 +302,9 @@ class DeviceMessageHandler: ) if self.federation_sender: - for destination in remote_messages.keys(): - # Enqueue a new federation transaction to send the new - # device messages to each remote destination. - self.federation_sender.send_device_messages(destination) + # Enqueue a new federation transaction to send the new + # device messages to each remote destination. + await self.federation_sender.send_device_messages(remote_messages.keys()) async def get_events_for_dehydrated_device( self, diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 2f841863ae..f31e18328b 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -354,7 +354,9 @@ class BasePresenceHandler(abc.ABC): ) for destination, host_states in hosts_to_states.items(): - self._federation.send_presence_to_destinations(host_states, [destination]) + await self._federation.send_presence_to_destinations( + host_states, [destination] + ) async def send_full_presence_to_users(self, user_ids: StrCollection) -> None: """ @@ -936,7 +938,7 @@ class PresenceHandler(BasePresenceHandler): ) for destination, states in hosts_to_states.items(): - self._federation_queue.send_presence_to_destinations( + await self._federation_queue.send_presence_to_destinations( states, [destination] ) @@ -1508,7 +1510,7 @@ class PresenceHandler(BasePresenceHandler): or state.status_msg is not None ] - self._federation_queue.send_presence_to_destinations( + await self._federation_queue.send_presence_to_destinations( destinations=newly_joined_remote_hosts, states=states, ) @@ -1519,7 +1521,7 @@ class PresenceHandler(BasePresenceHandler): prev_remote_hosts or newly_joined_remote_hosts ): local_states = await self.current_state_for_users(newly_joined_local_users) - self._federation_queue.send_presence_to_destinations( + await self._federation_queue.send_presence_to_destinations( destinations=prev_remote_hosts | newly_joined_remote_hosts, states=list(local_states.values()), ) @@ -2182,7 +2184,7 @@ class PresenceFederationQueue: index = bisect(self._queue, (clear_before,)) self._queue = self._queue[index:] - def send_presence_to_destinations( + async def send_presence_to_destinations( self, states: Collection[UserPresenceState], destinations: StrCollection ) -> None: """Send the presence states to the given destinations. @@ -2202,7 +2204,7 @@ class PresenceFederationQueue: return if self._federation: - self._federation.send_presence_to_destinations( + await self._federation.send_presence_to_destinations( states=states, destinations=destinations, ) @@ -2325,7 +2327,7 @@ class PresenceFederationQueue: for host, user_ids in hosts_to_users.items(): states = await self._presence_handler.current_state_for_users(user_ids) - self._federation.send_presence_to_destinations( + await self._federation.send_presence_to_destinations( states=states.values(), destinations=[host], ) diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 7aeae5319c..4b4227003d 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -26,9 +26,10 @@ from synapse.metrics.background_process_metrics import ( ) from synapse.replication.tcp.streams import TypingStream from synapse.streams import EventSource -from synapse.types import JsonDict, Requester, StreamKeyType, UserID +from synapse.types import JsonDict, Requester, StrCollection, StreamKeyType, UserID from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.metrics import Measure +from synapse.util.retryutils import filter_destinations_by_retry_limiter from synapse.util.wheel_timer import WheelTimer if TYPE_CHECKING: @@ -150,8 +151,15 @@ class FollowerTypingHandler: now=now, obj=member, then=now + FEDERATION_PING_INTERVAL ) - hosts = await self._storage_controllers.state.get_current_hosts_in_room( - member.room_id + hosts: StrCollection = ( + await self._storage_controllers.state.get_current_hosts_in_room( + member.room_id + ) + ) + hosts = await filter_destinations_by_retry_limiter( + hosts, + clock=self.clock, + store=self.store, ) for domain in hosts: if not self.is_mine_server_name(domain): diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 9ad8e038ae..2f00a7ba20 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -1180,7 +1180,7 @@ class ModuleApi: # Send to remote destinations. destination = UserID.from_string(user).domain - presence_handler.get_federation_queue().send_presence_to_destinations( + await presence_handler.get_federation_queue().send_presence_to_destinations( presence_events, [destination] ) diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 3b88dc68ea..51285e6d33 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -422,7 +422,7 @@ class FederationSenderHandler: # The federation stream contains things that we want to send out, e.g. # presence, typing, etc. if stream_name == "federation": - send_queue.process_rows_for_federation(self.federation_sender, rows) + await send_queue.process_rows_for_federation(self.federation_sender, rows) await self.update_token(token) # ... and when new receipts happen @@ -439,16 +439,14 @@ class FederationSenderHandler: for row in rows if not row.entity.startswith("@") and not row.is_signature } - for host in hosts: - self.federation_sender.send_device_messages(host, immediate=False) + await self.federation_sender.send_device_messages(hosts, immediate=False) elif stream_name == ToDeviceStream.NAME: # The to_device stream includes stuff to be pushed to both local # clients and remote servers, so we ignore entities that start with # '@' (since they'll be local users rather than destinations). hosts = {row.entity for row in rows if not row.entity.startswith("@")} - for host in hosts: - self.federation_sender.send_device_messages(host) + await self.federation_sender.send_device_messages(hosts) async def _on_new_receipts( self, rows: Iterable[ReceiptsStream.ReceiptsStreamRow] diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index 860bbf7c0f..efd21b5bfc 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -14,7 +14,7 @@ import logging from enum import Enum -from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, cast +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, cast import attr from canonicaljson import encode_canonical_json @@ -28,8 +28,8 @@ from synapse.storage.database import ( LoggingTransaction, ) from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore -from synapse.types import JsonDict -from synapse.util.caches.descriptors import cached +from synapse.types import JsonDict, StrCollection +from synapse.util.caches.descriptors import cached, cachedList if TYPE_CHECKING: from synapse.server import HomeServer @@ -205,6 +205,26 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): else: return None + @cachedList( + cached_method_name="get_destination_retry_timings", list_name="destinations" + ) + async def get_destination_retry_timings_batch( + self, destinations: StrCollection + ) -> Dict[str, Optional[DestinationRetryTimings]]: + rows = await self.db_pool.simple_select_many_batch( + table="destinations", + iterable=destinations, + column="destination", + retcols=("destination", "failure_ts", "retry_last_ts", "retry_interval"), + desc="get_destination_retry_timings_batch", + ) + + return { + row.pop("destination"): DestinationRetryTimings(**row) + for row in rows + if row["retry_last_ts"] and row["failure_ts"] and row["retry_interval"] + } + async def set_destination_retry_timings( self, destination: str, diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py index 9d2065372c..0e1f907667 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py @@ -19,6 +19,7 @@ from typing import TYPE_CHECKING, Any, Optional, Type from synapse.api.errors import CodeMessageException from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage import DataStore +from synapse.types import StrCollection from synapse.util import Clock if TYPE_CHECKING: @@ -116,6 +117,30 @@ async def get_retry_limiter( ) +async def filter_destinations_by_retry_limiter( + destinations: StrCollection, + clock: Clock, + store: DataStore, + retry_due_within_ms: int = 0, +) -> StrCollection: + """Filter down the list of destinations to only those that will are either + alive or due for a retry (within `retry_due_within_ms`) + """ + if not destinations: + return destinations + + retry_timings = await store.get_destination_retry_timings_batch(destinations) + + now = int(clock.time_msec()) + + return [ + destination + for destination, timings in retry_timings.items() + if timings is None + or timings.retry_last_ts + timings.retry_interval <= now + retry_due_within_ms + ] + + class RetryDestinationLimiter: def __init__( self, diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index 7bd3d06859..caf04b54cb 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -75,7 +75,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): thread_id=None, data={"ts": 1234}, ) - self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt))) + self.get_success(sender.send_read_receipt(receipt)) self.pump() @@ -111,6 +111,9 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): # * The same room / user on multiple threads. # * A different user in the same room. sender = self.hs.get_federation_sender() + # Hack so that we have a txn in-flight so we batch up read receipts + # below + sender.wake_destination("host2") for user, thread in ( ("alice", None), ("alice", "thread"), @@ -125,9 +128,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): thread_id=thread, data={"ts": 1234}, ) - self.successResultOf( - defer.ensureDeferred(sender.send_read_receipt(receipt)) - ) + defer.ensureDeferred(sender.send_read_receipt(receipt)) self.pump() @@ -191,7 +192,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): thread_id=None, data={"ts": 1234}, ) - self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt))) + self.get_success(sender.send_read_receipt(receipt)) self.pump() @@ -342,7 +343,9 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): self.reactor.advance(1) # a second call should produce no new device EDUs - self.hs.get_federation_sender().send_device_messages("host2") + self.get_success( + self.hs.get_federation_sender().send_device_messages(["host2"]) + ) self.assertEqual(self.edus, []) # a second device @@ -550,7 +553,9 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): # recover the server mock_send_txn.side_effect = self.record_transaction - self.hs.get_federation_sender().send_device_messages("host2") + self.get_success( + self.hs.get_federation_sender().send_device_messages(["host2"]) + ) # We queue up device list updates to be sent over federation, so we # advance to clear the queue. @@ -601,7 +606,9 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): # recover the server mock_send_txn.side_effect = self.record_transaction - self.hs.get_federation_sender().send_device_messages("host2") + self.get_success( + self.hs.get_federation_sender().send_device_messages(["host2"]) + ) # We queue up device list updates to be sent over federation, so we # advance to clear the queue. @@ -656,7 +663,9 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): # recover the server mock_send_txn.side_effect = self.record_transaction - self.hs.get_federation_sender().send_device_messages("host2") + self.get_success( + self.hs.get_federation_sender().send_device_messages(["host2"]) + ) # We queue up device list updates to be sent over federation, so we # advance to clear the queue. diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index a987267308..88a16193a3 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -909,8 +909,14 @@ class PresenceFederationQueueTestCase(unittest.HomeserverTestCase): prev_token = self.queue.get_current_token(self.instance_name) - self.queue.send_presence_to_destinations((state1, state2), ("dest1", "dest2")) - self.queue.send_presence_to_destinations((state3,), ("dest3",)) + self.get_success( + self.queue.send_presence_to_destinations( + (state1, state2), ("dest1", "dest2") + ) + ) + self.get_success( + self.queue.send_presence_to_destinations((state3,), ("dest3",)) + ) now_token = self.queue.get_current_token(self.instance_name) @@ -946,11 +952,17 @@ class PresenceFederationQueueTestCase(unittest.HomeserverTestCase): prev_token = self.queue.get_current_token(self.instance_name) - self.queue.send_presence_to_destinations((state1, state2), ("dest1", "dest2")) + self.get_success( + self.queue.send_presence_to_destinations( + (state1, state2), ("dest1", "dest2") + ) + ) now_token = self.queue.get_current_token(self.instance_name) - self.queue.send_presence_to_destinations((state3,), ("dest3",)) + self.get_success( + self.queue.send_presence_to_destinations((state3,), ("dest3",)) + ) rows, upto_token, limited = self.get_success( self.queue.get_replication_rows("master", prev_token, now_token, 10) @@ -989,8 +1001,14 @@ class PresenceFederationQueueTestCase(unittest.HomeserverTestCase): prev_token = self.queue.get_current_token(self.instance_name) - self.queue.send_presence_to_destinations((state1, state2), ("dest1", "dest2")) - self.queue.send_presence_to_destinations((state3,), ("dest3",)) + self.get_success( + self.queue.send_presence_to_destinations( + (state1, state2), ("dest1", "dest2") + ) + ) + self.get_success( + self.queue.send_presence_to_destinations((state3,), ("dest3",)) + ) self.reactor.advance(10 * 60 * 1000) @@ -1005,8 +1023,14 @@ class PresenceFederationQueueTestCase(unittest.HomeserverTestCase): prev_token = self.queue.get_current_token(self.instance_name) - self.queue.send_presence_to_destinations((state1, state2), ("dest1", "dest2")) - self.queue.send_presence_to_destinations((state3,), ("dest3",)) + self.get_success( + self.queue.send_presence_to_destinations( + (state1, state2), ("dest1", "dest2") + ) + ) + self.get_success( + self.queue.send_presence_to_destinations((state3,), ("dest3",)) + ) now_token = self.queue.get_current_token(self.instance_name) @@ -1033,11 +1057,17 @@ class PresenceFederationQueueTestCase(unittest.HomeserverTestCase): prev_token = self.queue.get_current_token(self.instance_name) - self.queue.send_presence_to_destinations((state1, state2), ("dest1", "dest2")) + self.get_success( + self.queue.send_presence_to_destinations( + (state1, state2), ("dest1", "dest2") + ) + ) self.reactor.advance(2 * 60 * 1000) - self.queue.send_presence_to_destinations((state3,), ("dest3",)) + self.get_success( + self.queue.send_presence_to_destinations((state3,), ("dest3",)) + ) self.reactor.advance(4 * 60 * 1000) @@ -1053,8 +1083,14 @@ class PresenceFederationQueueTestCase(unittest.HomeserverTestCase): prev_token = self.queue.get_current_token(self.instance_name) - self.queue.send_presence_to_destinations((state1, state2), ("dest1", "dest2")) - self.queue.send_presence_to_destinations((state3,), ("dest3",)) + self.get_success( + self.queue.send_presence_to_destinations( + (state1, state2), ("dest1", "dest2") + ) + ) + self.get_success( + self.queue.send_presence_to_destinations((state3,), ("dest3",)) + ) now_token = self.queue.get_current_token(self.instance_name) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 43c513b157..95106ec8f3 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -120,8 +120,6 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.datastore = hs.get_datastores().main - self.datastore.get_destination_retry_timings = AsyncMock(return_value=None) - self.datastore.get_device_updates_by_remote = AsyncMock( # type: ignore[method-assign] return_value=(0, []) ) -- cgit 1.5.1