summary refs log tree commit diff
path: root/tests/util/caches/test_descriptors.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/util/caches/test_descriptors.py')
-rw-r--r--tests/util/caches/test_descriptors.py22
1 files changed, 14 insertions, 8 deletions
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 43475a307f..13f1edd533 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -13,11 +13,12 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import Iterable, Set, Tuple
+from typing import Iterable, Set, Tuple, cast
 from unittest import mock
 
 from twisted.internet import defer, reactor
 from twisted.internet.defer import CancelledError, Deferred
+from twisted.internet.interfaces import IReactorTime
 
 from synapse.api.errors import SynapseError
 from synapse.logging.context import (
@@ -37,8 +38,8 @@ logger = logging.getLogger(__name__)
 
 
 def run_on_reactor():
-    d = defer.Deferred()
-    reactor.callLater(0, d.callback, 0)
+    d: "Deferred[int]" = defer.Deferred()
+    cast(IReactorTime, reactor).callLater(0, d.callback, 0)
     return make_deferred_yieldable(d)
 
 
@@ -224,7 +225,8 @@ class DescriptorTestCase(unittest.TestCase):
         callbacks: Set[str] = set()
 
         # set off an asynchronous request
-        obj.result = origin_d = defer.Deferred()
+        origin_d: Deferred = defer.Deferred()
+        obj.result = origin_d
 
         d1 = obj.fn(1, on_invalidate=lambda: callbacks.add("d1"))
         self.assertFalse(d1.called)
@@ -262,7 +264,7 @@ class DescriptorTestCase(unittest.TestCase):
         """Check that logcontexts are set and restored correctly when
         using the cache."""
 
-        complete_lookup = defer.Deferred()
+        complete_lookup: Deferred = defer.Deferred()
 
         class Cls:
             @descriptors.cached()
@@ -772,10 +774,14 @@ class CachedListDescriptorTestCase(unittest.TestCase):
 
             @descriptors.cachedList(cached_method_name="fn", list_name="args1")
             async def list_fn(self, args1, arg2):
-                assert current_context().name == "c1"
+                context = current_context()
+                assert isinstance(context, LoggingContext)
+                assert context.name == "c1"
                 # we want this to behave like an asynchronous function
                 await run_on_reactor()
-                assert current_context().name == "c1"
+                context = current_context()
+                assert isinstance(context, LoggingContext)
+                assert context.name == "c1"
                 return self.mock(args1, arg2)
 
         with LoggingContext("c1") as c1:
@@ -834,7 +840,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
                 return self.mock(args1)
 
         obj = Cls()
-        deferred_result = Deferred()
+        deferred_result: "Deferred[dict]" = Deferred()
         obj.mock.return_value = deferred_result
 
         # start off several concurrent lookups of the same key