diff options
39 files changed, 496 insertions, 130 deletions
diff --git a/changelog.d/8230.misc b/changelog.d/8230.misc new file mode 100644 index 0000000000..bf0ba76730 --- /dev/null +++ b/changelog.d/8230.misc @@ -0,0 +1 @@ +Track the latest event for every destination and room for catch-up after federation outage. diff --git a/changelog.d/8243.misc b/changelog.d/8243.misc new file mode 100644 index 0000000000..f7375d32d3 --- /dev/null +++ b/changelog.d/8243.misc @@ -0,0 +1 @@ +Remove the 'populate_stats_process_rooms_2' background job and restore functionality to 'populate_stats_process_rooms'. \ No newline at end of file diff --git a/changelog.d/8247.misc b/changelog.d/8247.misc new file mode 100644 index 0000000000..3c27803be4 --- /dev/null +++ b/changelog.d/8247.misc @@ -0,0 +1 @@ +Track the `stream_ordering` of the last successfully-sent event to every destination, so we can use this information to 'catch up' a remote server after an outage. diff --git a/changelog.d/8256.misc b/changelog.d/8256.misc new file mode 100644 index 0000000000..bf0ba76730 --- /dev/null +++ b/changelog.d/8256.misc @@ -0,0 +1 @@ +Track the latest event for every destination and room for catch-up after federation outage. diff --git a/changelog.d/8258.misc b/changelog.d/8258.misc new file mode 100644 index 0000000000..3c27803be4 --- /dev/null +++ b/changelog.d/8258.misc @@ -0,0 +1 @@ +Track the `stream_ordering` of the last successfully-sent event to every destination, so we can use this information to 'catch up' a remote server after an outage. diff --git a/changelog.d/8259.misc b/changelog.d/8259.misc new file mode 100644 index 0000000000..a26779a664 --- /dev/null +++ b/changelog.d/8259.misc @@ -0,0 +1 @@ +Switch to the JSON implementation from the standard library. diff --git a/changelog.d/8261.misc b/changelog.d/8261.misc new file mode 100644 index 0000000000..bc91e9375c --- /dev/null +++ b/changelog.d/8261.misc @@ -0,0 +1 @@ +Simplify tests that mock asynchronous functions. diff --git a/changelog.d/8262.bugfix b/changelog.d/8262.bugfix new file mode 100644 index 0000000000..2b84927de3 --- /dev/null +++ b/changelog.d/8262.bugfix @@ -0,0 +1 @@ +Upgrade canonicaljson to version 1.4.0Â to fix an unicode encoding issue. diff --git a/changelog.d/8265.bugfix b/changelog.d/8265.bugfix new file mode 100644 index 0000000000..981a836d21 --- /dev/null +++ b/changelog.d/8265.bugfix @@ -0,0 +1 @@ +Fix logstanding bug which could lead to incomplete database upgrades on SQLite. diff --git a/changelog.d/8268.bugfix b/changelog.d/8268.bugfix new file mode 100644 index 0000000000..4b15a60253 --- /dev/null +++ b/changelog.d/8268.bugfix @@ -0,0 +1 @@ +Fix stack overflow when stderr is redirected to the logging system, and the logging system encounters an error. diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index 2a2c9e6f13..bb33345be6 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -15,10 +15,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import json from typing import List import jsonschema -from canonicaljson import json from jsonschema import FormatChecker from synapse.api.constants import EventContentFields diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py index b6c9085670..7d309b1bb0 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py @@ -14,13 +14,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import argparse +import json import logging import os import sys import tempfile -from canonicaljson import json - from twisted.internet import defer, task import synapse diff --git a/synapse/config/logger.py b/synapse/config/logger.py index c96e6ef62a..13d6f6a3ea 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -17,6 +17,7 @@ import logging import logging.config import os import sys +import threading from string import Template import yaml @@ -25,6 +26,7 @@ from twisted.logger import ( ILogObserver, LogBeginner, STDLibLogObserver, + eventAsText, globalLogBeginner, ) @@ -216,8 +218,9 @@ def _setup_stdlib_logging(config, log_config, logBeginner: LogBeginner): # system. observer = STDLibLogObserver() - def _log(event): + threadlocal = threading.local() + def _log(event): if "log_text" in event: if event["log_text"].startswith("DNSDatagramProtocol starting on "): return @@ -228,7 +231,25 @@ def _setup_stdlib_logging(config, log_config, logBeginner: LogBeginner): if event["log_text"].startswith("Timing out client"): return - return observer(event) + # this is a workaround to make sure we don't get stack overflows when the + # logging system raises an error which is written to stderr which is redirected + # to the logging system, etc. + if getattr(threadlocal, "active", False): + # write the text of the event, if any, to the *real* stderr (which may + # be redirected to /dev/null, but there's not much we can do) + try: + event_text = eventAsText(event) + print("logging during logging: %s" % event_text, file=sys.__stderr__) + except Exception: + # gah. + pass + return + + try: + threadlocal.active = True + return observer(event) + finally: + threadlocal.active = False logBeginner.beginLoggingTo([_log], redirectStandardIO=not config.no_redirect_stdio) if not config.no_redirect_stdio: diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 552519e82c..41a726878d 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -209,7 +209,7 @@ class FederationSender: logger.debug("Sending %s to %r", event, destinations) if destinations: - self._send_pdu(event, destinations) + await self._send_pdu(event, destinations) now = self.clock.time_msec() ts = await self.store.get_received_ts(event.event_id) @@ -265,7 +265,7 @@ class FederationSender: finally: self._is_processing = False - def _send_pdu(self, pdu: EventBase, destinations: Iterable[str]) -> None: + async def _send_pdu(self, pdu: EventBase, destinations: Iterable[str]) -> None: # We loop through all destinations to see whether we already have # a transaction in progress. If we do, stick it in the pending_pdus # table and we'll get back to it later. @@ -280,6 +280,13 @@ class FederationSender: sent_pdus_destination_dist_total.inc(len(destinations)) sent_pdus_destination_dist_count.inc() + # track the fact that we have a PDU for these destinations, + # to allow us to perform catch-up later on if the remote is unreachable + # for a while. + await self.store.store_destination_rooms_entries( + destinations, pdu.room_id, pdu.internal_metadata.stream_ordering, + ) + for destination in destinations: self._get_per_destination_queue(destination).send_pdu(pdu) diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index defc228c23..9f0852b4a2 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -325,6 +325,17 @@ class PerDestinationQueue: self._last_device_stream_id = device_stream_id self._last_device_list_stream_id = dev_list_id + + if pending_pdus: + # we sent some PDUs and it was successful, so update our + # last_successful_stream_ordering in the destinations table. + final_pdu = pending_pdus[-1] + last_successful_stream_ordering = ( + final_pdu.internal_metadata.stream_ordering + ) + await self._store.set_destination_last_successful_stream_ordering( + self._destination, last_successful_stream_ordering + ) else: break except NotRetryingDestination as e: diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 2d995ec456..ff0c67228b 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -43,7 +43,7 @@ REQUIREMENTS = [ "jsonschema>=2.5.1", "frozendict>=1", "unpaddedbase64>=1.1.0", - "canonicaljson>=1.3.0", + "canonicaljson>=1.4.0", # we use the type definitions added in signedjson 1.1. "signedjson>=1.1.0", "pynacl>=1.2.1", diff --git a/synapse/storage/database.py b/synapse/storage/database.py index ed8a9bffb1..79ec8f119d 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -952,7 +952,7 @@ class DatabasePool: key_names: Collection[str], key_values: Collection[Iterable[Any]], value_names: Collection[str], - value_values: Iterable[Iterable[str]], + value_values: Iterable[Iterable[Any]], ) -> None: """ Upsert, many times. @@ -981,7 +981,7 @@ class DatabasePool: key_names: Iterable[str], key_values: Collection[Iterable[Any]], value_names: Collection[str], - value_values: Iterable[Iterable[str]], + value_values: Iterable[Iterable[Any]], ) -> None: """ Upsert, many times, but without native UPSERT support or batching. diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py index ea833829ae..d7a03cbf7d 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py @@ -69,6 +69,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore): # room_depth # state_groups # state_groups_state + # destination_rooms # we will build a temporary table listing the events so that we don't # have to keep shovelling the list back and forth across the @@ -336,6 +337,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore): # and finally, the tables with an index on room_id (or no useful index) for table in ( "current_state_events", + "destination_rooms", "event_backward_extremities", "event_forward_extremities", "event_json", diff --git a/synapse/storage/databases/main/schema/delta/58/15_catchup_destination_rooms.sql b/synapse/storage/databases/main/schema/delta/58/15_catchup_destination_rooms.sql new file mode 100644 index 0000000000..ebfbed7925 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/15_catchup_destination_rooms.sql @@ -0,0 +1,42 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +-- This schema delta alters the schema to enable 'catching up' remote homeservers +-- after there has been a connectivity problem for any reason. + +-- This stores, for each (destination, room) pair, the stream_ordering of the +-- latest event for that destination. +CREATE TABLE IF NOT EXISTS destination_rooms ( + -- the destination in question. + destination TEXT NOT NULL REFERENCES destinations (destination), + -- the ID of the room in question + room_id TEXT NOT NULL REFERENCES rooms (room_id), + -- the stream_ordering of the event + stream_ordering BIGINT NOT NULL, + PRIMARY KEY (destination, room_id) + -- We don't declare a foreign key on stream_ordering here because that'd mean + -- we'd need to either maintain an index (expensive) or do a table scan of + -- destination_rooms whenever we delete an event (also potentially expensive). + -- In addition to that, a foreign key on stream_ordering would be redundant + -- as this row doesn't need to refer to a specific event; if the event gets + -- deleted then it doesn't affect the validity of the stream_ordering here. +); + +-- This index is needed to make it so that a deletion of a room (in the rooms +-- table) can be efficient, as otherwise a table scan would need to be performed +-- to check that no destination_rooms rows point to the room to be deleted. +-- Also: it makes it efficient to delete all the entries for a given room ID, +-- such as when purging a room. +CREATE INDEX IF NOT EXISTS destination_rooms_room_id + ON destination_rooms (room_id); diff --git a/synapse/storage/databases/main/schema/delta/58/16populate_stats_process_rooms_fix.sql b/synapse/storage/databases/main/schema/delta/58/16populate_stats_process_rooms_fix.sql new file mode 100644 index 0000000000..55f5d0f732 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/16populate_stats_process_rooms_fix.sql @@ -0,0 +1,22 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +-- This delta file fixes a regression introduced by 58/12room_stats.sql, removing the hacky +-- populate_stats_process_rooms_2 background job and restores the functionality under the +-- original name. +-- See https://github.com/matrix-org/synapse/issues/8238 for details + +DELETE FROM background_updates WHERE update_name = 'populate_stats_process_rooms'; +UPDATE background_updates SET update_name = 'populate_stats_process_rooms' + WHERE update_name = 'populate_stats_process_rooms_2'; diff --git a/synapse/storage/databases/main/schema/delta/58/17_catchup_last_successful.sql b/synapse/storage/databases/main/schema/delta/58/17_catchup_last_successful.sql new file mode 100644 index 0000000000..a67aa5e500 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/17_catchup_last_successful.sql @@ -0,0 +1,21 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- This column tracks the stream_ordering of the event that was most recently +-- successfully transmitted to the destination. +-- A value of NULL means that we have not sent an event successfully yet +-- (at least, not since the introduction of this column). +ALTER TABLE destinations + ADD COLUMN last_successful_stream_ordering BIGINT; diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py index 55a250ef06..30840dbbaa 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py @@ -74,9 +74,6 @@ class StatsStore(StateDeltasStore): "populate_stats_process_rooms", self._populate_stats_process_rooms ) self.db_pool.updates.register_background_update_handler( - "populate_stats_process_rooms_2", self._populate_stats_process_rooms_2 - ) - self.db_pool.updates.register_background_update_handler( "populate_stats_process_users", self._populate_stats_process_users ) # we no longer need to perform clean-up, but we will give ourselves @@ -148,31 +145,10 @@ class StatsStore(StateDeltasStore): return len(users_to_work_on) async def _populate_stats_process_rooms(self, progress, batch_size): - """ - This was a background update which regenerated statistics for rooms. - - It has been replaced by StatsStore._populate_stats_process_rooms_2. This background - job has been scheduled to run as part of Synapse v1.0.0, and again now. To ensure - someone upgrading from <v1.0.0, this background task has been turned into a no-op - so that the potentially expensive task is not run twice. - - Further context: https://github.com/matrix-org/synapse/pull/7977 - """ - await self.db_pool.updates._end_background_update( - "populate_stats_process_rooms" - ) - return 1 - - async def _populate_stats_process_rooms_2(self, progress, batch_size): - """ - This is a background update which regenerates statistics for rooms. - - It replaces StatsStore._populate_stats_process_rooms. See its docstring for the - reasoning. - """ + """This is a background update which regenerates statistics for rooms.""" if not self.stats_enabled: await self.db_pool.updates._end_background_update( - "populate_stats_process_rooms_2" + "populate_stats_process_rooms" ) return 1 @@ -189,13 +165,13 @@ class StatsStore(StateDeltasStore): return [r for r, in txn] rooms_to_work_on = await self.db_pool.runInteraction( - "populate_stats_rooms_2_get_batch", _get_next_batch + "populate_stats_rooms_get_batch", _get_next_batch ) # No more rooms -- complete the transaction. if not rooms_to_work_on: await self.db_pool.updates._end_background_update( - "populate_stats_process_rooms_2" + "populate_stats_process_rooms" ) return 1 @@ -204,9 +180,9 @@ class StatsStore(StateDeltasStore): progress["last_room_id"] = room_id await self.db_pool.runInteraction( - "_populate_stats_process_rooms_2", + "_populate_stats_process_rooms", self.db_pool.updates._background_update_progress_txn, - "populate_stats_process_rooms_2", + "populate_stats_process_rooms", progress, ) diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index 5b31aab700..c0a958252e 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -15,13 +15,14 @@ import logging from collections import namedtuple -from typing import Optional, Tuple +from typing import Iterable, Optional, Tuple from canonicaljson import encode_canonical_json from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.types import JsonDict from synapse.util.caches.expiringcache import ExpiringCache @@ -164,7 +165,9 @@ class TransactionStore(SQLBaseStore): allow_none=True, ) - if result and result["retry_last_ts"] > 0: + # check we have a row and retry_last_ts is not null or zero + # (retry_last_ts can't be negative) + if result and result["retry_last_ts"]: return result else: return None @@ -273,3 +276,98 @@ class TransactionStore(SQLBaseStore): await self.db_pool.runInteraction( "_cleanup_transactions", _cleanup_transactions_txn ) + + async def store_destination_rooms_entries( + self, destinations: Iterable[str], room_id: str, stream_ordering: int, + ) -> None: + """ + Updates or creates `destination_rooms` entries in batch for a single event. + + Args: + destinations: list of destinations + room_id: the room_id of the event + stream_ordering: the stream_ordering of the event + """ + + return await self.db_pool.runInteraction( + "store_destination_rooms_entries", + self._store_destination_rooms_entries_txn, + destinations, + room_id, + stream_ordering, + ) + + def _store_destination_rooms_entries_txn( + self, + txn: LoggingTransaction, + destinations: Iterable[str], + room_id: str, + stream_ordering: int, + ) -> None: + + # ensure we have a `destinations` row for this destination, as there is + # a foreign key constraint. + if isinstance(self.database_engine, PostgresEngine): + q = """ + INSERT INTO destinations (destination) + VALUES (?) + ON CONFLICT DO NOTHING; + """ + elif isinstance(self.database_engine, Sqlite3Engine): + q = """ + INSERT OR IGNORE INTO destinations (destination) + VALUES (?); + """ + else: + raise RuntimeError("Unknown database engine") + + txn.execute_batch(q, ((destination,) for destination in destinations)) + + rows = [(destination, room_id) for destination in destinations] + + self.db_pool.simple_upsert_many_txn( + txn, + "destination_rooms", + ["destination", "room_id"], + rows, + ["stream_ordering"], + [(stream_ordering,)] * len(rows), + ) + + async def get_destination_last_successful_stream_ordering( + self, destination: str + ) -> Optional[int]: + """ + Gets the stream ordering of the PDU most-recently successfully sent + to the specified destination, or None if this information has not been + tracked yet. + + Args: + destination: the destination to query + """ + return await self.db_pool.simple_select_one_onecol( + "destinations", + {"destination": destination}, + "last_successful_stream_ordering", + allow_none=True, + desc="get_last_successful_stream_ordering", + ) + + async def set_destination_last_successful_stream_ordering( + self, destination: str, last_successful_stream_ordering: int + ) -> None: + """ + Marks that we have successfully sent the PDUs up to and including the + one specified. + + Args: + destination: the destination we have successfully sent to + last_successful_stream_ordering: the stream_ordering of the most + recent successfully-sent PDU + """ + return await self.db_pool.simple_upsert( + "destinations", + keyvalues={"destination": destination}, + values={"last_successful_stream_ordering": last_successful_stream_ordering}, + desc="set_last_successful_stream_ordering", + ) diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index ee60e2a718..a7f2dfb850 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -19,12 +19,15 @@ import logging import os import re from collections import Counter -from typing import TextIO +from typing import Optional, TextIO import attr +from synapse.config.homeserver import HomeServerConfig +from synapse.storage.engines import BaseDatabaseEngine from synapse.storage.engines.postgres import PostgresEngine -from synapse.storage.types import Cursor +from synapse.storage.types import Connection, Cursor +from synapse.types import Collection logger = logging.getLogger(__name__) @@ -63,7 +66,12 @@ UNAPPLIED_DELTA_ON_WORKER_ERROR = ( ) -def prepare_database(db_conn, database_engine, config, databases=["main", "state"]): +def prepare_database( + db_conn: Connection, + database_engine: BaseDatabaseEngine, + config: Optional[HomeServerConfig], + databases: Collection[str] = ["main", "state"], +): """Prepares a physical database for usage. Will either create all necessary tables or upgrade from an older schema version. @@ -73,16 +81,24 @@ def prepare_database(db_conn, database_engine, config, databases=["main", "state Args: db_conn: database_engine: - config (synapse.config.homeserver.HomeServerConfig|None): + config : application config, or None if we are connecting to an existing database which we expect to be configured already - databases (list[str]): The name of the databases that will be used + databases: The name of the databases that will be used with this physical database. Defaults to all databases. """ try: cur = db_conn.cursor() + # sqlite does not automatically start transactions for DDL / SELECT statements, + # so we start one before running anything. This ensures that any upgrades + # are either applied completely, or not at all. + # + # (psycopg2 automatically starts a transaction as soon as we run any statements + # at all, so this is redundant but harmless there.) + cur.execute("BEGIN TRANSACTION") + logger.info("%r: Checking existing schema version", databases) version_info = _get_or_create_schema_state(cur, database_engine) diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index 3ad4b28fc7..b2355700ad 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -13,11 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import logging import re import attr -from canonicaljson import json from twisted.internet import defer, task diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py index 0e445e01d7..bf094c9386 100644 --- a/synapse/util/frozenutils.py +++ b/synapse/util/frozenutils.py @@ -13,7 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from canonicaljson import json +import json + from frozendict import frozendict @@ -66,5 +67,5 @@ def _handle_frozendict(obj): # A JSONEncoder which is capable of encoding frozendicts without barfing. # Additionally reduce the whitespace produced by JSON encoding. frozendict_json_encoder = json.JSONEncoder( - default=_handle_frozendict, separators=(",", ":"), + allow_nan=False, separators=(",", ":"), default=_handle_frozendict, ) diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py index 3d880c499d..1471cc1a28 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py @@ -77,11 +77,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = Mock( - side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999}) - ) + fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) handler.federation_handler.do_invite_join = Mock( - side_effect=lambda *args, **kwargs: make_awaitable(("", 1)) + return_value=make_awaitable(("", 1)) ) d = handler._remote_join( @@ -110,11 +108,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = Mock( - side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999}) - ) + fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) handler.federation_handler.do_invite_join = Mock( - side_effect=lambda *args, **kwargs: make_awaitable(("", 1)) + return_value=make_awaitable(("", 1)) ) d = handler._remote_join( @@ -150,11 +146,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = Mock( - side_effect=lambda *args, **kwargs: make_awaitable(None) - ) + fed_transport.client.get_json = Mock(return_value=make_awaitable(None)) handler.federation_handler.do_invite_join = Mock( - side_effect=lambda *args, **kwargs: make_awaitable(("", 1)) + return_value=make_awaitable(("", 1)) ) # Artificially raise the complexity @@ -208,11 +202,9 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase): fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = Mock( - side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999}) - ) + fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) handler.federation_handler.do_invite_join = Mock( - side_effect=lambda *args, **kwargs: make_awaitable(("", 1)) + return_value=make_awaitable(("", 1)) ) d = handler._remote_join( @@ -240,11 +232,9 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase): fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = Mock( - side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999}) - ) + fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) handler.federation_handler.do_invite_join = Mock( - side_effect=lambda *args, **kwargs: make_awaitable(("", 1)) + return_value=make_awaitable(("", 1)) ) d = handler._remote_join( diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py new file mode 100644 index 0000000000..6cdcc378f0 --- /dev/null +++ b/tests/federation/test_federation_catch_up.py @@ -0,0 +1,158 @@ +from mock import Mock + +from synapse.rest import admin +from synapse.rest.client.v1 import login, room + +from tests.test_utils import event_injection, make_awaitable +from tests.unittest import FederatingHomeserverTestCase, override_config + + +class FederationCatchUpTestCases(FederatingHomeserverTestCase): + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + return self.setup_test_homeserver( + federation_transport_client=Mock(spec=["send_transaction"]), + ) + + def prepare(self, reactor, clock, hs): + # stub out get_current_hosts_in_room + state_handler = hs.get_state_handler() + + # This mock is crucial for destination_rooms to be populated. + state_handler.get_current_hosts_in_room = Mock( + return_value=make_awaitable(["test", "host2"]) + ) + + # whenever send_transaction is called, record the pdu data + self.pdus = [] + self.failed_pdus = [] + self.is_online = True + self.hs.get_federation_transport_client().send_transaction.side_effect = ( + self.record_transaction + ) + + async def record_transaction(self, txn, json_cb): + if self.is_online: + data = json_cb() + self.pdus.extend(data["pdus"]) + return {} + else: + data = json_cb() + self.failed_pdus.extend(data["pdus"]) + raise IOError("Failed to connect because this is a test!") + + def get_destination_room(self, room: str, destination: str = "host2") -> dict: + """ + Gets the destination_rooms entry for a (destination, room_id) pair. + + Args: + room: room ID + destination: what destination, default is "host2" + + Returns: + Dictionary of { event_id: str, stream_ordering: int } + """ + event_id, stream_ordering = self.get_success( + self.hs.get_datastore().db_pool.execute( + "test:get_destination_rooms", + None, + """ + SELECT event_id, stream_ordering + FROM destination_rooms dr + JOIN events USING (stream_ordering) + WHERE dr.destination = ? AND dr.room_id = ? + """, + destination, + room, + ) + )[0] + return {"event_id": event_id, "stream_ordering": stream_ordering} + + @override_config({"send_federation": True}) + def test_catch_up_destination_rooms_tracking(self): + """ + Tests that we populate the `destination_rooms` table as needed. + """ + self.register_user("u1", "you the one") + u1_token = self.login("u1", "you the one") + room = self.helper.create_room_as("u1", tok=u1_token) + + self.get_success( + event_injection.inject_member_event(self.hs, room, "@user:host2", "join") + ) + + event_id_1 = self.helper.send(room, "wombats!", tok=u1_token)["event_id"] + + row_1 = self.get_destination_room(room) + + event_id_2 = self.helper.send(room, "rabbits!", tok=u1_token)["event_id"] + + row_2 = self.get_destination_room(room) + + # check: events correctly registered in order + self.assertEqual(row_1["event_id"], event_id_1) + self.assertEqual(row_2["event_id"], event_id_2) + self.assertEqual(row_1["stream_ordering"], row_2["stream_ordering"] - 1) + + @override_config({"send_federation": True}) + def test_catch_up_last_successful_stream_ordering_tracking(self): + """ + Tests that we populate the `destination_rooms` table as needed. + """ + self.register_user("u1", "you the one") + u1_token = self.login("u1", "you the one") + room = self.helper.create_room_as("u1", tok=u1_token) + + # take the remote offline + self.is_online = False + + self.get_success( + event_injection.inject_member_event(self.hs, room, "@user:host2", "join") + ) + + self.helper.send(room, "wombats!", tok=u1_token) + self.pump() + + lsso_1 = self.get_success( + self.hs.get_datastore().get_destination_last_successful_stream_ordering( + "host2" + ) + ) + + self.assertIsNone( + lsso_1, + "There should be no last successful stream ordering for an always-offline destination", + ) + + # bring the remote online + self.is_online = True + + event_id_2 = self.helper.send(room, "rabbits!", tok=u1_token)["event_id"] + + lsso_2 = self.get_success( + self.hs.get_datastore().get_destination_last_successful_stream_ordering( + "host2" + ) + ) + row_2 = self.get_destination_room(room) + + self.assertEqual( + self.pdus[0]["content"]["body"], + "rabbits!", + "Test fault: didn't receive the right PDU", + ) + self.assertEqual( + row_2["event_id"], + event_id_2, + "Test fault: destination_rooms not updated correctly", + ) + self.assertEqual( + lsso_2, + row_2["stream_ordering"], + "Send succeeded but not marked as last_successful_stream_ordering", + ) diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index 5f512ff8bf..917762e6b6 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -34,7 +34,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): def make_homeserver(self, reactor, clock): mock_state_handler = Mock(spec=["get_current_hosts_in_room"]) # Ensure a new Awaitable is created for each call. - mock_state_handler.get_current_hosts_in_room.side_effect = lambda room_Id: make_awaitable( + mock_state_handler.get_current_hosts_in_room.return_value = make_awaitable( ["test", "host2"] ) return self.setup_test_homeserver( diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index c7efd3822d..97877c2e42 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -143,7 +143,7 @@ class AuthTestCase(unittest.TestCase): def test_mau_limits_exceeded_large(self): self.auth_blocking._limit_usage_by_mau = True self.hs.get_datastore().get_monthly_active_count = Mock( - side_effect=lambda: make_awaitable(self.large_number_of_users) + return_value=make_awaitable(self.large_number_of_users) ) with self.assertRaises(ResourceLimitError): @@ -154,7 +154,7 @@ class AuthTestCase(unittest.TestCase): ) self.hs.get_datastore().get_monthly_active_count = Mock( - side_effect=lambda: make_awaitable(self.large_number_of_users) + return_value=make_awaitable(self.large_number_of_users) ) with self.assertRaises(ResourceLimitError): yield defer.ensureDeferred( @@ -169,7 +169,7 @@ class AuthTestCase(unittest.TestCase): # If not in monthly active cohort self.hs.get_datastore().get_monthly_active_count = Mock( - side_effect=lambda: make_awaitable(self.auth_blocking._max_mau_value) + return_value=make_awaitable(self.auth_blocking._max_mau_value) ) with self.assertRaises(ResourceLimitError): yield defer.ensureDeferred( @@ -179,7 +179,7 @@ class AuthTestCase(unittest.TestCase): ) self.hs.get_datastore().get_monthly_active_count = Mock( - side_effect=lambda: make_awaitable(self.auth_blocking._max_mau_value) + return_value=make_awaitable(self.auth_blocking._max_mau_value) ) with self.assertRaises(ResourceLimitError): yield defer.ensureDeferred( @@ -189,10 +189,10 @@ class AuthTestCase(unittest.TestCase): ) # If in monthly active cohort self.hs.get_datastore().user_last_seen_monthly_active = Mock( - side_effect=lambda user_id: make_awaitable(self.hs.get_clock().time_msec()) + return_value=make_awaitable(self.hs.get_clock().time_msec()) ) self.hs.get_datastore().get_monthly_active_count = Mock( - side_effect=lambda: make_awaitable(self.auth_blocking._max_mau_value) + return_value=make_awaitable(self.auth_blocking._max_mau_value) ) yield defer.ensureDeferred( self.auth_handler.get_access_token_for_user_id( @@ -200,10 +200,10 @@ class AuthTestCase(unittest.TestCase): ) ) self.hs.get_datastore().user_last_seen_monthly_active = Mock( - side_effect=lambda user_id: make_awaitable(self.hs.get_clock().time_msec()) + return_value=make_awaitable(self.hs.get_clock().time_msec()) ) self.hs.get_datastore().get_monthly_active_count = Mock( - side_effect=lambda: make_awaitable(self.auth_blocking._max_mau_value) + return_value=make_awaitable(self.auth_blocking._max_mau_value) ) yield defer.ensureDeferred( self.auth_handler.validate_short_term_login_token_and_get_user_id( @@ -216,7 +216,7 @@ class AuthTestCase(unittest.TestCase): self.auth_blocking._limit_usage_by_mau = True self.hs.get_datastore().get_monthly_active_count = Mock( - side_effect=lambda: make_awaitable(self.small_number_of_users) + return_value=make_awaitable(self.small_number_of_users) ) # Ensure does not raise exception yield defer.ensureDeferred( @@ -226,7 +226,7 @@ class AuthTestCase(unittest.TestCase): ) self.hs.get_datastore().get_monthly_active_count = Mock( - side_effect=lambda: make_awaitable(self.small_number_of_users) + return_value=make_awaitable(self.small_number_of_users) ) yield defer.ensureDeferred( self.auth_handler.validate_short_term_login_token_and_get_user_id( diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index eddf5e2498..cb7c0ed51a 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -100,7 +100,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): def test_get_or_create_user_mau_not_blocked(self): self.hs.config.limit_usage_by_mau = True self.store.count_monthly_users = Mock( - side_effect=lambda: make_awaitable(self.hs.config.max_mau_value - 1) + return_value=make_awaitable(self.hs.config.max_mau_value - 1) ) # Ensure does not throw exception self.get_success(self.get_or_create_user(self.requester, "c", "User")) @@ -108,7 +108,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): def test_get_or_create_user_mau_blocked(self): self.hs.config.limit_usage_by_mau = True self.store.get_monthly_active_count = Mock( - side_effect=lambda: make_awaitable(self.lots_of_users) + return_value=make_awaitable(self.lots_of_users) ) self.get_failure( self.get_or_create_user(self.requester, "b", "display_name"), @@ -116,7 +116,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): ) self.store.get_monthly_active_count = Mock( - side_effect=lambda: make_awaitable(self.hs.config.max_mau_value) + return_value=make_awaitable(self.hs.config.max_mau_value) ) self.get_failure( self.get_or_create_user(self.requester, "b", "display_name"), @@ -126,14 +126,14 @@ class RegistrationTestCase(unittest.HomeserverTestCase): def test_register_mau_blocked(self): self.hs.config.limit_usage_by_mau = True self.store.get_monthly_active_count = Mock( - side_effect=lambda: make_awaitable(self.lots_of_users) + return_value=make_awaitable(self.lots_of_users) ) self.get_failure( self.handler.register_user(localpart="local_part"), ResourceLimitError ) self.store.get_monthly_active_count = Mock( - side_effect=lambda: make_awaitable(self.hs.config.max_mau_value) + return_value=make_awaitable(self.hs.config.max_mau_value) ) self.get_failure( self.handler.register_user(localpart="local_part"), ResourceLimitError diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py index a609f148c0..312c0a0d41 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py @@ -54,7 +54,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.store.db_pool.simple_insert( "background_updates", { - "update_name": "populate_stats_process_rooms_2", + "update_name": "populate_stats_process_rooms", "progress_json": "{}", "depends_on": "populate_stats_prepare", }, @@ -66,7 +66,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): { "update_name": "populate_stats_process_users", "progress_json": "{}", - "depends_on": "populate_stats_process_rooms_2", + "depends_on": "populate_stats_process_rooms", }, ) ) @@ -219,10 +219,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.get_success( self.store.db_pool.simple_insert( "background_updates", - { - "update_name": "populate_stats_process_rooms_2", - "progress_json": "{}", - }, + {"update_name": "populate_stats_process_rooms", "progress_json": "{}"}, ) ) self.get_success( @@ -231,7 +228,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): { "update_name": "populate_stats_cleanup", "progress_json": "{}", - "depends_on": "populate_stats_process_rooms_2", + "depends_on": "populate_stats_process_rooms", }, ) ) @@ -728,7 +725,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.store.db_pool.simple_insert( "background_updates", { - "update_name": "populate_stats_process_rooms_2", + "update_name": "populate_stats_process_rooms", "progress_json": "{}", "depends_on": "populate_stats_prepare", }, @@ -740,7 +737,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): { "update_name": "populate_stats_process_users", "progress_json": "{}", - "depends_on": "populate_stats_process_rooms_2", + "depends_on": "populate_stats_process_rooms", }, ) ) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 7bf15c4ba9..ae6bc24f4c 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -116,7 +116,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): retry_timings_res ) - self.datastore.get_device_updates_by_remote.side_effect = lambda destination, from_stream_id, limit: make_awaitable( + self.datastore.get_device_updates_by_remote.return_value = make_awaitable( (0, []) ) diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py index 8b4982ecb1..1d7edee5ba 100644 --- a/tests/replication/test_federation_sender_shard.py +++ b/tests/replication/test_federation_sender_shard.py @@ -45,7 +45,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): new event. """ mock_client = Mock(spec=["put_json"]) - mock_client.put_json.side_effect = lambda *_, **__: make_awaitable({}) + mock_client.put_json.return_value = make_awaitable({}) self.make_worker_hs( "synapse.app.federation_sender", @@ -73,7 +73,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): new events. """ mock_client1 = Mock(spec=["put_json"]) - mock_client1.put_json.side_effect = lambda *_, **__: make_awaitable({}) + mock_client1.put_json.return_value = make_awaitable({}) self.make_worker_hs( "synapse.app.federation_sender", { @@ -85,7 +85,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): ) mock_client2 = Mock(spec=["put_json"]) - mock_client2.put_json.side_effect = lambda *_, **__: make_awaitable({}) + mock_client2.put_json.return_value = make_awaitable({}) self.make_worker_hs( "synapse.app.federation_sender", { @@ -136,7 +136,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): new typing EDUs. """ mock_client1 = Mock(spec=["put_json"]) - mock_client1.put_json.side_effect = lambda *_, **__: make_awaitable({}) + mock_client1.put_json.return_value = make_awaitable({}) self.make_worker_hs( "synapse.app.federation_sender", { @@ -148,7 +148,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): ) mock_client2 = Mock(spec=["put_json"]) - mock_client2.put_json.side_effect = lambda *_, **__: make_awaitable({}) + mock_client2.put_json.return_value = make_awaitable({}) self.make_worker_hs( "synapse.app.federation_sender", { diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 160c630235..b8b7758d24 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -337,7 +337,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # Set monthly active users to the limit store.get_monthly_active_count = Mock( - side_effect=lambda: make_awaitable(self.hs.config.max_mau_value) + return_value=make_awaitable(self.hs.config.max_mau_value) ) # Check that the blocking of monthly active users is working as expected # The registration of a new user fails due to the limit @@ -591,7 +591,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Set monthly active users to the limit self.store.get_monthly_active_count = Mock( - side_effect=lambda: make_awaitable(self.hs.config.max_mau_value) + return_value=make_awaitable(self.hs.config.max_mau_value) ) # Check that the blocking of monthly active users is working as expected # The registration of a new user fails due to the limit @@ -631,7 +631,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Set monthly active users to the limit self.store.get_monthly_active_count = Mock( - side_effect=lambda: make_awaitable(self.hs.config.max_mau_value) + return_value=make_awaitable(self.hs.config.max_mau_value) ) # Check that the blocking of monthly active users is working as expected # The registration of a new user fails due to the limit diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 973338ea71..6382b19dc3 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -67,7 +67,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): raise Exception("Failed to find reference to ResourceLimitsServerNotices") self._rlsn._store.user_last_seen_monthly_active = Mock( - side_effect=lambda user_id: make_awaitable(1000) + return_value=make_awaitable(1000) ) self._rlsn._server_notices_manager.send_notice = Mock( return_value=defer.succeed(Mock()) @@ -80,9 +80,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): return_value=defer.succeed("!something:localhost") ) self._rlsn._store.add_tag_to_room = Mock(return_value=defer.succeed(None)) - self._rlsn._store.get_tags_for_room = Mock( - side_effect=lambda user_id, room_id: make_awaitable({}) - ) + self._rlsn._store.get_tags_for_room = Mock(return_value=make_awaitable({})) @override_config({"hs_disabled": True}) def test_maybe_send_server_notice_disabled_hs(self): @@ -158,7 +156,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): """ self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None)) self._rlsn._store.user_last_seen_monthly_active = Mock( - side_effect=lambda user_id: make_awaitable(None) + return_value=make_awaitable(None) ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) @@ -261,12 +259,10 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): self.user_id = "@user_id:test" def test_server_notice_only_sent_once(self): - self.store.get_monthly_active_count = Mock( - side_effect=lambda: make_awaitable(1000) - ) + self.store.get_monthly_active_count = Mock(return_value=make_awaitable(1000)) self.store.user_last_seen_monthly_active = Mock( - side_effect=lambda user_id: make_awaitable(1000) + return_value=make_awaitable(1000) ) # Call the function multiple times to ensure we only send the notice once diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index 370c247e16..755c70db31 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -154,7 +154,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): user_id = "@user:server" self.store.get_monthly_active_count = Mock( - side_effect=lambda: make_awaitable(lots_of_users) + return_value=make_awaitable(lots_of_users) ) self.get_success( self.store.insert_client_ip( diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index 9870c74883..643072bbaf 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -231,9 +231,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): ) self.get_success(d) - self.store.upsert_monthly_active_user = Mock( - side_effect=lambda user_id: make_awaitable(None) - ) + self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) d = self.store.populate_monthly_active_users(user_id) self.get_success(d) @@ -241,9 +239,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.store.upsert_monthly_active_user.assert_not_called() def test_populate_monthly_users_should_update(self): - self.store.upsert_monthly_active_user = Mock( - side_effect=lambda user_id: make_awaitable(None) - ) + self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) self.store.is_trial_user = Mock(return_value=defer.succeed(False)) @@ -256,9 +252,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.store.upsert_monthly_active_user.assert_called_once() def test_populate_monthly_users_should_not_update(self): - self.store.upsert_monthly_active_user = Mock( - side_effect=lambda user_id: make_awaitable(None) - ) + self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) self.store.is_trial_user = Mock(return_value=defer.succeed(False)) self.store.user_last_seen_monthly_active = Mock( @@ -344,9 +338,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): @override_config({"limit_usage_by_mau": False, "mau_stats_only": False}) def test_no_users_when_not_tracking(self): - self.store.upsert_monthly_active_user = Mock( - side_effect=lambda user_id: make_awaitable(None) - ) + self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) self.get_success(self.store.populate_monthly_active_users("@user:sever")) diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py index 508aeba078..a298cc0fd3 100644 --- a/tests/test_utils/__init__.py +++ b/tests/test_utils/__init__.py @@ -17,6 +17,7 @@ """ Utilities for running the unit tests """ +from asyncio import Future from typing import Any, Awaitable, TypeVar TV = TypeVar("TV") @@ -38,6 +39,12 @@ def get_awaitable_result(awaitable: Awaitable[TV]) -> TV: raise Exception("awaitable has not yet completed") -async def make_awaitable(result: Any): - """Create an awaitable that just returns a result.""" - return result +def make_awaitable(result: Any) -> Awaitable[Any]: + """ + Makes an awaitable, suitable for mocking an `async` function. + This uses Futures as they can be awaited multiple times so can be returned + to multiple callers. + """ + future = Future() # type: ignore + future.set_result(result) + return future |