summary refs log tree commit diff
path: root/tests/util
diff options
context:
space:
mode:
Diffstat (limited to 'tests/util')
-rw-r--r--tests/util/test_async_helpers.py118
-rw-r--r--tests/util/test_batching_queue.py30
-rw-r--r--tests/util/test_check_dependencies.py29
-rw-r--r--tests/util/test_dict_cache.py20
-rw-r--r--tests/util/test_expiring_cache.py26
-rw-r--r--tests/util/test_file_consumer.py103
-rw-r--r--tests/util/test_itertools.py24
-rw-r--r--tests/util/test_logcontext.py86
-rw-r--r--tests/util/test_logformatter.py2
-rw-r--r--tests/util/test_lrucache.py80
-rw-r--r--tests/util/test_macaroons.py8
-rw-r--r--tests/util/test_ratelimitutils.py15
-rw-r--r--tests/util/test_retryutils.py4
-rw-r--r--tests/util/test_rwlock.py14
-rw-r--r--tests/util/test_stream_change_cache.py14
-rw-r--r--tests/util/test_stringutils.py4
-rw-r--r--tests/util/test_threepids.py16
-rw-r--r--tests/util/test_treecache.py14
-rw-r--r--tests/util/test_wheel_timer.py16
19 files changed, 357 insertions, 266 deletions
diff --git a/tests/util/test_async_helpers.py b/tests/util/test_async_helpers.py
index 9d5010bf92..91cac9822a 100644
--- a/tests/util/test_async_helpers.py
+++ b/tests/util/test_async_helpers.py
@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import traceback
+from typing import Generator, List, NoReturn, Optional
 
 from parameterized import parameterized_class
 
@@ -41,8 +42,8 @@ from tests.unittest import TestCase
 
 
 class ObservableDeferredTest(TestCase):
-    def test_succeed(self):
-        origin_d = Deferred()
+    def test_succeed(self) -> None:
+        origin_d: "Deferred[int]" = Deferred()
         observable = ObservableDeferred(origin_d)
 
         observer1 = observable.observe()
@@ -52,16 +53,18 @@ class ObservableDeferredTest(TestCase):
         self.assertFalse(observer2.called)
 
         # check the first observer is called first
-        def check_called_first(res):
+        def check_called_first(res: int) -> int:
             self.assertFalse(observer2.called)
             return res
 
         observer1.addBoth(check_called_first)
 
         # store the results
-        results = [None, None]
+        results: List[Optional[ObservableDeferred[int]]] = [None, None]
 
-        def check_val(res, idx):
+        def check_val(
+            res: ObservableDeferred[int], idx: int
+        ) -> ObservableDeferred[int]:
             results[idx] = res
             return res
 
@@ -72,8 +75,8 @@ class ObservableDeferredTest(TestCase):
         self.assertEqual(results[0], 123, "observer 1 callback result")
         self.assertEqual(results[1], 123, "observer 2 callback result")
 
-    def test_failure(self):
-        origin_d = Deferred()
+    def test_failure(self) -> None:
+        origin_d: Deferred = Deferred()
         observable = ObservableDeferred(origin_d, consumeErrors=True)
 
         observer1 = observable.observe()
@@ -83,16 +86,16 @@ class ObservableDeferredTest(TestCase):
         self.assertFalse(observer2.called)
 
         # check the first observer is called first
-        def check_called_first(res):
+        def check_called_first(res: int) -> int:
             self.assertFalse(observer2.called)
             return res
 
         observer1.addBoth(check_called_first)
 
         # store the results
-        results = [None, None]
+        results: List[Optional[ObservableDeferred[str]]] = [None, None]
 
-        def check_val(res, idx):
+        def check_val(res: ObservableDeferred[str], idx: int) -> None:
             results[idx] = res
             return None
 
@@ -103,10 +106,12 @@ class ObservableDeferredTest(TestCase):
             raise Exception("gah!")
         except Exception as e:
             origin_d.errback(e)
+        assert results[0] is not None
         self.assertEqual(str(results[0].value), "gah!", "observer 1 errback result")
+        assert results[1] is not None
         self.assertEqual(str(results[1].value), "gah!", "observer 2 errback result")
 
-    def test_cancellation(self):
+    def test_cancellation(self) -> None:
         """Test that cancelling an observer does not affect other observers."""
         origin_d: "Deferred[int]" = Deferred()
         observable = ObservableDeferred(origin_d, consumeErrors=True)
@@ -136,37 +141,38 @@ class ObservableDeferredTest(TestCase):
 
 
 class TimeoutDeferredTest(TestCase):
-    def setUp(self):
+    def setUp(self) -> None:
         self.clock = Clock()
 
-    def test_times_out(self):
+    def test_times_out(self) -> None:
         """Basic test case that checks that the original deferred is cancelled and that
         the timing-out deferred is errbacked
         """
-        cancelled = [False]
+        cancelled = False
 
-        def canceller(_d):
-            cancelled[0] = True
+        def canceller(_d: Deferred) -> None:
+            nonlocal cancelled
+            cancelled = True
 
-        non_completing_d = Deferred(canceller)
+        non_completing_d: Deferred = Deferred(canceller)
         timing_out_d = timeout_deferred(non_completing_d, 1.0, self.clock)
 
         self.assertNoResult(timing_out_d)
-        self.assertFalse(cancelled[0], "deferred was cancelled prematurely")
+        self.assertFalse(cancelled, "deferred was cancelled prematurely")
 
         self.clock.pump((1.0,))
 
-        self.assertTrue(cancelled[0], "deferred was not cancelled by timeout")
+        self.assertTrue(cancelled, "deferred was not cancelled by timeout")
         self.failureResultOf(timing_out_d, defer.TimeoutError)
 
-    def test_times_out_when_canceller_throws(self):
+    def test_times_out_when_canceller_throws(self) -> None:
         """Test that we have successfully worked around
         https://twistedmatrix.com/trac/ticket/9534"""
 
-        def canceller(_d):
+        def canceller(_d: Deferred) -> None:
             raise Exception("can't cancel this deferred")
 
-        non_completing_d = Deferred(canceller)
+        non_completing_d: Deferred = Deferred(canceller)
         timing_out_d = timeout_deferred(non_completing_d, 1.0, self.clock)
 
         self.assertNoResult(timing_out_d)
@@ -175,22 +181,24 @@ class TimeoutDeferredTest(TestCase):
 
         self.failureResultOf(timing_out_d, defer.TimeoutError)
 
-    def test_logcontext_is_preserved_on_cancellation(self):
-        blocking_was_cancelled = [False]
+    def test_logcontext_is_preserved_on_cancellation(self) -> None:
+        blocking_was_cancelled = False
 
         @defer.inlineCallbacks
-        def blocking():
-            non_completing_d = Deferred()
+        def blocking() -> Generator["Deferred[object]", object, None]:
+            nonlocal blocking_was_cancelled
+
+            non_completing_d: Deferred = Deferred()
             with PreserveLoggingContext():
                 try:
                     yield non_completing_d
                 except CancelledError:
-                    blocking_was_cancelled[0] = True
+                    blocking_was_cancelled = True
                     raise
 
         with LoggingContext("one") as context_one:
             # the errbacks should be run in the test logcontext
-            def errback(res, deferred_name):
+            def errback(res: Failure, deferred_name: str) -> Failure:
                 self.assertIs(
                     current_context(),
                     context_one,
@@ -209,7 +217,7 @@ class TimeoutDeferredTest(TestCase):
             self.clock.pump((1.0,))
 
             self.assertTrue(
-                blocking_was_cancelled[0], "non-completing deferred was not cancelled"
+                blocking_was_cancelled, "non-completing deferred was not cancelled"
             )
             self.failureResultOf(timing_out_d, defer.TimeoutError)
             self.assertIs(current_context(), context_one)
@@ -220,13 +228,13 @@ class _TestException(Exception):
 
 
 class ConcurrentlyExecuteTest(TestCase):
-    def test_limits_runners(self):
+    def test_limits_runners(self) -> None:
         """If we have more tasks than runners, we should get the limit of runners"""
         started = 0
         waiters = []
         processed = []
 
-        async def callback(v):
+        async def callback(v: int) -> None:
             # when we first enter, bump the start count
             nonlocal started
             started += 1
@@ -235,7 +243,7 @@ class ConcurrentlyExecuteTest(TestCase):
             processed.append(v)
 
             # wait for the goahead before returning
-            d2 = Deferred()
+            d2: "Deferred[int]" = Deferred()
             waiters.append(d2)
             await d2
 
@@ -265,16 +273,16 @@ class ConcurrentlyExecuteTest(TestCase):
         self.assertCountEqual(processed, [1, 2, 3, 4, 5])
         self.successResultOf(d2)
 
-    def test_preserves_stacktraces(self):
+    def test_preserves_stacktraces(self) -> None:
         """Test that the stacktrace from an exception thrown in the callback is preserved"""
-        d1 = Deferred()
+        d1: "Deferred[int]" = Deferred()
 
-        async def callback(v):
+        async def callback(v: int) -> None:
             # alas, this doesn't work at all without an await here
             await d1
             raise _TestException("bah")
 
-        async def caller():
+        async def caller() -> None:
             try:
                 await concurrently_execute(callback, [1], 2)
             except _TestException as e:
@@ -290,17 +298,17 @@ class ConcurrentlyExecuteTest(TestCase):
         d1.callback(0)
         self.successResultOf(d2)
 
-    def test_preserves_stacktraces_on_preformed_failure(self):
+    def test_preserves_stacktraces_on_preformed_failure(self) -> None:
         """Test that the stacktrace on a Failure returned by the callback is preserved"""
-        d1 = Deferred()
+        d1: "Deferred[int]" = Deferred()
         f = Failure(_TestException("bah"))
 
-        async def callback(v):
+        async def callback(v: int) -> None:
             # alas, this doesn't work at all without an await here
             await d1
             await defer.fail(f)
 
-        async def caller():
+        async def caller() -> None:
             try:
                 await concurrently_execute(callback, [1], 2)
             except _TestException as e:
@@ -336,7 +344,7 @@ class CancellationWrapperTests(TestCase):
         else:
             raise ValueError(f"Unsupported wrapper type: {self.wrapper}")
 
-    def test_succeed(self):
+    def test_succeed(self) -> None:
         """Test that the new `Deferred` receives the result."""
         deferred: "Deferred[str]" = Deferred()
         wrapper_deferred = self.wrap_deferred(deferred)
@@ -346,7 +354,7 @@ class CancellationWrapperTests(TestCase):
         self.assertTrue(wrapper_deferred.called)
         self.assertEqual("success", self.successResultOf(wrapper_deferred))
 
-    def test_failure(self):
+    def test_failure(self) -> None:
         """Test that the new `Deferred` receives the `Failure`."""
         deferred: "Deferred[str]" = Deferred()
         wrapper_deferred = self.wrap_deferred(deferred)
@@ -361,7 +369,7 @@ class CancellationWrapperTests(TestCase):
 class StopCancellationTests(TestCase):
     """Tests for the `stop_cancellation` function."""
 
-    def test_cancellation(self):
+    def test_cancellation(self) -> None:
         """Test that cancellation of the new `Deferred` leaves the original running."""
         deferred: "Deferred[str]" = Deferred()
         wrapper_deferred = stop_cancellation(deferred)
@@ -384,7 +392,7 @@ class StopCancellationTests(TestCase):
 class DelayCancellationTests(TestCase):
     """Tests for the `delay_cancellation` function."""
 
-    def test_deferred_cancellation(self):
+    def test_deferred_cancellation(self) -> None:
         """Test that cancellation of the new `Deferred` waits for the original."""
         deferred: "Deferred[str]" = Deferred()
         wrapper_deferred = delay_cancellation(deferred)
@@ -405,12 +413,12 @@ class DelayCancellationTests(TestCase):
         # Now that the original `Deferred` has failed, we should get a `CancelledError`.
         self.failureResultOf(wrapper_deferred, CancelledError)
 
-    def test_coroutine_cancellation(self):
+    def test_coroutine_cancellation(self) -> None:
         """Test that cancellation of the new `Deferred` waits for the original."""
         blocking_deferred: "Deferred[None]" = Deferred()
         completion_deferred: "Deferred[None]" = Deferred()
 
-        async def task():
+        async def task() -> NoReturn:
             await blocking_deferred
             completion_deferred.callback(None)
             # Raise an exception. Twisted should consume it, otherwise unwanted
@@ -434,7 +442,7 @@ class DelayCancellationTests(TestCase):
         # Now that the original coroutine has failed, we should get a `CancelledError`.
         self.failureResultOf(wrapper_deferred, CancelledError)
 
-    def test_suppresses_second_cancellation(self):
+    def test_suppresses_second_cancellation(self) -> None:
         """Test that a second cancellation is suppressed.
 
         Identical to `test_cancellation` except the new `Deferred` is cancelled twice.
@@ -459,7 +467,7 @@ class DelayCancellationTests(TestCase):
         # Now that the original `Deferred` has failed, we should get a `CancelledError`.
         self.failureResultOf(wrapper_deferred, CancelledError)
 
-    def test_propagates_cancelled_error(self):
+    def test_propagates_cancelled_error(self) -> None:
         """Test that a `CancelledError` from the original `Deferred` gets propagated."""
         deferred: "Deferred[str]" = Deferred()
         wrapper_deferred = delay_cancellation(deferred)
@@ -472,14 +480,14 @@ class DelayCancellationTests(TestCase):
         self.assertTrue(wrapper_deferred.called)
         self.assertIs(cancelled_error, self.failureResultOf(wrapper_deferred).value)
 
-    def test_preserves_logcontext(self):
+    def test_preserves_logcontext(self) -> None:
         """Test that logging contexts are preserved."""
         blocking_d: "Deferred[None]" = Deferred()
 
-        async def inner():
+        async def inner() -> None:
             await make_deferred_yieldable(blocking_d)
 
-        async def outer():
+        async def outer() -> None:
             with LoggingContext("c") as c:
                 try:
                     await delay_cancellation(inner())
@@ -503,7 +511,7 @@ class DelayCancellationTests(TestCase):
 class AwakenableSleeperTests(TestCase):
     "Tests AwakenableSleeper"
 
-    def test_sleep(self):
+    def test_sleep(self) -> None:
         reactor, _ = get_clock()
         sleeper = AwakenableSleeper(reactor)
 
@@ -518,7 +526,7 @@ class AwakenableSleeperTests(TestCase):
         reactor.advance(0.6)
         self.assertTrue(d.called)
 
-    def test_explicit_wake(self):
+    def test_explicit_wake(self) -> None:
         reactor, _ = get_clock()
         sleeper = AwakenableSleeper(reactor)
 
@@ -535,7 +543,7 @@ class AwakenableSleeperTests(TestCase):
 
         reactor.advance(0.6)
 
-    def test_multiple_sleepers_timeout(self):
+    def test_multiple_sleepers_timeout(self) -> None:
         reactor, _ = get_clock()
         sleeper = AwakenableSleeper(reactor)
 
@@ -555,7 +563,7 @@ class AwakenableSleeperTests(TestCase):
         reactor.advance(0.6)
         self.assertTrue(d2.called)
 
-    def test_multiple_sleepers_wake(self):
+    def test_multiple_sleepers_wake(self) -> None:
         reactor, _ = get_clock()
         sleeper = AwakenableSleeper(reactor)
 
diff --git a/tests/util/test_batching_queue.py b/tests/util/test_batching_queue.py
index 07be57d72c..94ef91f645 100644
--- a/tests/util/test_batching_queue.py
+++ b/tests/util/test_batching_queue.py
@@ -11,6 +11,10 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from typing import List, Tuple
+
+from prometheus_client import Gauge
+
 from twisted.internet import defer
 
 from synapse.logging.context import make_deferred_yieldable
@@ -26,7 +30,7 @@ from tests.unittest import TestCase
 
 
 class BatchingQueueTestCase(TestCase):
-    def setUp(self):
+    def setUp(self) -> None:
         self.clock, hs_clock = get_clock()
 
         # We ensure that we remove any existing metrics for "test_queue".
@@ -37,25 +41,27 @@ class BatchingQueueTestCase(TestCase):
         except KeyError:
             pass
 
-        self._pending_calls = []
-        self.queue = BatchingQueue("test_queue", hs_clock, self._process_queue)
+        self._pending_calls: List[Tuple[List[str], defer.Deferred]] = []
+        self.queue: BatchingQueue[str, str] = BatchingQueue(
+            "test_queue", hs_clock, self._process_queue
+        )
 
-    async def _process_queue(self, values):
-        d = defer.Deferred()
+    async def _process_queue(self, values: List[str]) -> str:
+        d: "defer.Deferred[str]" = defer.Deferred()
         self._pending_calls.append((values, d))
         return await make_deferred_yieldable(d)
 
-    def _get_sample_with_name(self, metric, name) -> int:
+    def _get_sample_with_name(self, metric: Gauge, name: str) -> float:
         """For a prometheus metric get the value of the sample that has a
         matching "name" label.
         """
-        for sample in metric.collect()[0].samples:
+        for sample in next(iter(metric.collect())).samples:
             if sample.labels.get("name") == name:
                 return sample.value
 
         self.fail("Found no matching sample")
 
-    def _assert_metrics(self, queued, keys, in_flight):
+    def _assert_metrics(self, queued: int, keys: int, in_flight: int) -> None:
         """Assert that the metrics are correct"""
 
         sample = self._get_sample_with_name(number_queued, self.queue._name)
@@ -75,7 +81,7 @@ class BatchingQueueTestCase(TestCase):
             "number_in_flight",
         )
 
-    def test_simple(self):
+    def test_simple(self) -> None:
         """Tests the basic case of calling `add_to_queue` once and having
         `_process_queue` return.
         """
@@ -106,7 +112,7 @@ class BatchingQueueTestCase(TestCase):
 
         self._assert_metrics(queued=0, keys=0, in_flight=0)
 
-    def test_batching(self):
+    def test_batching(self) -> None:
         """Test that multiple calls at the same time get batched up into one
         call to `_process_queue`.
         """
@@ -134,7 +140,7 @@ class BatchingQueueTestCase(TestCase):
         self.assertEqual(self.successResultOf(queue_d2), "bar")
         self._assert_metrics(queued=0, keys=0, in_flight=0)
 
-    def test_queuing(self):
+    def test_queuing(self) -> None:
         """Test that we queue up requests while a `_process_queue` is being
         called.
         """
@@ -184,7 +190,7 @@ class BatchingQueueTestCase(TestCase):
         self.assertEqual(self.successResultOf(queue_d3), "bar2")
         self._assert_metrics(queued=0, keys=0, in_flight=0)
 
-    def test_different_keys(self):
+    def test_different_keys(self) -> None:
         """Test that calls to different keys get processed in parallel."""
 
         self.assertFalse(self._pending_calls)
diff --git a/tests/util/test_check_dependencies.py b/tests/util/test_check_dependencies.py
index 6913de24b9..aa20fe6780 100644
--- a/tests/util/test_check_dependencies.py
+++ b/tests/util/test_check_dependencies.py
@@ -1,5 +1,20 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
 from contextlib import contextmanager
-from typing import Generator, Optional
+from os import PathLike
+from typing import Generator, Optional, Union
 from unittest.mock import patch
 
 from synapse.util.check_dependencies import (
@@ -12,17 +27,17 @@ from tests.unittest import TestCase
 
 
 class DummyDistribution(metadata.Distribution):
-    def __init__(self, version: object):
+    def __init__(self, version: str):
         self._version = version
 
     @property
-    def version(self):
+    def version(self) -> str:
         return self._version
 
-    def locate_file(self, path):
+    def locate_file(self, path: Union[str, PathLike]) -> PathLike:
         raise NotImplementedError()
 
-    def read_text(self, filename):
+    def read_text(self, filename: str) -> None:
         raise NotImplementedError()
 
 
@@ -30,7 +45,7 @@ old = DummyDistribution("0.1.2")
 old_release_candidate = DummyDistribution("0.1.2rc3")
 new = DummyDistribution("1.2.3")
 new_release_candidate = DummyDistribution("1.2.3rc4")
-distribution_with_no_version = DummyDistribution(None)
+distribution_with_no_version = DummyDistribution(None)  # type: ignore[arg-type]
 
 # could probably use stdlib TestCase --- no need for twisted here
 
@@ -45,7 +60,7 @@ class TestDependencyChecker(TestCase):
         If `distribution = None`, we pretend that the package is not installed.
         """
 
-        def mock_distribution(name: str):
+        def mock_distribution(name: str) -> DummyDistribution:
             if distribution is None:
                 raise metadata.PackageNotFoundError
             else:
diff --git a/tests/util/test_dict_cache.py b/tests/util/test_dict_cache.py
index e8b6246ab5..acb251bfea 100644
--- a/tests/util/test_dict_cache.py
+++ b/tests/util/test_dict_cache.py
@@ -19,10 +19,12 @@ from tests import unittest
 
 
 class DictCacheTestCase(unittest.TestCase):
-    def setUp(self):
-        self.cache = DictionaryCache("foobar", max_entries=10)
+    def setUp(self) -> None:
+        self.cache: DictionaryCache[str, str, str] = DictionaryCache(
+            "foobar", max_entries=10
+        )
 
-    def test_simple_cache_hit_full(self):
+    def test_simple_cache_hit_full(self) -> None:
         key = "test_simple_cache_hit_full"
 
         v = self.cache.get(key)
@@ -37,7 +39,7 @@ class DictCacheTestCase(unittest.TestCase):
         c = self.cache.get(key)
         self.assertEqual(test_value, c.value)
 
-    def test_simple_cache_hit_partial(self):
+    def test_simple_cache_hit_partial(self) -> None:
         key = "test_simple_cache_hit_partial"
 
         seq = self.cache.sequence
@@ -47,7 +49,7 @@ class DictCacheTestCase(unittest.TestCase):
         c = self.cache.get(key, ["test"])
         self.assertEqual(test_value, c.value)
 
-    def test_simple_cache_miss_partial(self):
+    def test_simple_cache_miss_partial(self) -> None:
         key = "test_simple_cache_miss_partial"
 
         seq = self.cache.sequence
@@ -57,7 +59,7 @@ class DictCacheTestCase(unittest.TestCase):
         c = self.cache.get(key, ["test2"])
         self.assertEqual({}, c.value)
 
-    def test_simple_cache_hit_miss_partial(self):
+    def test_simple_cache_hit_miss_partial(self) -> None:
         key = "test_simple_cache_hit_miss_partial"
 
         seq = self.cache.sequence
@@ -71,7 +73,7 @@ class DictCacheTestCase(unittest.TestCase):
         c = self.cache.get(key, ["test2"])
         self.assertEqual({"test2": "test_simple_cache_hit_miss_partial2"}, c.value)
 
-    def test_multi_insert(self):
+    def test_multi_insert(self) -> None:
         key = "test_simple_cache_hit_miss_partial"
 
         seq = self.cache.sequence
@@ -92,7 +94,7 @@ class DictCacheTestCase(unittest.TestCase):
         )
         self.assertEqual(c.full, False)
 
-    def test_invalidation(self):
+    def test_invalidation(self) -> None:
         """Test that the partial dict and full dicts get invalidated
         separately.
         """
@@ -106,7 +108,7 @@ class DictCacheTestCase(unittest.TestCase):
         # entry for "a" warm.
         for i in range(20):
             self.cache.get(key, ["a"])
-            self.cache.update(seq, f"key{i}", {1: 2})
+            self.cache.update(seq, f"key{i}", {"1": "2"})
 
         # We should have evicted the full dict...
         r = self.cache.get(key)
diff --git a/tests/util/test_expiring_cache.py b/tests/util/test_expiring_cache.py
index 7f60aae5ba..9cf920daf8 100644
--- a/tests/util/test_expiring_cache.py
+++ b/tests/util/test_expiring_cache.py
@@ -12,7 +12,9 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from typing import List, cast
 
+from synapse.util import Clock
 from synapse.util.caches.expiringcache import ExpiringCache
 
 from tests.utils import MockClock
@@ -21,17 +23,21 @@ from .. import unittest
 
 
 class ExpiringCacheTestCase(unittest.HomeserverTestCase):
-    def test_get_set(self):
+    def test_get_set(self) -> None:
         clock = MockClock()
-        cache = ExpiringCache("test", clock, max_len=1)
+        cache: ExpiringCache[str, str] = ExpiringCache(
+            "test", cast(Clock, clock), max_len=1
+        )
 
         cache["key"] = "value"
         self.assertEqual(cache.get("key"), "value")
         self.assertEqual(cache["key"], "value")
 
-    def test_eviction(self):
+    def test_eviction(self) -> None:
         clock = MockClock()
-        cache = ExpiringCache("test", clock, max_len=2)
+        cache: ExpiringCache[str, str] = ExpiringCache(
+            "test", cast(Clock, clock), max_len=2
+        )
 
         cache["key"] = "value"
         cache["key2"] = "value2"
@@ -43,9 +49,11 @@ class ExpiringCacheTestCase(unittest.HomeserverTestCase):
         self.assertEqual(cache.get("key2"), "value2")
         self.assertEqual(cache.get("key3"), "value3")
 
-    def test_iterable_eviction(self):
+    def test_iterable_eviction(self) -> None:
         clock = MockClock()
-        cache = ExpiringCache("test", clock, max_len=5, iterable=True)
+        cache: ExpiringCache[str, List[int]] = ExpiringCache(
+            "test", cast(Clock, clock), max_len=5, iterable=True
+        )
 
         cache["key"] = [1]
         cache["key2"] = [2, 3]
@@ -61,9 +69,11 @@ class ExpiringCacheTestCase(unittest.HomeserverTestCase):
         self.assertEqual(cache.get("key3"), [4, 5])
         self.assertEqual(cache.get("key4"), [6, 7])
 
-    def test_time_eviction(self):
+    def test_time_eviction(self) -> None:
         clock = MockClock()
-        cache = ExpiringCache("test", clock, expiry_ms=1000)
+        cache: ExpiringCache[str, int] = ExpiringCache(
+            "test", cast(Clock, clock), expiry_ms=1000
+        )
 
         cache["key"] = 1
         clock.advance_time(0.5)
diff --git a/tests/util/test_file_consumer.py b/tests/util/test_file_consumer.py
index 3bb4695405..4f3c983c15 100644
--- a/tests/util/test_file_consumer.py
+++ b/tests/util/test_file_consumer.py
@@ -12,22 +12,28 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
 import threading
-from io import StringIO
+from io import BytesIO
+from typing import BinaryIO, Generator, Optional, cast
 from unittest.mock import NonCallableMock
 
-from twisted.internet import defer, reactor
+from zope.interface import implementer
+
+from twisted.internet import defer, reactor as _reactor
+from twisted.internet.interfaces import IPullProducer
 
+from synapse.types import ISynapseReactor
 from synapse.util.file_consumer import BackgroundFileConsumer
 
 from tests import unittest
 
+reactor = cast(ISynapseReactor, _reactor)
+
 
 class FileConsumerTests(unittest.TestCase):
     @defer.inlineCallbacks
-    def test_pull_consumer(self):
-        string_file = StringIO()
+    def test_pull_consumer(self) -> Generator["defer.Deferred[object]", object, None]:
+        string_file = BytesIO()
         consumer = BackgroundFileConsumer(string_file, reactor=reactor)
 
         try:
@@ -35,55 +41,57 @@ class FileConsumerTests(unittest.TestCase):
 
             yield producer.register_with_consumer(consumer)
 
-            yield producer.write_and_wait("Foo")
+            yield producer.write_and_wait(b"Foo")
 
-            self.assertEqual(string_file.getvalue(), "Foo")
+            self.assertEqual(string_file.getvalue(), b"Foo")
 
-            yield producer.write_and_wait("Bar")
+            yield producer.write_and_wait(b"Bar")
 
-            self.assertEqual(string_file.getvalue(), "FooBar")
+            self.assertEqual(string_file.getvalue(), b"FooBar")
         finally:
             consumer.unregisterProducer()
 
-        yield consumer.wait()
+        yield consumer.wait()  # type: ignore[misc]
 
         self.assertTrue(string_file.closed)
 
     @defer.inlineCallbacks
-    def test_push_consumer(self):
-        string_file = BlockingStringWrite()
-        consumer = BackgroundFileConsumer(string_file, reactor=reactor)
+    def test_push_consumer(self) -> Generator["defer.Deferred[object]", object, None]:
+        string_file = BlockingBytesWrite()
+        consumer = BackgroundFileConsumer(cast(BinaryIO, string_file), reactor=reactor)
 
         try:
             producer = NonCallableMock(spec_set=[])
 
             consumer.registerProducer(producer, True)
 
-            consumer.write("Foo")
-            yield string_file.wait_for_n_writes(1)
+            consumer.write(b"Foo")
+            yield string_file.wait_for_n_writes(1)  # type: ignore[misc]
 
-            self.assertEqual(string_file.buffer, "Foo")
+            self.assertEqual(string_file.buffer, b"Foo")
 
-            consumer.write("Bar")
-            yield string_file.wait_for_n_writes(2)
+            consumer.write(b"Bar")
+            yield string_file.wait_for_n_writes(2)  # type: ignore[misc]
 
-            self.assertEqual(string_file.buffer, "FooBar")
+            self.assertEqual(string_file.buffer, b"FooBar")
         finally:
             consumer.unregisterProducer()
 
-        yield consumer.wait()
+        yield consumer.wait()  # type: ignore[misc]
 
         self.assertTrue(string_file.closed)
 
     @defer.inlineCallbacks
-    def test_push_producer_feedback(self):
-        string_file = BlockingStringWrite()
-        consumer = BackgroundFileConsumer(string_file, reactor=reactor)
+    def test_push_producer_feedback(
+        self,
+    ) -> Generator["defer.Deferred[object]", object, None]:
+        string_file = BlockingBytesWrite()
+        consumer = BackgroundFileConsumer(cast(BinaryIO, string_file), reactor=reactor)
 
         try:
             producer = NonCallableMock(spec_set=["pauseProducing", "resumeProducing"])
 
-            resume_deferred = defer.Deferred()
+            resume_deferred: defer.Deferred = defer.Deferred()
             producer.resumeProducing.side_effect = lambda: resume_deferred.callback(
                 None
             )
@@ -93,65 +101,72 @@ class FileConsumerTests(unittest.TestCase):
             number_writes = 0
             with string_file.write_lock:
                 for _ in range(consumer._PAUSE_ON_QUEUE_SIZE):
-                    consumer.write("Foo")
+                    consumer.write(b"Foo")
                     number_writes += 1
 
                 producer.pauseProducing.assert_called_once()
 
-            yield string_file.wait_for_n_writes(number_writes)
+            yield string_file.wait_for_n_writes(number_writes)  # type: ignore[misc]
 
             yield resume_deferred
             producer.resumeProducing.assert_called_once()
         finally:
             consumer.unregisterProducer()
 
-        yield consumer.wait()
+        yield consumer.wait()  # type: ignore[misc]
 
         self.assertTrue(string_file.closed)
 
 
+@implementer(IPullProducer)
 class DummyPullProducer:
-    def __init__(self):
-        self.consumer = None
-        self.deferred = defer.Deferred()
+    def __init__(self) -> None:
+        self.consumer: Optional[BackgroundFileConsumer] = None
+        self.deferred: "defer.Deferred[object]" = defer.Deferred()
 
-    def resumeProducing(self):
+    def resumeProducing(self) -> None:
         d = self.deferred
         self.deferred = defer.Deferred()
         d.callback(None)
 
-    def write_and_wait(self, bytes):
+    def stopProducing(self) -> None:
+        raise RuntimeError("Unexpected call")
+
+    def write_and_wait(self, write_bytes: bytes) -> "defer.Deferred[object]":
+        assert self.consumer is not None
         d = self.deferred
-        self.consumer.write(bytes)
+        self.consumer.write(write_bytes)
         return d
 
-    def register_with_consumer(self, consumer):
+    def register_with_consumer(
+        self, consumer: BackgroundFileConsumer
+    ) -> "defer.Deferred[object]":
         d = self.deferred
         self.consumer = consumer
         self.consumer.registerProducer(self, False)
         return d
 
 
-class BlockingStringWrite:
-    def __init__(self):
-        self.buffer = ""
+class BlockingBytesWrite:
+    def __init__(self) -> None:
+        self.buffer = b""
         self.closed = False
         self.write_lock = threading.Lock()
 
-        self._notify_write_deferred = None
+        self._notify_write_deferred: Optional[defer.Deferred] = None
         self._number_of_writes = 0
 
-    def write(self, bytes):
+    def write(self, write_bytes: bytes) -> None:
         with self.write_lock:
-            self.buffer += bytes
+            self.buffer += write_bytes
             self._number_of_writes += 1
 
         reactor.callFromThread(self._notify_write)
 
-    def close(self):
+    def close(self) -> None:
         self.closed = True
 
-    def _notify_write(self):
+    def _notify_write(self) -> None:
         "Called by write to indicate a write happened"
         with self.write_lock:
             if not self._notify_write_deferred:
@@ -161,7 +176,9 @@ class BlockingStringWrite:
         d.callback(None)
 
     @defer.inlineCallbacks
-    def wait_for_n_writes(self, n):
+    def wait_for_n_writes(
+        self, n: int
+    ) -> Generator["defer.Deferred[object]", object, None]:
         "Wait for n writes to have happened"
         while True:
             with self.write_lock:
diff --git a/tests/util/test_itertools.py b/tests/util/test_itertools.py
index 3c0ddd4f18..406c16cdcf 100644
--- a/tests/util/test_itertools.py
+++ b/tests/util/test_itertools.py
@@ -19,7 +19,7 @@ from tests.unittest import TestCase
 
 
 class ChunkSeqTests(TestCase):
-    def test_short_seq(self):
+    def test_short_seq(self) -> None:
         parts = chunk_seq("123", 8)
 
         self.assertEqual(
@@ -27,7 +27,7 @@ class ChunkSeqTests(TestCase):
             ["123"],
         )
 
-    def test_long_seq(self):
+    def test_long_seq(self) -> None:
         parts = chunk_seq("abcdefghijklmnop", 8)
 
         self.assertEqual(
@@ -35,7 +35,7 @@ class ChunkSeqTests(TestCase):
             ["abcdefgh", "ijklmnop"],
         )
 
-    def test_uneven_parts(self):
+    def test_uneven_parts(self) -> None:
         parts = chunk_seq("abcdefghijklmnop", 5)
 
         self.assertEqual(
@@ -43,7 +43,7 @@ class ChunkSeqTests(TestCase):
             ["abcde", "fghij", "klmno", "p"],
         )
 
-    def test_empty_input(self):
+    def test_empty_input(self) -> None:
         parts: Iterable[Sequence] = chunk_seq([], 5)
 
         self.assertEqual(
@@ -53,13 +53,13 @@ class ChunkSeqTests(TestCase):
 
 
 class SortTopologically(TestCase):
-    def test_empty(self):
+    def test_empty(self) -> None:
         "Test that an empty graph works correctly"
 
         graph: Dict[int, List[int]] = {}
         self.assertEqual(list(sorted_topologically([], graph)), [])
 
-    def test_handle_empty_graph(self):
+    def test_handle_empty_graph(self) -> None:
         "Test that a graph where a node doesn't have an entry is treated as empty"
 
         graph: Dict[int, List[int]] = {}
@@ -67,7 +67,7 @@ class SortTopologically(TestCase):
         # For disconnected nodes the output is simply sorted.
         self.assertEqual(list(sorted_topologically([1, 2], graph)), [1, 2])
 
-    def test_disconnected(self):
+    def test_disconnected(self) -> None:
         "Test that a graph with no edges work"
 
         graph: Dict[int, List[int]] = {1: [], 2: []}
@@ -75,20 +75,20 @@ class SortTopologically(TestCase):
         # For disconnected nodes the output is simply sorted.
         self.assertEqual(list(sorted_topologically([1, 2], graph)), [1, 2])
 
-    def test_linear(self):
+    def test_linear(self) -> None:
         "Test that a simple `4 -> 3 -> 2 -> 1` graph works"
 
         graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [2], 4: [3]}
 
         self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
 
-    def test_subset(self):
+    def test_subset(self) -> None:
         "Test that only sorting a subset of the graph works"
         graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [2], 4: [3]}
 
         self.assertEqual(list(sorted_topologically([4, 3], graph)), [3, 4])
 
-    def test_fork(self):
+    def test_fork(self) -> None:
         "Test that a forked graph works"
         graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [1], 4: [2, 3]}
 
@@ -96,13 +96,13 @@ class SortTopologically(TestCase):
         # always get the same one.
         self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
 
-    def test_duplicates(self):
+    def test_duplicates(self) -> None:
         "Test that a graph with duplicate edges work"
         graph: Dict[int, List[int]] = {1: [], 2: [1, 1], 3: [2, 2], 4: [3]}
 
         self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
 
-    def test_multiple_paths(self):
+    def test_multiple_paths(self) -> None:
         "Test that a graph with multiple paths between two nodes work"
         graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [2], 4: [3, 2, 1]}
 
diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py
index 2ad321e184..d64c162e1d 100644
--- a/tests/util/test_logcontext.py
+++ b/tests/util/test_logcontext.py
@@ -1,5 +1,21 @@
+# Copyright 2014-2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Callable, Generator, cast
+
 import twisted.python.failure
-from twisted.internet import defer, reactor
+from twisted.internet import defer, reactor as _reactor
 
 from synapse.logging.context import (
     SENTINEL_CONTEXT,
@@ -10,25 +26,30 @@ from synapse.logging.context import (
     nested_logging_context,
     run_in_background,
 )
+from synapse.types import ISynapseReactor
 from synapse.util import Clock
 
 from .. import unittest
 
+reactor = cast(ISynapseReactor, _reactor)
+
 
 class LoggingContextTestCase(unittest.TestCase):
-    def _check_test_key(self, value):
-        self.assertEqual(current_context().name, value)
+    def _check_test_key(self, value: str) -> None:
+        context = current_context()
+        assert isinstance(context, LoggingContext)
+        self.assertEqual(context.name, value)
 
-    def test_with_context(self):
+    def test_with_context(self) -> None:
         with LoggingContext("test"):
             self._check_test_key("test")
 
     @defer.inlineCallbacks
-    def test_sleep(self):
+    def test_sleep(self) -> Generator["defer.Deferred[object]", object, None]:
         clock = Clock(reactor)
 
         @defer.inlineCallbacks
-        def competing_callback():
+        def competing_callback() -> Generator["defer.Deferred[object]", object, None]:
             with LoggingContext("competing"):
                 yield clock.sleep(0)
                 self._check_test_key("competing")
@@ -39,17 +60,18 @@ class LoggingContextTestCase(unittest.TestCase):
             yield clock.sleep(0)
             self._check_test_key("one")
 
-    def _test_run_in_background(self, function):
+    def _test_run_in_background(self, function: Callable[[], object]) -> defer.Deferred:
         sentinel_context = current_context()
 
-        callback_completed = [False]
+        callback_completed = False
 
         with LoggingContext("one"):
             # fire off function, but don't wait on it.
             d2 = run_in_background(function)
 
-            def cb(res):
-                callback_completed[0] = True
+            def cb(res: object) -> object:
+                nonlocal callback_completed
+                callback_completed = True
                 return res
 
             d2.addCallback(cb)
@@ -60,8 +82,8 @@ class LoggingContextTestCase(unittest.TestCase):
         # the logcontext is left in a sane state.
         d2 = defer.Deferred()
 
-        def check_logcontext():
-            if not callback_completed[0]:
+        def check_logcontext() -> None:
+            if not callback_completed:
                 reactor.callLater(0.01, check_logcontext)
                 return
 
@@ -78,31 +100,31 @@ class LoggingContextTestCase(unittest.TestCase):
         # test is done once d2 finishes
         return d2
 
-    def test_run_in_background_with_blocking_fn(self):
+    def test_run_in_background_with_blocking_fn(self) -> defer.Deferred:
         @defer.inlineCallbacks
-        def blocking_function():
+        def blocking_function() -> Generator["defer.Deferred[object]", object, None]:
             yield Clock(reactor).sleep(0)
 
         return self._test_run_in_background(blocking_function)
 
-    def test_run_in_background_with_non_blocking_fn(self):
+    def test_run_in_background_with_non_blocking_fn(self) -> defer.Deferred:
         @defer.inlineCallbacks
-        def nonblocking_function():
+        def nonblocking_function() -> Generator["defer.Deferred[object]", object, None]:
             with PreserveLoggingContext():
                 yield defer.succeed(None)
 
         return self._test_run_in_background(nonblocking_function)
 
-    def test_run_in_background_with_chained_deferred(self):
+    def test_run_in_background_with_chained_deferred(self) -> defer.Deferred:
         # a function which returns a deferred which looks like it has been
         # called, but is actually paused
-        def testfunc():
+        def testfunc() -> defer.Deferred:
             return make_deferred_yieldable(_chained_deferred_function())
 
         return self._test_run_in_background(testfunc)
 
-    def test_run_in_background_with_coroutine(self):
-        async def testfunc():
+    def test_run_in_background_with_coroutine(self) -> defer.Deferred:
+        async def testfunc() -> None:
             self._check_test_key("one")
             d = Clock(reactor).sleep(0)
             self.assertIs(current_context(), SENTINEL_CONTEXT)
@@ -111,18 +133,20 @@ class LoggingContextTestCase(unittest.TestCase):
 
         return self._test_run_in_background(testfunc)
 
-    def test_run_in_background_with_nonblocking_coroutine(self):
-        async def testfunc():
+    def test_run_in_background_with_nonblocking_coroutine(self) -> defer.Deferred:
+        async def testfunc() -> None:
             self._check_test_key("one")
 
         return self._test_run_in_background(testfunc)
 
     @defer.inlineCallbacks
-    def test_make_deferred_yieldable(self):
+    def test_make_deferred_yieldable(
+        self,
+    ) -> Generator["defer.Deferred[object]", object, None]:
         # a function which returns an incomplete deferred, but doesn't follow
         # the synapse rules.
-        def blocking_function():
-            d = defer.Deferred()
+        def blocking_function() -> defer.Deferred:
+            d: defer.Deferred = defer.Deferred()
             reactor.callLater(0, d.callback, None)
             return d
 
@@ -139,7 +163,9 @@ class LoggingContextTestCase(unittest.TestCase):
             self._check_test_key("one")
 
     @defer.inlineCallbacks
-    def test_make_deferred_yieldable_with_chained_deferreds(self):
+    def test_make_deferred_yieldable_with_chained_deferreds(
+        self,
+    ) -> Generator["defer.Deferred[object]", object, None]:
         sentinel_context = current_context()
 
         with LoggingContext("one"):
@@ -152,7 +178,7 @@ class LoggingContextTestCase(unittest.TestCase):
             # now it should be restored
             self._check_test_key("one")
 
-    def test_nested_logging_context(self):
+    def test_nested_logging_context(self) -> None:
         with LoggingContext("foo"):
             nested_context = nested_logging_context(suffix="bar")
             self.assertEqual(nested_context.name, "foo-bar")
@@ -161,11 +187,11 @@ class LoggingContextTestCase(unittest.TestCase):
 # a function which returns a deferred which has been "called", but
 # which had a function which returned another incomplete deferred on
 # its callback list, so won't yet call any other new callbacks.
-def _chained_deferred_function():
+def _chained_deferred_function() -> defer.Deferred:
     d = defer.succeed(None)
 
-    def cb(res):
-        d2 = defer.Deferred()
+    def cb(res: object) -> defer.Deferred:
+        d2: defer.Deferred = defer.Deferred()
         reactor.callLater(0, d2.callback, res)
         return d2
 
diff --git a/tests/util/test_logformatter.py b/tests/util/test_logformatter.py
index a2e08281e6..0dee69a6fe 100644
--- a/tests/util/test_logformatter.py
+++ b/tests/util/test_logformatter.py
@@ -23,7 +23,7 @@ class TestException(Exception):
 
 
 class LogFormatterTestCase(unittest.TestCase):
-    def test_formatter(self):
+    def test_formatter(self) -> None:
         formatter = LogFormatter()
 
         try:
diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py
index 67173a4f5b..1fc5a473f0 100644
--- a/tests/util/test_lrucache.py
+++ b/tests/util/test_lrucache.py
@@ -13,10 +13,11 @@
 # limitations under the License.
 
 
-from typing import List
+from typing import List, Tuple
 from unittest.mock import Mock, patch
 
 from synapse.metrics.jemalloc import JemallocStats
+from synapse.types import JsonDict
 from synapse.util.caches.lrucache import LruCache, setup_expire_lru_cache_entries
 from synapse.util.caches.treecache import TreeCache
 
@@ -25,14 +26,14 @@ from tests.unittest import override_config
 
 
 class LruCacheTestCase(unittest.HomeserverTestCase):
-    def test_get_set(self):
-        cache = LruCache(1)
+    def test_get_set(self) -> None:
+        cache: LruCache[str, str] = LruCache(1)
         cache["key"] = "value"
         self.assertEqual(cache.get("key"), "value")
         self.assertEqual(cache["key"], "value")
 
-    def test_eviction(self):
-        cache = LruCache(2)
+    def test_eviction(self) -> None:
+        cache: LruCache[int, int] = LruCache(2)
         cache[1] = 1
         cache[2] = 2
 
@@ -45,8 +46,8 @@ class LruCacheTestCase(unittest.HomeserverTestCase):
         self.assertEqual(cache.get(2), 2)
         self.assertEqual(cache.get(3), 3)
 
-    def test_setdefault(self):
-        cache = LruCache(1)
+    def test_setdefault(self) -> None:
+        cache: LruCache[str, int] = LruCache(1)
         self.assertEqual(cache.setdefault("key", 1), 1)
         self.assertEqual(cache.get("key"), 1)
         self.assertEqual(cache.setdefault("key", 2), 1)
@@ -54,14 +55,15 @@ class LruCacheTestCase(unittest.HomeserverTestCase):
         cache["key"] = 2  # Make sure overriding works.
         self.assertEqual(cache.get("key"), 2)
 
-    def test_pop(self):
-        cache = LruCache(1)
+    def test_pop(self) -> None:
+        cache: LruCache[str, int] = LruCache(1)
         cache["key"] = 1
         self.assertEqual(cache.pop("key"), 1)
         self.assertEqual(cache.pop("key"), None)
 
-    def test_del_multi(self):
-        cache = LruCache(4, cache_type=TreeCache)
+    def test_del_multi(self) -> None:
+        # The type here isn't quite correct as they don't handle TreeCache well.
+        cache: LruCache[Tuple[str, str], str] = LruCache(4, cache_type=TreeCache)
         cache[("animal", "cat")] = "mew"
         cache[("animal", "dog")] = "woof"
         cache[("vehicles", "car")] = "vroom"
@@ -71,7 +73,7 @@ class LruCacheTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(cache.get(("animal", "cat")), "mew")
         self.assertEqual(cache.get(("vehicles", "car")), "vroom")
-        cache.del_multi(("animal",))
+        cache.del_multi(("animal",))  # type: ignore[arg-type]
         self.assertEqual(len(cache), 2)
         self.assertEqual(cache.get(("animal", "cat")), None)
         self.assertEqual(cache.get(("animal", "dog")), None)
@@ -79,22 +81,22 @@ class LruCacheTestCase(unittest.HomeserverTestCase):
         self.assertEqual(cache.get(("vehicles", "train")), "chuff")
         # Man from del_multi say "Yes".
 
-    def test_clear(self):
-        cache = LruCache(1)
+    def test_clear(self) -> None:
+        cache: LruCache[str, int] = LruCache(1)
         cache["key"] = 1
         cache.clear()
         self.assertEqual(len(cache), 0)
 
     @override_config({"caches": {"per_cache_factors": {"mycache": 10}}})
-    def test_special_size(self):
-        cache = LruCache(10, "mycache")
+    def test_special_size(self) -> None:
+        cache: LruCache = LruCache(10, "mycache")
         self.assertEqual(cache.max_size, 100)
 
 
 class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
-    def test_get(self):
+    def test_get(self) -> None:
         m = Mock()
-        cache = LruCache(1)
+        cache: LruCache[str, str] = LruCache(1)
 
         cache.set("key", "value")
         self.assertFalse(m.called)
@@ -111,9 +113,9 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
         cache.set("key", "value")
         self.assertEqual(m.call_count, 1)
 
-    def test_multi_get(self):
+    def test_multi_get(self) -> None:
         m = Mock()
-        cache = LruCache(1)
+        cache: LruCache[str, str] = LruCache(1)
 
         cache.set("key", "value")
         self.assertFalse(m.called)
@@ -130,9 +132,9 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
         cache.set("key", "value")
         self.assertEqual(m.call_count, 1)
 
-    def test_set(self):
+    def test_set(self) -> None:
         m = Mock()
-        cache = LruCache(1)
+        cache: LruCache[str, str] = LruCache(1)
 
         cache.set("key", "value", callbacks=[m])
         self.assertFalse(m.called)
@@ -146,9 +148,9 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
         cache.set("key", "value")
         self.assertEqual(m.call_count, 1)
 
-    def test_pop(self):
+    def test_pop(self) -> None:
         m = Mock()
-        cache = LruCache(1)
+        cache: LruCache[str, str] = LruCache(1)
 
         cache.set("key", "value", callbacks=[m])
         self.assertFalse(m.called)
@@ -162,12 +164,13 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
         cache.pop("key")
         self.assertEqual(m.call_count, 1)
 
-    def test_del_multi(self):
+    def test_del_multi(self) -> None:
         m1 = Mock()
         m2 = Mock()
         m3 = Mock()
         m4 = Mock()
-        cache = LruCache(4, cache_type=TreeCache)
+        # The type here isn't quite correct as they don't handle TreeCache well.
+        cache: LruCache[Tuple[str, str], str] = LruCache(4, cache_type=TreeCache)
 
         cache.set(("a", "1"), "value", callbacks=[m1])
         cache.set(("a", "2"), "value", callbacks=[m2])
@@ -179,17 +182,17 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
         self.assertEqual(m3.call_count, 0)
         self.assertEqual(m4.call_count, 0)
 
-        cache.del_multi(("a",))
+        cache.del_multi(("a",))  # type: ignore[arg-type]
 
         self.assertEqual(m1.call_count, 1)
         self.assertEqual(m2.call_count, 1)
         self.assertEqual(m3.call_count, 0)
         self.assertEqual(m4.call_count, 0)
 
-    def test_clear(self):
+    def test_clear(self) -> None:
         m1 = Mock()
         m2 = Mock()
-        cache = LruCache(5)
+        cache: LruCache[str, str] = LruCache(5)
 
         cache.set("key1", "value", callbacks=[m1])
         cache.set("key2", "value", callbacks=[m2])
@@ -202,11 +205,11 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
         self.assertEqual(m1.call_count, 1)
         self.assertEqual(m2.call_count, 1)
 
-    def test_eviction(self):
+    def test_eviction(self) -> None:
         m1 = Mock(name="m1")
         m2 = Mock(name="m2")
         m3 = Mock(name="m3")
-        cache = LruCache(2)
+        cache: LruCache[str, str] = LruCache(2)
 
         cache.set("key1", "value", callbacks=[m1])
         cache.set("key2", "value", callbacks=[m2])
@@ -241,8 +244,8 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
 
 
 class LruCacheSizedTestCase(unittest.HomeserverTestCase):
-    def test_evict(self):
-        cache = LruCache(5, size_callback=len)
+    def test_evict(self) -> None:
+        cache: LruCache[str, List[int]] = LruCache(5, size_callback=len)
         cache["key1"] = [0]
         cache["key2"] = [1, 2]
         cache["key3"] = [3]
@@ -269,6 +272,7 @@ class LruCacheSizedTestCase(unittest.HomeserverTestCase):
         cache["key1"] = []
 
         self.assertEqual(len(cache), 0)
+        assert isinstance(cache.cache, dict)
         cache.cache["key1"].drop_from_cache()
         self.assertIsNone(
             cache.pop("key1"), "Cache entry should have been evicted but wasn't"
@@ -278,17 +282,17 @@ class LruCacheSizedTestCase(unittest.HomeserverTestCase):
 class TimeEvictionTestCase(unittest.HomeserverTestCase):
     """Test that time based eviction works correctly."""
 
-    def default_config(self):
+    def default_config(self) -> JsonDict:
         config = super().default_config()
 
         config.setdefault("caches", {})["expiry_time"] = "30m"
 
         return config
 
-    def test_evict(self):
+    def test_evict(self) -> None:
         setup_expire_lru_cache_entries(self.hs)
 
-        cache = LruCache(5, clock=self.hs.get_clock())
+        cache: LruCache[str, int] = LruCache(5, clock=self.hs.get_clock())
 
         # Check that we evict entries we haven't accessed for 30 minutes.
         cache["key1"] = 1
@@ -332,7 +336,7 @@ class MemoryEvictionTestCase(unittest.HomeserverTestCase):
         }
     )
     @patch("synapse.util.caches.lrucache.get_jemalloc_stats")
-    def test_evict_memory(self, jemalloc_interface) -> None:
+    def test_evict_memory(self, jemalloc_interface: Mock) -> None:
         mock_jemalloc_class = Mock(spec=JemallocStats)
         jemalloc_interface.return_value = mock_jemalloc_class
 
@@ -340,7 +344,7 @@ class MemoryEvictionTestCase(unittest.HomeserverTestCase):
         mock_jemalloc_class.get_stat.return_value = 924288000
 
         setup_expire_lru_cache_entries(self.hs)
-        cache = LruCache(4, clock=self.hs.get_clock())
+        cache: LruCache[str, int] = LruCache(4, clock=self.hs.get_clock())
 
         cache["key1"] = 1
         cache["key2"] = 2
diff --git a/tests/util/test_macaroons.py b/tests/util/test_macaroons.py
index 40754a4711..f68377a05a 100644
--- a/tests/util/test_macaroons.py
+++ b/tests/util/test_macaroons.py
@@ -21,14 +21,14 @@ from tests.unittest import TestCase
 
 
 class MacaroonGeneratorTestCase(TestCase):
-    def setUp(self):
+    def setUp(self) -> None:
         self.reactor, hs_clock = get_clock()
         self.macaroon_generator = MacaroonGenerator(hs_clock, "tesths", b"verysecret")
         self.other_macaroon_generator = MacaroonGenerator(
             hs_clock, "tesths", b"anothersecretkey"
         )
 
-    def test_guest_access_token(self):
+    def test_guest_access_token(self) -> None:
         """Test the generation and verification of guest access tokens"""
         token = self.macaroon_generator.generate_guest_access_token("@user:tesths")
         user_id = self.macaroon_generator.verify_guest_token(token)
@@ -47,7 +47,7 @@ class MacaroonGeneratorTestCase(TestCase):
         with self.assertRaises(MacaroonVerificationFailedException):
             self.macaroon_generator.verify_guest_token(token)
 
-    def test_delete_pusher_token(self):
+    def test_delete_pusher_token(self) -> None:
         """Test the generation and verification of delete_pusher tokens"""
         token = self.macaroon_generator.generate_delete_pusher_token(
             "@user:tesths", "m.mail", "john@example.com"
@@ -84,7 +84,7 @@ class MacaroonGeneratorTestCase(TestCase):
         )
         self.assertEqual(user_id, "@user:tesths")
 
-    def test_oidc_session_token(self):
+    def test_oidc_session_token(self) -> None:
         """Test the generation and verification of OIDC session cookies"""
         state = "arandomstate"
         session_data = OidcSessionData(
diff --git a/tests/util/test_ratelimitutils.py b/tests/util/test_ratelimitutils.py
index 89d8656634..5b327b390e 100644
--- a/tests/util/test_ratelimitutils.py
+++ b/tests/util/test_ratelimitutils.py
@@ -13,16 +13,19 @@
 # limitations under the License.
 from typing import Optional
 
+from twisted.internet.defer import Deferred
+
 from synapse.config.homeserver import HomeServerConfig
+from synapse.config.ratelimiting import FederationRatelimitSettings
 from synapse.util.ratelimitutils import FederationRateLimiter
 
-from tests.server import get_clock
+from tests.server import ThreadedMemoryReactorClock, get_clock
 from tests.unittest import TestCase
 from tests.utils import default_config
 
 
 class FederationRateLimiterTestCase(TestCase):
-    def test_ratelimit(self):
+    def test_ratelimit(self) -> None:
         """A simple test with the default values"""
         reactor, clock = get_clock()
         rc_config = build_rc_config()
@@ -32,7 +35,7 @@ class FederationRateLimiterTestCase(TestCase):
             # shouldn't block
             self.successResultOf(d1)
 
-    def test_concurrent_limit(self):
+    def test_concurrent_limit(self) -> None:
         """Test what happens when we hit the concurrent limit"""
         reactor, clock = get_clock()
         rc_config = build_rc_config({"rc_federation": {"concurrent": 2}})
@@ -56,7 +59,7 @@ class FederationRateLimiterTestCase(TestCase):
             cm2.__exit__(None, None, None)
             self.successResultOf(d3)
 
-    def test_sleep_limit(self):
+    def test_sleep_limit(self) -> None:
         """Test what happens when we hit the sleep limit"""
         reactor, clock = get_clock()
         rc_config = build_rc_config(
@@ -79,7 +82,7 @@ class FederationRateLimiterTestCase(TestCase):
             self.assertAlmostEqual(sleep_time, 500, places=3)
 
 
-def _await_resolution(reactor, d):
+def _await_resolution(reactor: ThreadedMemoryReactorClock, d: Deferred) -> float:
     """advance the clock until the deferred completes.
 
     Returns the number of milliseconds it took to complete.
@@ -90,7 +93,7 @@ def _await_resolution(reactor, d):
     return (reactor.seconds() - start_time) * 1000
 
 
-def build_rc_config(settings: Optional[dict] = None):
+def build_rc_config(settings: Optional[dict] = None) -> FederationRatelimitSettings:
     config_dict = default_config("test")
     config_dict.update(settings or {})
     config = HomeServerConfig()
diff --git a/tests/util/test_retryutils.py b/tests/util/test_retryutils.py
index 26cb71c640..9529ee53c8 100644
--- a/tests/util/test_retryutils.py
+++ b/tests/util/test_retryutils.py
@@ -22,7 +22,7 @@ from tests.unittest import HomeserverTestCase
 
 
 class RetryLimiterTestCase(HomeserverTestCase):
-    def test_new_destination(self):
+    def test_new_destination(self) -> None:
         """A happy-path case with a new destination and a successful operation"""
         store = self.hs.get_datastores().main
         limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store))
@@ -36,7 +36,7 @@ class RetryLimiterTestCase(HomeserverTestCase):
         new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
         self.assertIsNone(new_timings)
 
-    def test_limiter(self):
+    def test_limiter(self) -> None:
         """General test case which walks through the process of a failing request"""
         store = self.hs.get_datastores().main
 
diff --git a/tests/util/test_rwlock.py b/tests/util/test_rwlock.py
index 5da04362a9..bc93de62eb 100644
--- a/tests/util/test_rwlock.py
+++ b/tests/util/test_rwlock.py
@@ -49,7 +49,7 @@ class ReadWriteLockTestCase(unittest.TestCase):
         acquired_d: "Deferred[None]" = Deferred()
         unblock_d: "Deferred[None]" = Deferred()
 
-        async def reader_or_writer():
+        async def reader_or_writer() -> str:
             async with read_or_write(key):
                 acquired_d.callback(None)
                 await unblock_d
@@ -134,7 +134,7 @@ class ReadWriteLockTestCase(unittest.TestCase):
                 d.called, msg="deferred %d was unexpectedly resolved" % (i + n)
             )
 
-    def test_rwlock(self):
+    def test_rwlock(self) -> None:
         rwlock = ReadWriteLock()
         key = "key"
 
@@ -197,7 +197,7 @@ class ReadWriteLockTestCase(unittest.TestCase):
         _, acquired_d = self._start_nonblocking_reader(rwlock, key, "last reader")
         self.assertTrue(acquired_d.called)
 
-    def test_lock_handoff_to_nonblocking_writer(self):
+    def test_lock_handoff_to_nonblocking_writer(self) -> None:
         """Test a writer handing the lock to another writer that completes instantly."""
         rwlock = ReadWriteLock()
         key = "key"
@@ -216,7 +216,7 @@ class ReadWriteLockTestCase(unittest.TestCase):
         d3, _ = self._start_nonblocking_writer(rwlock, key, "write 3 completed")
         self.assertTrue(d3.called)
 
-    def test_cancellation_while_holding_read_lock(self):
+    def test_cancellation_while_holding_read_lock(self) -> None:
         """Test cancellation while holding a read lock.
 
         A waiting writer should be given the lock when the reader holding the lock is
@@ -242,7 +242,7 @@ class ReadWriteLockTestCase(unittest.TestCase):
         )
         self.assertEqual("write completed", self.successResultOf(writer_d))
 
-    def test_cancellation_while_holding_write_lock(self):
+    def test_cancellation_while_holding_write_lock(self) -> None:
         """Test cancellation while holding a write lock.
 
         A waiting reader should be given the lock when the writer holding the lock is
@@ -268,7 +268,7 @@ class ReadWriteLockTestCase(unittest.TestCase):
         )
         self.assertEqual("read completed", self.successResultOf(reader_d))
 
-    def test_cancellation_while_waiting_for_read_lock(self):
+    def test_cancellation_while_waiting_for_read_lock(self) -> None:
         """Test cancellation while waiting for a read lock.
 
         Tests that cancelling a waiting reader:
@@ -319,7 +319,7 @@ class ReadWriteLockTestCase(unittest.TestCase):
         )
         self.assertEqual("write 2 completed", self.successResultOf(writer2_d))
 
-    def test_cancellation_while_waiting_for_write_lock(self):
+    def test_cancellation_while_waiting_for_write_lock(self) -> None:
         """Test cancellation while waiting for a write lock.
 
         Tests that cancelling a waiting writer:
diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py
index 9ed01f7e0c..1b0fa52ad1 100644
--- a/tests/util/test_stream_change_cache.py
+++ b/tests/util/test_stream_change_cache.py
@@ -8,7 +8,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
     Tests for StreamChangeCache.
     """
 
-    def test_prefilled_cache(self):
+    def test_prefilled_cache(self) -> None:
         """
         Providing a prefilled cache to StreamChangeCache will result in a cache
         with the prefilled-cache entered in.
@@ -16,7 +16,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
         cache = StreamChangeCache("#test", 1, prefilled_cache={"user@foo.com": 2})
         self.assertTrue(cache.has_entity_changed("user@foo.com", 1))
 
-    def test_has_entity_changed(self):
+    def test_has_entity_changed(self) -> None:
         """
         StreamChangeCache.entity_has_changed will mark entities as changed, and
         has_entity_changed will observe the changed entities.
@@ -52,7 +52,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
         self.assertTrue(cache.has_entity_changed("user@foo.com", 0))
         self.assertTrue(cache.has_entity_changed("not@here.website", 0))
 
-    def test_entity_has_changed_pops_off_start(self):
+    def test_entity_has_changed_pops_off_start(self) -> None:
         """
         StreamChangeCache.entity_has_changed will respect the max size and
         purge the oldest items upon reaching that max size.
@@ -86,7 +86,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
         )
         self.assertIsNone(cache.get_all_entities_changed(1))
 
-    def test_get_all_entities_changed(self):
+    def test_get_all_entities_changed(self) -> None:
         """
         StreamChangeCache.get_all_entities_changed will return all changed
         entities since the given position.  If the position is before the start
@@ -142,7 +142,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
         r = cache.get_all_entities_changed(3)
         self.assertTrue(r == ok1 or r == ok2)
 
-    def test_has_any_entity_changed(self):
+    def test_has_any_entity_changed(self) -> None:
         """
         StreamChangeCache.has_any_entity_changed will return True if any
         entities have been changed since the provided stream position, and
@@ -168,7 +168,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
         self.assertFalse(cache.has_any_entity_changed(2))
         self.assertFalse(cache.has_any_entity_changed(3))
 
-    def test_get_entities_changed(self):
+    def test_get_entities_changed(self) -> None:
         """
         StreamChangeCache.get_entities_changed will return the entities in the
         given list that have changed since the provided stream ID.  If the
@@ -228,7 +228,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
             {"bar@baz.net"},
         )
 
-    def test_max_pos(self):
+    def test_max_pos(self) -> None:
         """
         StreamChangeCache.get_max_pos_of_last_change will return the most
         recent point where the entity could have changed.  If the entity is not
diff --git a/tests/util/test_stringutils.py b/tests/util/test_stringutils.py
index ad4dd7f007..f137e05191 100644
--- a/tests/util/test_stringutils.py
+++ b/tests/util/test_stringutils.py
@@ -19,7 +19,7 @@ from .. import unittest
 
 
 class StringUtilsTestCase(unittest.TestCase):
-    def test_client_secret_regex(self):
+    def test_client_secret_regex(self) -> None:
         """Ensure that client_secret does not contain illegal characters"""
         good = [
             "abcde12345",
@@ -46,7 +46,7 @@ class StringUtilsTestCase(unittest.TestCase):
             with self.assertRaises(SynapseError):
                 assert_valid_client_secret(client_secret)
 
-    def test_base62_encode(self):
+    def test_base62_encode(self) -> None:
         self.assertEqual("0", base62_encode(0))
         self.assertEqual("10", base62_encode(62))
         self.assertEqual("1c", base62_encode(100))
diff --git a/tests/util/test_threepids.py b/tests/util/test_threepids.py
index d957b953bb..3b35b8e4ec 100644
--- a/tests/util/test_threepids.py
+++ b/tests/util/test_threepids.py
@@ -18,31 +18,31 @@ from tests.unittest import HomeserverTestCase
 
 
 class CanonicaliseEmailTests(HomeserverTestCase):
-    def test_no_at(self):
+    def test_no_at(self) -> None:
         with self.assertRaises(ValueError):
             canonicalise_email("address-without-at.bar")
 
-    def test_two_at(self):
+    def test_two_at(self) -> None:
         with self.assertRaises(ValueError):
             canonicalise_email("foo@foo@test.bar")
 
-    def test_bad_format(self):
+    def test_bad_format(self) -> None:
         with self.assertRaises(ValueError):
             canonicalise_email("user@bad.example.net@good.example.com")
 
-    def test_valid_format(self):
+    def test_valid_format(self) -> None:
         self.assertEqual(canonicalise_email("foo@test.bar"), "foo@test.bar")
 
-    def test_domain_to_lower(self):
+    def test_domain_to_lower(self) -> None:
         self.assertEqual(canonicalise_email("foo@TEST.BAR"), "foo@test.bar")
 
-    def test_domain_with_umlaut(self):
+    def test_domain_with_umlaut(self) -> None:
         self.assertEqual(canonicalise_email("foo@Öumlaut.com"), "foo@öumlaut.com")
 
-    def test_address_casefold(self):
+    def test_address_casefold(self) -> None:
         self.assertEqual(
             canonicalise_email("Strauß@Example.com"), "strauss@example.com"
         )
 
-    def test_address_trim(self):
+    def test_address_trim(self) -> None:
         self.assertEqual(canonicalise_email(" foo@test.bar "), "foo@test.bar")
diff --git a/tests/util/test_treecache.py b/tests/util/test_treecache.py
index 567cb18468..fe3b4dc6a4 100644
--- a/tests/util/test_treecache.py
+++ b/tests/util/test_treecache.py
@@ -19,7 +19,7 @@ from .. import unittest
 
 
 class TreeCacheTestCase(unittest.TestCase):
-    def test_get_set_onelevel(self):
+    def test_get_set_onelevel(self) -> None:
         cache = TreeCache()
         cache[("a",)] = "A"
         cache[("b",)] = "B"
@@ -27,7 +27,7 @@ class TreeCacheTestCase(unittest.TestCase):
         self.assertEqual(cache.get(("b",)), "B")
         self.assertEqual(len(cache), 2)
 
-    def test_pop_onelevel(self):
+    def test_pop_onelevel(self) -> None:
         cache = TreeCache()
         cache[("a",)] = "A"
         cache[("b",)] = "B"
@@ -36,7 +36,7 @@ class TreeCacheTestCase(unittest.TestCase):
         self.assertEqual(cache.get(("b",)), "B")
         self.assertEqual(len(cache), 1)
 
-    def test_get_set_twolevel(self):
+    def test_get_set_twolevel(self) -> None:
         cache = TreeCache()
         cache[("a", "a")] = "AA"
         cache[("a", "b")] = "AB"
@@ -46,7 +46,7 @@ class TreeCacheTestCase(unittest.TestCase):
         self.assertEqual(cache.get(("b", "a")), "BA")
         self.assertEqual(len(cache), 3)
 
-    def test_pop_twolevel(self):
+    def test_pop_twolevel(self) -> None:
         cache = TreeCache()
         cache[("a", "a")] = "AA"
         cache[("a", "b")] = "AB"
@@ -58,7 +58,7 @@ class TreeCacheTestCase(unittest.TestCase):
         self.assertEqual(cache.pop(("b", "a")), None)
         self.assertEqual(len(cache), 1)
 
-    def test_pop_mixedlevel(self):
+    def test_pop_mixedlevel(self) -> None:
         cache = TreeCache()
         cache[("a", "a")] = "AA"
         cache[("a", "b")] = "AB"
@@ -72,14 +72,14 @@ class TreeCacheTestCase(unittest.TestCase):
 
         self.assertEqual({"AA", "AB"}, set(iterate_tree_cache_entry(popped)))
 
-    def test_clear(self):
+    def test_clear(self) -> None:
         cache = TreeCache()
         cache[("a",)] = "A"
         cache[("b",)] = "B"
         cache.clear()
         self.assertEqual(len(cache), 0)
 
-    def test_contains(self):
+    def test_contains(self) -> None:
         cache = TreeCache()
         cache[("a",)] = "A"
         self.assertTrue(("a",) in cache)
diff --git a/tests/util/test_wheel_timer.py b/tests/util/test_wheel_timer.py
index 0d5039de04..c9d22b6d8c 100644
--- a/tests/util/test_wheel_timer.py
+++ b/tests/util/test_wheel_timer.py
@@ -18,8 +18,8 @@ from .. import unittest
 
 
 class WheelTimerTestCase(unittest.TestCase):
-    def test_single_insert_fetch(self):
-        wheel = WheelTimer(bucket_size=5)
+    def test_single_insert_fetch(self) -> None:
+        wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
 
         obj = object()
         wheel.insert(100, obj, 150)
@@ -32,8 +32,8 @@ class WheelTimerTestCase(unittest.TestCase):
         self.assertListEqual(wheel.fetch(156), [obj])
         self.assertListEqual(wheel.fetch(170), [])
 
-    def test_multi_insert(self):
-        wheel = WheelTimer(bucket_size=5)
+    def test_multi_insert(self) -> None:
+        wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
 
         obj1 = object()
         obj2 = object()
@@ -50,15 +50,15 @@ class WheelTimerTestCase(unittest.TestCase):
         self.assertListEqual(wheel.fetch(200), [obj3])
         self.assertListEqual(wheel.fetch(210), [])
 
-    def test_insert_past(self):
-        wheel = WheelTimer(bucket_size=5)
+    def test_insert_past(self) -> None:
+        wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
 
         obj = object()
         wheel.insert(100, obj, 50)
         self.assertListEqual(wheel.fetch(120), [obj])
 
-    def test_insert_past_multi(self):
-        wheel = WheelTimer(bucket_size=5)
+    def test_insert_past_multi(self) -> None:
+        wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
 
         obj1 = object()
         obj2 = object()