From 5090f26b636bf4439575767a2272d033fb33b2d5 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 14 May 2021 11:12:36 +0100 Subject: Minor `@cachedList` enhancements (#9975) - use a tuple rather than a list for the iterable that is passed into the wrapped function, for performance - test that we can pass an iterable and that keys are correctly deduped. --- tests/util/caches/test_descriptors.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) (limited to 'tests/util') diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 178ac8a68c..bbbc276697 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -666,18 +666,20 @@ class CachedListDescriptorTestCase(unittest.TestCase): with LoggingContext("c1") as c1: obj = Cls() obj.mock.return_value = {10: "fish", 20: "chips"} + + # start the lookup off d1 = obj.list_fn([10, 20], 2) self.assertEqual(current_context(), SENTINEL_CONTEXT) r = yield d1 self.assertEqual(current_context(), c1) - obj.mock.assert_called_once_with([10, 20], 2) + obj.mock.assert_called_once_with((10, 20), 2) self.assertEqual(r, {10: "fish", 20: "chips"}) obj.mock.reset_mock() # a call with different params should call the mock again obj.mock.return_value = {30: "peas"} r = yield obj.list_fn([20, 30], 2) - obj.mock.assert_called_once_with([30], 2) + obj.mock.assert_called_once_with((30,), 2) self.assertEqual(r, {20: "chips", 30: "peas"}) obj.mock.reset_mock() @@ -692,6 +694,15 @@ class CachedListDescriptorTestCase(unittest.TestCase): obj.mock.assert_not_called() self.assertEqual(r, {10: "fish", 20: "chips", 30: "peas"}) + # we should also be able to use a (single-use) iterable, and should + # deduplicate the keys + obj.mock.reset_mock() + obj.mock.return_value = {40: "gravy"} + iterable = (x for x in [10, 40, 40]) + r = yield obj.list_fn(iterable, 2) + obj.mock.assert_called_once_with((40,), 2) + self.assertEqual(r, {10: "fish", 40: "gravy"}) + @defer.inlineCallbacks def test_invalidate(self): """Make sure that invalidation callbacks are called.""" @@ -717,7 +728,7 @@ class CachedListDescriptorTestCase(unittest.TestCase): # cache miss obj.mock.return_value = {10: "fish", 20: "chips"} r1 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate0) - obj.mock.assert_called_once_with([10, 20], 2) + obj.mock.assert_called_once_with((10, 20), 2) self.assertEqual(r1, {10: "fish", 20: "chips"}) obj.mock.reset_mock() -- cgit 1.5.1 From 7958eadcd16087e6aaf7cce240c1f82856e0bcc7 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 21 May 2021 11:20:51 +0100 Subject: Add a batching queue implementation. (#10017) --- changelog.d/10017.misc | 1 + synapse/util/batching_queue.py | 153 ++++++++++++++++++++++++++++++++++ tests/util/test_batching_queue.py | 169 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 323 insertions(+) create mode 100644 changelog.d/10017.misc create mode 100644 synapse/util/batching_queue.py create mode 100644 tests/util/test_batching_queue.py (limited to 'tests/util') diff --git a/changelog.d/10017.misc b/changelog.d/10017.misc new file mode 100644 index 0000000000..4777b7fb57 --- /dev/null +++ b/changelog.d/10017.misc @@ -0,0 +1 @@ +Add a batching queue implementation. diff --git a/synapse/util/batching_queue.py b/synapse/util/batching_queue.py new file mode 100644 index 0000000000..44bbb7b1a8 --- /dev/null +++ b/synapse/util/batching_queue.py @@ -0,0 +1,153 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import ( + Awaitable, + Callable, + Dict, + Generic, + Hashable, + List, + Set, + Tuple, + TypeVar, +) + +from twisted.internet import defer + +from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable +from synapse.metrics import LaterGauge +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.util import Clock + +logger = logging.getLogger(__name__) + + +V = TypeVar("V") +R = TypeVar("R") + + +class BatchingQueue(Generic[V, R]): + """A queue that batches up work, calling the provided processing function + with all pending work (for a given key). + + The provided processing function will only be called once at a time for each + key. It will be called the next reactor tick after `add_to_queue` has been + called, and will keep being called until the queue has been drained (for the + given key). + + Note that the return value of `add_to_queue` will be the return value of the + processing function that processed the given item. This means that the + returned value will likely include data for other items that were in the + batch. + """ + + def __init__( + self, + name: str, + clock: Clock, + process_batch_callback: Callable[[List[V]], Awaitable[R]], + ): + self._name = name + self._clock = clock + + # The set of keys currently being processed. + self._processing_keys = set() # type: Set[Hashable] + + # The currently pending batch of values by key, with a Deferred to call + # with the result of the corresponding `_process_batch_callback` call. + self._next_values = {} # type: Dict[Hashable, List[Tuple[V, defer.Deferred]]] + + # The function to call with batches of values. + self._process_batch_callback = process_batch_callback + + LaterGauge( + "synapse_util_batching_queue_number_queued", + "The number of items waiting in the queue across all keys", + labels=("name",), + caller=lambda: sum(len(v) for v in self._next_values.values()), + ) + + LaterGauge( + "synapse_util_batching_queue_number_of_keys", + "The number of distinct keys that have items queued", + labels=("name",), + caller=lambda: len(self._next_values), + ) + + async def add_to_queue(self, value: V, key: Hashable = ()) -> R: + """Adds the value to the queue with the given key, returning the result + of the processing function for the batch that included the given value. + + The optional `key` argument allows sharding the queue by some key. The + queues will then be processed in parallel, i.e. the process batch + function will be called in parallel with batched values from a single + key. + """ + + # First we create a defer and add it and the value to the list of + # pending items. + d = defer.Deferred() + self._next_values.setdefault(key, []).append((value, d)) + + # If we're not currently processing the key fire off a background + # process to start processing. + if key not in self._processing_keys: + run_as_background_process(self._name, self._process_queue, key) + + return await make_deferred_yieldable(d) + + async def _process_queue(self, key: Hashable) -> None: + """A background task to repeatedly pull things off the queue for the + given key and call the `self._process_batch_callback` with the values. + """ + + try: + if key in self._processing_keys: + return + + self._processing_keys.add(key) + + while True: + # We purposefully wait a reactor tick to allow us to batch + # together requests that we're about to receive. A common + # pattern is to call `add_to_queue` multiple times at once, and + # deferring to the next reactor tick allows us to batch all of + # those up. + await self._clock.sleep(0) + + next_values = self._next_values.pop(key, []) + if not next_values: + # We've exhausted the queue. + break + + try: + values = [value for value, _ in next_values] + results = await self._process_batch_callback(values) + + for _, deferred in next_values: + with PreserveLoggingContext(): + deferred.callback(results) + + except Exception as e: + for _, deferred in next_values: + if deferred.called: + continue + + with PreserveLoggingContext(): + deferred.errback(e) + + finally: + self._processing_keys.discard(key) diff --git a/tests/util/test_batching_queue.py b/tests/util/test_batching_queue.py new file mode 100644 index 0000000000..5def1e56c9 --- /dev/null +++ b/tests/util/test_batching_queue.py @@ -0,0 +1,169 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from twisted.internet import defer + +from synapse.logging.context import make_deferred_yieldable +from synapse.util.batching_queue import BatchingQueue + +from tests.server import get_clock +from tests.unittest import TestCase + + +class BatchingQueueTestCase(TestCase): + def setUp(self): + self.clock, hs_clock = get_clock() + + self._pending_calls = [] + self.queue = BatchingQueue("test_queue", hs_clock, self._process_queue) + + async def _process_queue(self, values): + d = defer.Deferred() + self._pending_calls.append((values, d)) + return await make_deferred_yieldable(d) + + def test_simple(self): + """Tests the basic case of calling `add_to_queue` once and having + `_process_queue` return. + """ + + self.assertFalse(self._pending_calls) + + queue_d = defer.ensureDeferred(self.queue.add_to_queue("foo")) + + # The queue should wait a reactor tick before calling the processing + # function. + self.assertFalse(self._pending_calls) + self.assertFalse(queue_d.called) + + # We should see a call to `_process_queue` after a reactor tick. + self.clock.pump([0]) + + self.assertEqual(len(self._pending_calls), 1) + self.assertEqual(self._pending_calls[0][0], ["foo"]) + self.assertFalse(queue_d.called) + + # Return value of the `_process_queue` should be propagated back. + self._pending_calls.pop()[1].callback("bar") + + self.assertEqual(self.successResultOf(queue_d), "bar") + + def test_batching(self): + """Test that multiple calls at the same time get batched up into one + call to `_process_queue`. + """ + + self.assertFalse(self._pending_calls) + + queue_d1 = defer.ensureDeferred(self.queue.add_to_queue("foo1")) + queue_d2 = defer.ensureDeferred(self.queue.add_to_queue("foo2")) + + self.clock.pump([0]) + + # We should see only *one* call to `_process_queue` + self.assertEqual(len(self._pending_calls), 1) + self.assertEqual(self._pending_calls[0][0], ["foo1", "foo2"]) + self.assertFalse(queue_d1.called) + self.assertFalse(queue_d2.called) + + # Return value of the `_process_queue` should be propagated back to both. + self._pending_calls.pop()[1].callback("bar") + + self.assertEqual(self.successResultOf(queue_d1), "bar") + self.assertEqual(self.successResultOf(queue_d2), "bar") + + def test_queuing(self): + """Test that we queue up requests while a `_process_queue` is being + called. + """ + + self.assertFalse(self._pending_calls) + + queue_d1 = defer.ensureDeferred(self.queue.add_to_queue("foo1")) + self.clock.pump([0]) + + queue_d2 = defer.ensureDeferred(self.queue.add_to_queue("foo2")) + + # We should see only *one* call to `_process_queue` + self.assertEqual(len(self._pending_calls), 1) + self.assertEqual(self._pending_calls[0][0], ["foo1"]) + self.assertFalse(queue_d1.called) + self.assertFalse(queue_d2.called) + + # Return value of the `_process_queue` should be propagated back to the + # first. + self._pending_calls.pop()[1].callback("bar1") + + self.assertEqual(self.successResultOf(queue_d1), "bar1") + self.assertFalse(queue_d2.called) + + # We should now see a second call to `_process_queue` + self.clock.pump([0]) + self.assertEqual(len(self._pending_calls), 1) + self.assertEqual(self._pending_calls[0][0], ["foo2"]) + self.assertFalse(queue_d2.called) + + # Return value of the `_process_queue` should be propagated back to the + # second. + self._pending_calls.pop()[1].callback("bar2") + + self.assertEqual(self.successResultOf(queue_d2), "bar2") + + def test_different_keys(self): + """Test that calls to different keys get processed in parallel.""" + + self.assertFalse(self._pending_calls) + + queue_d1 = defer.ensureDeferred(self.queue.add_to_queue("foo1", key=1)) + self.clock.pump([0]) + queue_d2 = defer.ensureDeferred(self.queue.add_to_queue("foo2", key=2)) + self.clock.pump([0]) + + # We queue up another item with key=2 to check that we will keep taking + # things off the queue. + queue_d3 = defer.ensureDeferred(self.queue.add_to_queue("foo3", key=2)) + + # We should see two calls to `_process_queue` + self.assertEqual(len(self._pending_calls), 2) + self.assertEqual(self._pending_calls[0][0], ["foo1"]) + self.assertEqual(self._pending_calls[1][0], ["foo2"]) + self.assertFalse(queue_d1.called) + self.assertFalse(queue_d2.called) + self.assertFalse(queue_d3.called) + + # Return value of the `_process_queue` should be propagated back to the + # first. + self._pending_calls.pop(0)[1].callback("bar1") + + self.assertEqual(self.successResultOf(queue_d1), "bar1") + self.assertFalse(queue_d2.called) + self.assertFalse(queue_d3.called) + + # Return value of the `_process_queue` should be propagated back to the + # second. + self._pending_calls.pop()[1].callback("bar2") + + self.assertEqual(self.successResultOf(queue_d2), "bar2") + self.assertFalse(queue_d3.called) + + # We should now see a call `_pending_calls` for `foo3` + self.clock.pump([0]) + self.assertEqual(len(self._pending_calls), 1) + self.assertEqual(self._pending_calls[0][0], ["foo3"]) + self.assertFalse(queue_d3.called) + + # Return value of the `_process_queue` should be propagated back to the + # third deferred. + self._pending_calls.pop()[1].callback("bar4") + + self.assertEqual(self.successResultOf(queue_d3), "bar4") -- cgit 1.5.1 From 3e831f24ffc887e174f67ff7b1cfe3a429b7b5c1 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 21 May 2021 17:57:08 +0100 Subject: Don't hammer the database for destination retry timings every ~5mins (#10036) --- changelog.d/10036.misc | 1 + synapse/app/generic_worker.py | 2 - synapse/federation/transport/server.py | 2 +- synapse/replication/slave/storage/transactions.py | 21 -------- synapse/storage/databases/main/__init__.py | 4 +- synapse/storage/databases/main/transactions.py | 66 +++++++++++++---------- synapse/util/retryutils.py | 8 ++- tests/handlers/test_typing.py | 8 +-- tests/storage/test_transactions.py | 8 ++- tests/util/test_retryutils.py | 18 ++++--- 10 files changed, 62 insertions(+), 76 deletions(-) create mode 100644 changelog.d/10036.misc delete mode 100644 synapse/replication/slave/storage/transactions.py (limited to 'tests/util') diff --git a/changelog.d/10036.misc b/changelog.d/10036.misc new file mode 100644 index 0000000000..d2cf1e5473 --- /dev/null +++ b/changelog.d/10036.misc @@ -0,0 +1 @@ +Properly invalidate caches for destination retry timings every (instead of expiring entries every 5 minutes). diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index f730cdbd78..91ad326f19 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -61,7 +61,6 @@ from synapse.replication.slave.storage.pushers import SlavedPusherStore from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.room import RoomStore -from synapse.replication.slave.storage.transactions import SlavedTransactionStore from synapse.rest.admin import register_servlets_for_media_repo from synapse.rest.client.v1 import events, login, presence, room from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet @@ -237,7 +236,6 @@ class GenericWorkerSlavedStore( DirectoryStore, SlavedApplicationServiceStore, SlavedRegistrationStore, - SlavedTransactionStore, SlavedProfileStore, SlavedClientIpStore, SlavedFilteringStore, diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index c17a085a4f..9d50b05d01 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -160,7 +160,7 @@ class Authenticator: # If we get a valid signed request from the other side, its probably # alive retry_timings = await self.store.get_destination_retry_timings(origin) - if retry_timings and retry_timings["retry_last_ts"]: + if retry_timings and retry_timings.retry_last_ts: run_in_background(self._reset_retry_timings, origin) return origin diff --git a/synapse/replication/slave/storage/transactions.py b/synapse/replication/slave/storage/transactions.py deleted file mode 100644 index a59e543924..0000000000 --- a/synapse/replication/slave/storage/transactions.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from synapse.storage.databases.main.transactions import TransactionStore - -from ._base import BaseSlavedStore - - -class SlavedTransactionStore(TransactionStore, BaseSlavedStore): - pass diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 49c7606d51..9cce62ae6c 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -67,7 +67,7 @@ from .state import StateStore from .stats import StatsStore from .stream import StreamStore from .tags import TagsStore -from .transactions import TransactionStore +from .transactions import TransactionWorkerStore from .ui_auth import UIAuthStore from .user_directory import UserDirectoryStore from .user_erasure_store import UserErasureStore @@ -83,7 +83,7 @@ class DataStore( StreamStore, ProfileStore, PresenceStore, - TransactionStore, + TransactionWorkerStore, DirectoryStore, KeyStore, StateStore, diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index 82335e7a9d..d211c423b2 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -16,13 +16,15 @@ import logging from collections import namedtuple from typing import Iterable, List, Optional, Tuple +import attr from canonicaljson import encode_canonical_json from synapse.metrics.background_process_metrics import wrap_as_background_process -from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.storage._base import db_to_json from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.types import JsonDict -from synapse.util.caches.expiringcache import ExpiringCache +from synapse.util.caches.descriptors import cached db_binary_type = memoryview @@ -38,10 +40,23 @@ _UpdateTransactionRow = namedtuple( "_TransactionRow", ("response_code", "response_json") ) -SENTINEL = object() +@attr.s(slots=True, frozen=True, auto_attribs=True) +class DestinationRetryTimings: + """The current destination retry timing info for a remote server.""" -class TransactionWorkerStore(SQLBaseStore): + # The first time we tried and failed to reach the remote server, in ms. + failure_ts: int + + # The last time we tried and failed to reach the remote server, in ms. + retry_last_ts: int + + # How long since the last time we tried to reach the remote server before + # trying again, in ms. + retry_interval: int + + +class TransactionWorkerStore(CacheInvalidationWorkerStore): def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) @@ -60,19 +75,6 @@ class TransactionWorkerStore(SQLBaseStore): "_cleanup_transactions", _cleanup_transactions_txn ) - -class TransactionStore(TransactionWorkerStore): - """A collection of queries for handling PDUs.""" - - def __init__(self, database: DatabasePool, db_conn, hs): - super().__init__(database, db_conn, hs) - - self._destination_retry_cache = ExpiringCache( - cache_name="get_destination_retry_timings", - clock=self._clock, - expiry_ms=5 * 60 * 1000, - ) - async def get_received_txn_response( self, transaction_id: str, origin: str ) -> Optional[Tuple[int, JsonDict]]: @@ -145,7 +147,11 @@ class TransactionStore(TransactionWorkerStore): desc="set_received_txn_response", ) - async def get_destination_retry_timings(self, destination): + @cached(max_entries=10000) + async def get_destination_retry_timings( + self, + destination: str, + ) -> Optional[DestinationRetryTimings]: """Gets the current retry timings (if any) for a given destination. Args: @@ -156,34 +162,29 @@ class TransactionStore(TransactionWorkerStore): Otherwise a dict for the retry scheme """ - result = self._destination_retry_cache.get(destination, SENTINEL) - if result is not SENTINEL: - return result - result = await self.db_pool.runInteraction( "get_destination_retry_timings", self._get_destination_retry_timings, destination, ) - # We don't hugely care about race conditions between getting and - # invalidating the cache, since we time out fairly quickly anyway. - self._destination_retry_cache[destination] = result return result - def _get_destination_retry_timings(self, txn, destination): + def _get_destination_retry_timings( + self, txn, destination: str + ) -> Optional[DestinationRetryTimings]: result = self.db_pool.simple_select_one_txn( txn, table="destinations", keyvalues={"destination": destination}, - retcols=("destination", "failure_ts", "retry_last_ts", "retry_interval"), + retcols=("failure_ts", "retry_last_ts", "retry_interval"), allow_none=True, ) # check we have a row and retry_last_ts is not null or zero # (retry_last_ts can't be negative) if result and result["retry_last_ts"]: - return result + return DestinationRetryTimings(**result) else: return None @@ -204,7 +205,6 @@ class TransactionStore(TransactionWorkerStore): retry_interval: how long until next retry in ms """ - self._destination_retry_cache.pop(destination, None) if self.database_engine.can_native_upsert: return await self.db_pool.runInteraction( "set_destination_retry_timings", @@ -252,6 +252,10 @@ class TransactionStore(TransactionWorkerStore): txn.execute(sql, (destination, failure_ts, retry_last_ts, retry_interval)) + self._invalidate_cache_and_stream( + txn, self.get_destination_retry_timings, (destination,) + ) + def _set_destination_retry_timings_emulated( self, txn, destination, failure_ts, retry_last_ts, retry_interval ): @@ -295,6 +299,10 @@ class TransactionStore(TransactionWorkerStore): }, ) + self._invalidate_cache_and_stream( + txn, self.get_destination_retry_timings, (destination,) + ) + async def store_destination_rooms_entries( self, destinations: Iterable[str], diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py index f9c370a814..129b47cd49 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py @@ -82,11 +82,9 @@ async def get_retry_limiter(destination, clock, store, ignore_backoff=False, **k retry_timings = await store.get_destination_retry_timings(destination) if retry_timings: - failure_ts = retry_timings["failure_ts"] - retry_last_ts, retry_interval = ( - retry_timings["retry_last_ts"], - retry_timings["retry_interval"], - ) + failure_ts = retry_timings.failure_ts + retry_last_ts = retry_timings.retry_last_ts + retry_interval = retry_timings.retry_interval now = int(clock.time_msec()) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 0c89487eaf..f58afbc244 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -89,14 +89,8 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.event_source = hs.get_event_sources().sources["typing"] self.datastore = hs.get_datastore() - retry_timings_res = { - "destination": "", - "retry_last_ts": 0, - "retry_interval": 0, - "failure_ts": None, - } self.datastore.get_destination_retry_timings = Mock( - return_value=defer.succeed(retry_timings_res) + return_value=defer.succeed(None) ) self.datastore.get_device_updates_by_remote = Mock( diff --git a/tests/storage/test_transactions.py b/tests/storage/test_transactions.py index b7f7eae8d0..bea9091d30 100644 --- a/tests/storage/test_transactions.py +++ b/tests/storage/test_transactions.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from synapse.storage.databases.main.transactions import DestinationRetryTimings from synapse.util.retryutils import MAX_RETRY_INTERVAL from tests.unittest import HomeserverTestCase @@ -36,8 +37,11 @@ class TransactionStoreTestCase(HomeserverTestCase): d = self.store.get_destination_retry_timings("example.com") r = self.get_success(d) - self.assert_dict( - {"retry_last_ts": 50, "retry_interval": 100, "failure_ts": 1000}, r + self.assertEqual( + DestinationRetryTimings( + retry_last_ts=50, retry_interval=100, failure_ts=1000 + ), + r, ) def test_initial_set_transactions(self): diff --git a/tests/util/test_retryutils.py b/tests/util/test_retryutils.py index 9b2be83a43..9e1bebdc83 100644 --- a/tests/util/test_retryutils.py +++ b/tests/util/test_retryutils.py @@ -51,10 +51,12 @@ class RetryLimiterTestCase(HomeserverTestCase): except AssertionError: pass + self.pump() + new_timings = self.get_success(store.get_destination_retry_timings("test_dest")) - self.assertEqual(new_timings["failure_ts"], failure_ts) - self.assertEqual(new_timings["retry_last_ts"], failure_ts) - self.assertEqual(new_timings["retry_interval"], MIN_RETRY_INTERVAL) + self.assertEqual(new_timings.failure_ts, failure_ts) + self.assertEqual(new_timings.retry_last_ts, failure_ts) + self.assertEqual(new_timings.retry_interval, MIN_RETRY_INTERVAL) # now if we try again we should get a failure self.get_failure( @@ -77,14 +79,16 @@ class RetryLimiterTestCase(HomeserverTestCase): except AssertionError: pass + self.pump() + new_timings = self.get_success(store.get_destination_retry_timings("test_dest")) - self.assertEqual(new_timings["failure_ts"], failure_ts) - self.assertEqual(new_timings["retry_last_ts"], retry_ts) + self.assertEqual(new_timings.failure_ts, failure_ts) + self.assertEqual(new_timings.retry_last_ts, retry_ts) self.assertGreaterEqual( - new_timings["retry_interval"], MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 0.5 + new_timings.retry_interval, MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 0.5 ) self.assertLessEqual( - new_timings["retry_interval"], MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 2.0 + new_timings.retry_interval, MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 2.0 ) # -- cgit 1.5.1 From c0df6bae066fe818bb80d41af65503be7a07275d Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 24 May 2021 14:02:01 +0100 Subject: Remove `keylen` from `LruCache`. (#9993) `keylen` seems to be a thing that is frequently incorrectly set, and we don't really need it. The only time it was used was to figure out if we had removed a subtree in `del_multi`, which we can do better by changing `TreeCache.pop` to return a different type (`TreeCacheNode`). Commits should be independently reviewable. --- changelog.d/9993.misc | 1 + synapse/replication/slave/storage/client_ips.py | 2 +- synapse/storage/databases/main/client_ips.py | 2 +- synapse/storage/databases/main/devices.py | 2 +- synapse/storage/databases/main/events_worker.py | 1 - synapse/util/caches/deferred_cache.py | 2 - synapse/util/caches/descriptors.py | 1 - synapse/util/caches/lrucache.py | 10 +-- synapse/util/caches/treecache.py | 104 +++++++++++++++--------- tests/util/test_lrucache.py | 4 +- tests/util/test_treecache.py | 6 +- 11 files changed, 80 insertions(+), 55 deletions(-) create mode 100644 changelog.d/9993.misc (limited to 'tests/util') diff --git a/changelog.d/9993.misc b/changelog.d/9993.misc new file mode 100644 index 0000000000..0dd9244071 --- /dev/null +++ b/changelog.d/9993.misc @@ -0,0 +1 @@ +Remove `keylen` param on `LruCache`. diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py index 8730966380..13ed87adc4 100644 --- a/synapse/replication/slave/storage/client_ips.py +++ b/synapse/replication/slave/storage/client_ips.py @@ -24,7 +24,7 @@ class SlavedClientIpStore(BaseSlavedStore): super().__init__(database, db_conn, hs) self.client_ip_last_seen = LruCache( - cache_name="client_ip_last_seen", keylen=4, max_size=50000 + cache_name="client_ip_last_seen", max_size=50000 ) # type: LruCache[tuple, int] async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id): diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py index d60010e942..074b077bef 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py @@ -436,7 +436,7 @@ class ClientIpStore(ClientIpWorkerStore): def __init__(self, database: DatabasePool, db_conn, hs): self.client_ip_last_seen = LruCache( - cache_name="client_ip_last_seen", keylen=4, max_size=50000 + cache_name="client_ip_last_seen", max_size=50000 ) super().__init__(database, db_conn, hs) diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index a1f98b7e38..fd87ba71ab 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -1053,7 +1053,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): # Map of (user_id, device_id) -> bool. If there is an entry that implies # the device exists. self.device_id_exists_cache = LruCache( - cache_name="device_id_exists", keylen=2, max_size=10000 + cache_name="device_id_exists", max_size=10000 ) async def store_device( diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 2c823e09cf..6963bbf7f4 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -157,7 +157,6 @@ class EventsWorkerStore(SQLBaseStore): self._get_event_cache = LruCache( cache_name="*getEvent*", - keylen=3, max_size=hs.config.caches.event_cache_size, ) diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py index 484097a48a..371e7e4dd0 100644 --- a/synapse/util/caches/deferred_cache.py +++ b/synapse/util/caches/deferred_cache.py @@ -70,7 +70,6 @@ class DeferredCache(Generic[KT, VT]): self, name: str, max_entries: int = 1000, - keylen: int = 1, tree: bool = False, iterable: bool = False, apply_cache_factor_from_config: bool = True, @@ -101,7 +100,6 @@ class DeferredCache(Generic[KT, VT]): # a Deferred. self.cache = LruCache( max_size=max_entries, - keylen=keylen, cache_name=name, cache_type=cache_type, size_callback=(lambda d: len(d) or 1) if iterable else None, diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 3a4d027095..2ac24a2f25 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -270,7 +270,6 @@ class DeferredCacheDescriptor(_CacheDescriptorBase): cache = DeferredCache( name=self.orig.__name__, max_entries=self.max_entries, - keylen=self.num_args, tree=self.tree, iterable=self.iterable, ) # type: DeferredCache[CacheKey, Any] diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 1be675e014..54df407ff7 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -34,7 +34,7 @@ from typing_extensions import Literal from synapse.config import cache as cache_config from synapse.util import caches from synapse.util.caches import CacheMetric, register_cache -from synapse.util.caches.treecache import TreeCache +from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry try: from pympler.asizeof import Asizer @@ -160,7 +160,6 @@ class LruCache(Generic[KT, VT]): self, max_size: int, cache_name: Optional[str] = None, - keylen: int = 1, cache_type: Type[Union[dict, TreeCache]] = dict, size_callback: Optional[Callable] = None, metrics_collection_callback: Optional[Callable[[], None]] = None, @@ -173,9 +172,6 @@ class LruCache(Generic[KT, VT]): cache_name: The name of this cache, for the prometheus metrics. If unset, no metrics will be reported on this cache. - keylen: The length of the tuple used as the cache key. Ignored unless - cache_type is `TreeCache`. - cache_type (type): type of underlying cache to be used. Typically one of dict or TreeCache. @@ -403,7 +399,9 @@ class LruCache(Generic[KT, VT]): popped = cache.pop(key) if popped is None: return - for leaf in enumerate_leaves(popped, keylen - len(cast(tuple, key))): + # for each deleted node, we now need to remove it from the linked list + # and run its callbacks. + for leaf in iterate_tree_cache_entry(popped): delete_node(leaf) @synchronized diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py index eb4d98f683..73502a8b06 100644 --- a/synapse/util/caches/treecache.py +++ b/synapse/util/caches/treecache.py @@ -1,18 +1,43 @@ -from typing import Dict +# Copyright 2016-2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. SENTINEL = object() +class TreeCacheNode(dict): + """The type of nodes in our tree. + + Has its own type so we can distinguish it from real dicts that are stored at the + leaves. + """ + + pass + + class TreeCache: """ Tree-based backing store for LruCache. Allows subtrees of data to be deleted efficiently. Keys must be tuples. + + The data structure is a chain of TreeCacheNodes: + root = {key_1: {key_2: _value}} """ def __init__(self): self.size = 0 - self.root = {} # type: Dict + self.root = TreeCacheNode() def __setitem__(self, key, value): return self.set(key, value) @@ -21,10 +46,23 @@ class TreeCache: return self.get(key, SENTINEL) is not SENTINEL def set(self, key, value): + if isinstance(value, TreeCacheNode): + # this would mean we couldn't tell where our tree ended and the value + # started. + raise ValueError("Cannot store TreeCacheNodes in a TreeCache") + node = self.root for k in key[:-1]: - node = node.setdefault(k, {}) - node[key[-1]] = _Entry(value) + next_node = node.get(k, SENTINEL) + if next_node is SENTINEL: + next_node = node[k] = TreeCacheNode() + elif not isinstance(next_node, TreeCacheNode): + # this suggests that the caller is not being consistent with its key + # length. + raise ValueError("value conflicts with an existing subtree") + node = next_node + + node[key[-1]] = value self.size += 1 def get(self, key, default=None): @@ -33,25 +71,41 @@ class TreeCache: node = node.get(k, None) if node is None: return default - return node.get(key[-1], _Entry(default)).value + return node.get(key[-1], default) def clear(self): self.size = 0 - self.root = {} + self.root = TreeCacheNode() def pop(self, key, default=None): + """Remove the given key, or subkey, from the cache + + Args: + key: key or subkey to remove. + default: value to return if key is not found + + Returns: + If the key is not found, 'default'. If the key is complete, the removed + value. If the key is partial, the TreeCacheNode corresponding to the part + of the tree that was removed. + """ + # a list of the nodes we have touched on the way down the tree nodes = [] node = self.root for k in key[:-1]: node = node.get(k, None) - nodes.append(node) # don't add the root node if node is None: return default + if not isinstance(node, TreeCacheNode): + # we've gone off the end of the tree + raise ValueError("pop() key too long") + nodes.append(node) # don't add the root node popped = node.pop(key[-1], SENTINEL) if popped is SENTINEL: return default + # working back up the tree, clear out any nodes that are now empty node_and_keys = list(zip(nodes, key)) node_and_keys.reverse() node_and_keys.append((self.root, None)) @@ -61,14 +115,15 @@ class TreeCache: if n: break + # found an empty node: remove it from its parent, and loop. node_and_keys[i + 1][0].pop(k) - popped, cnt = _strip_and_count_entires(popped) + cnt = sum(1 for _ in iterate_tree_cache_entry(popped)) self.size -= cnt return popped def values(self): - return list(iterate_tree_cache_entry(self.root)) + return iterate_tree_cache_entry(self.root) def __len__(self): return self.size @@ -78,36 +133,9 @@ def iterate_tree_cache_entry(d): """Helper function to iterate over the leaves of a tree, i.e. a dict of that can contain dicts. """ - if isinstance(d, dict): + if isinstance(d, TreeCacheNode): for value_d in d.values(): for value in iterate_tree_cache_entry(value_d): yield value else: - if isinstance(d, _Entry): - yield d.value - else: - yield d - - -class _Entry: - __slots__ = ["value"] - - def __init__(self, value): - self.value = value - - -def _strip_and_count_entires(d): - """Takes an _Entry or dict with leaves of _Entry's, and either returns the - value or a dictionary with _Entry's replaced by their values. - - Also returns the count of _Entry's - """ - if isinstance(d, dict): - cnt = 0 - for key, value in d.items(): - v, n = _strip_and_count_entires(value) - d[key] = v - cnt += n - return d, cnt - else: - return d.value, 1 + yield d diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py index df3e27779f..377904e72e 100644 --- a/tests/util/test_lrucache.py +++ b/tests/util/test_lrucache.py @@ -59,7 +59,7 @@ class LruCacheTestCase(unittest.HomeserverTestCase): self.assertEquals(cache.pop("key"), None) def test_del_multi(self): - cache = LruCache(4, keylen=2, cache_type=TreeCache) + cache = LruCache(4, cache_type=TreeCache) cache[("animal", "cat")] = "mew" cache[("animal", "dog")] = "woof" cache[("vehicles", "car")] = "vroom" @@ -165,7 +165,7 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): m2 = Mock() m3 = Mock() m4 = Mock() - cache = LruCache(4, keylen=2, cache_type=TreeCache) + cache = LruCache(4, cache_type=TreeCache) cache.set(("a", "1"), "value", callbacks=[m1]) cache.set(("a", "2"), "value", callbacks=[m2]) diff --git a/tests/util/test_treecache.py b/tests/util/test_treecache.py index 3b077af27e..6066372053 100644 --- a/tests/util/test_treecache.py +++ b/tests/util/test_treecache.py @@ -13,7 +13,7 @@ # limitations under the License. -from synapse.util.caches.treecache import TreeCache +from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry from .. import unittest @@ -64,12 +64,14 @@ class TreeCacheTestCase(unittest.TestCase): cache[("a", "b")] = "AB" cache[("b", "a")] = "BA" self.assertEquals(cache.get(("a", "a")), "AA") - cache.pop(("a",)) + popped = cache.pop(("a",)) self.assertEquals(cache.get(("a", "a")), None) self.assertEquals(cache.get(("a", "b")), None) self.assertEquals(cache.get(("b", "a")), "BA") self.assertEquals(len(cache), 1) + self.assertEquals({"AA", "AB"}, set(iterate_tree_cache_entry(popped))) + def test_clear(self): cache = TreeCache() cache[("a",)] = "A" -- cgit 1.5.1 From 7adcb20fc02d614b4a2b03b128b279f25633e2bd Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 24 May 2021 15:32:01 -0400 Subject: Add missing type hints to synapse.util (#9982) --- changelog.d/9982.misc | 1 + mypy.ini | 9 +++++++++ synapse/config/saml2.py | 8 +++++++- synapse/storage/databases/main/keys.py | 2 +- synapse/util/hash.py | 10 +++++----- synapse/util/iterutils.py | 11 ++++------- synapse/util/module_loader.py | 9 +++++---- synapse/util/msisdn.py | 10 +++++----- tests/util/test_itertools.py | 4 ++-- 9 files changed, 39 insertions(+), 25 deletions(-) create mode 100644 changelog.d/9982.misc (limited to 'tests/util') diff --git a/changelog.d/9982.misc b/changelog.d/9982.misc new file mode 100644 index 0000000000..f3821f61a3 --- /dev/null +++ b/changelog.d/9982.misc @@ -0,0 +1 @@ +Add missing type hints to `synapse.util` module. diff --git a/mypy.ini b/mypy.ini index 1d1d1ea0f2..062872020e 100644 --- a/mypy.ini +++ b/mypy.ini @@ -71,8 +71,13 @@ files = synapse/types.py, synapse/util/async_helpers.py, synapse/util/caches, + synapse/util/daemonize.py, + synapse/util/hash.py, + synapse/util/iterutils.py, synapse/util/metrics.py, synapse/util/macaroons.py, + synapse/util/module_loader.py, + synapse/util/msisdn.py, synapse/util/stringutils.py, synapse/visibility.py, tests/replication, @@ -80,6 +85,7 @@ files = tests/handlers/test_password_providers.py, tests/rest/client/v1/test_login.py, tests/rest/client/v2_alpha/test_auth.py, + tests/util/test_itertools.py, tests/util/test_stream_change_cache.py [mypy-pymacaroons.*] @@ -175,5 +181,8 @@ ignore_missing_imports = True [mypy-pympler.*] ignore_missing_imports = True +[mypy-phonenumbers.*] +ignore_missing_imports = True + [mypy-ijson.*] ignore_missing_imports = True diff --git a/synapse/config/saml2.py b/synapse/config/saml2.py index 3d1218c8d1..05e983625d 100644 --- a/synapse/config/saml2.py +++ b/synapse/config/saml2.py @@ -164,7 +164,13 @@ class SAML2Config(Config): config_path = saml2_config.get("config_path", None) if config_path is not None: mod = load_python_module(config_path) - _dict_merge(merge_dict=mod.CONFIG, into_dict=saml2_config_dict) + config = getattr(mod, "CONFIG", None) + if config is None: + raise ConfigError( + "Config path specified by saml2_config.config_path does not " + "have a CONFIG property." + ) + _dict_merge(merge_dict=config, into_dict=saml2_config_dict) import saml2.config diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py index 0e86807834..6990f3ed1d 100644 --- a/synapse/storage/databases/main/keys.py +++ b/synapse/storage/databases/main/keys.py @@ -55,7 +55,7 @@ class KeyStore(SQLBaseStore): """ keys = {} - def _get_keys(txn: Cursor, batch: Tuple[Tuple[str, str]]) -> None: + def _get_keys(txn: Cursor, batch: Tuple[Tuple[str, str], ...]) -> None: """Processes a batch of keys to fetch, and adds the result to `keys`.""" # batch_iter always returns tuples so it's safe to do len(batch) diff --git a/synapse/util/hash.py b/synapse/util/hash.py index ba676e1762..7625ca8c2c 100644 --- a/synapse/util/hash.py +++ b/synapse/util/hash.py @@ -17,15 +17,15 @@ import hashlib import unpaddedbase64 -def sha256_and_url_safe_base64(input_text): +def sha256_and_url_safe_base64(input_text: str) -> str: """SHA256 hash an input string, encode the digest as url-safe base64, and return - :param input_text: string to hash - :type input_text: str + Args: + input_text: string to hash - :returns a sha256 hashed and url-safe base64 encoded digest - :rtype: str + returns: + A sha256 hashed and url-safe base64 encoded digest """ digest = hashlib.sha256(input_text.encode()).digest() return unpaddedbase64.encode_base64(digest, urlsafe=True) diff --git a/synapse/util/iterutils.py b/synapse/util/iterutils.py index abfdc29832..886afa9d19 100644 --- a/synapse/util/iterutils.py +++ b/synapse/util/iterutils.py @@ -30,12 +30,12 @@ from typing import ( T = TypeVar("T") -def batch_iter(iterable: Iterable[T], size: int) -> Iterator[Tuple[T]]: +def batch_iter(iterable: Iterable[T], size: int) -> Iterator[Tuple[T, ...]]: """batch an iterable up into tuples with a maximum size Args: - iterable (iterable): the iterable to slice - size (int): the maximum batch size + iterable: the iterable to slice + size: the maximum batch size Returns: an iterator over the chunks @@ -46,10 +46,7 @@ def batch_iter(iterable: Iterable[T], size: int) -> Iterator[Tuple[T]]: return iter(lambda: tuple(islice(sourceiter, size)), ()) -ISeq = TypeVar("ISeq", bound=Sequence, covariant=True) - - -def chunk_seq(iseq: ISeq, maxlen: int) -> Iterable[ISeq]: +def chunk_seq(iseq: Sequence[T], maxlen: int) -> Iterable[Sequence[T]]: """Split the given sequence into chunks of the given size The last chunk may be shorter than the given size. diff --git a/synapse/util/module_loader.py b/synapse/util/module_loader.py index 8acbe276e4..cbfbd097f9 100644 --- a/synapse/util/module_loader.py +++ b/synapse/util/module_loader.py @@ -15,6 +15,7 @@ import importlib import importlib.util import itertools +from types import ModuleType from typing import Any, Iterable, Tuple, Type import jsonschema @@ -44,8 +45,8 @@ def load_module(provider: dict, config_path: Iterable[str]) -> Tuple[Type, Any]: # We need to import the module, and then pick the class out of # that, so we split based on the last dot. - module, clz = modulename.rsplit(".", 1) - module = importlib.import_module(module) + module_name, clz = modulename.rsplit(".", 1) + module = importlib.import_module(module_name) provider_class = getattr(module, clz) # Load the module config. If None, pass an empty dictionary instead @@ -69,11 +70,11 @@ def load_module(provider: dict, config_path: Iterable[str]) -> Tuple[Type, Any]: return provider_class, provider_config -def load_python_module(location: str): +def load_python_module(location: str) -> ModuleType: """Load a python module, and return a reference to its global namespace Args: - location (str): path to the module + location: path to the module Returns: python module object diff --git a/synapse/util/msisdn.py b/synapse/util/msisdn.py index bbbdebf264..1046224f15 100644 --- a/synapse/util/msisdn.py +++ b/synapse/util/msisdn.py @@ -17,19 +17,19 @@ import phonenumbers from synapse.api.errors import SynapseError -def phone_number_to_msisdn(country, number): +def phone_number_to_msisdn(country: str, number: str) -> str: """ Takes an ISO-3166-1 2 letter country code and phone number and returns an msisdn representing the canonical version of that phone number. Args: - country (str): ISO-3166-1 2 letter country code - number (str): Phone number in a national or international format + country: ISO-3166-1 2 letter country code + number: Phone number in a national or international format Returns: - (str) The canonical form of the phone number, as an msisdn + The canonical form of the phone number, as an msisdn Raises: - SynapseError if the number could not be parsed. + SynapseError if the number could not be parsed. """ try: phoneNumber = phonenumbers.parse(number, country) diff --git a/tests/util/test_itertools.py b/tests/util/test_itertools.py index 1bd0b45d94..e712eb42ea 100644 --- a/tests/util/test_itertools.py +++ b/tests/util/test_itertools.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List +from typing import Dict, Iterable, List, Sequence from synapse.util.iterutils import chunk_seq, sorted_topologically @@ -44,7 +44,7 @@ class ChunkSeqTests(TestCase): ) def test_empty_input(self): - parts = chunk_seq([], 5) + parts = chunk_seq([], 5) # type: Iterable[Sequence] self.assertEqual( list(parts), -- cgit 1.5.1