summary refs log tree commit diff
path: root/tests/util/test_logcontext.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/util/test_logcontext.py')
-rw-r--r--tests/util/test_logcontext.py86
1 files changed, 56 insertions, 30 deletions
diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py
index 2ad321e184..d64c162e1d 100644
--- a/tests/util/test_logcontext.py
+++ b/tests/util/test_logcontext.py
@@ -1,5 +1,21 @@
+# Copyright 2014-2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Callable, Generator, cast
+
 import twisted.python.failure
-from twisted.internet import defer, reactor
+from twisted.internet import defer, reactor as _reactor
 
 from synapse.logging.context import (
     SENTINEL_CONTEXT,
@@ -10,25 +26,30 @@ from synapse.logging.context import (
     nested_logging_context,
     run_in_background,
 )
+from synapse.types import ISynapseReactor
 from synapse.util import Clock
 
 from .. import unittest
 
+reactor = cast(ISynapseReactor, _reactor)
+
 
 class LoggingContextTestCase(unittest.TestCase):
-    def _check_test_key(self, value):
-        self.assertEqual(current_context().name, value)
+    def _check_test_key(self, value: str) -> None:
+        context = current_context()
+        assert isinstance(context, LoggingContext)
+        self.assertEqual(context.name, value)
 
-    def test_with_context(self):
+    def test_with_context(self) -> None:
         with LoggingContext("test"):
             self._check_test_key("test")
 
     @defer.inlineCallbacks
-    def test_sleep(self):
+    def test_sleep(self) -> Generator["defer.Deferred[object]", object, None]:
         clock = Clock(reactor)
 
         @defer.inlineCallbacks
-        def competing_callback():
+        def competing_callback() -> Generator["defer.Deferred[object]", object, None]:
             with LoggingContext("competing"):
                 yield clock.sleep(0)
                 self._check_test_key("competing")
@@ -39,17 +60,18 @@ class LoggingContextTestCase(unittest.TestCase):
             yield clock.sleep(0)
             self._check_test_key("one")
 
-    def _test_run_in_background(self, function):
+    def _test_run_in_background(self, function: Callable[[], object]) -> defer.Deferred:
         sentinel_context = current_context()
 
-        callback_completed = [False]
+        callback_completed = False
 
         with LoggingContext("one"):
             # fire off function, but don't wait on it.
             d2 = run_in_background(function)
 
-            def cb(res):
-                callback_completed[0] = True
+            def cb(res: object) -> object:
+                nonlocal callback_completed
+                callback_completed = True
                 return res
 
             d2.addCallback(cb)
@@ -60,8 +82,8 @@ class LoggingContextTestCase(unittest.TestCase):
         # the logcontext is left in a sane state.
         d2 = defer.Deferred()
 
-        def check_logcontext():
-            if not callback_completed[0]:
+        def check_logcontext() -> None:
+            if not callback_completed:
                 reactor.callLater(0.01, check_logcontext)
                 return
 
@@ -78,31 +100,31 @@ class LoggingContextTestCase(unittest.TestCase):
         # test is done once d2 finishes
         return d2
 
-    def test_run_in_background_with_blocking_fn(self):
+    def test_run_in_background_with_blocking_fn(self) -> defer.Deferred:
         @defer.inlineCallbacks
-        def blocking_function():
+        def blocking_function() -> Generator["defer.Deferred[object]", object, None]:
             yield Clock(reactor).sleep(0)
 
         return self._test_run_in_background(blocking_function)
 
-    def test_run_in_background_with_non_blocking_fn(self):
+    def test_run_in_background_with_non_blocking_fn(self) -> defer.Deferred:
         @defer.inlineCallbacks
-        def nonblocking_function():
+        def nonblocking_function() -> Generator["defer.Deferred[object]", object, None]:
             with PreserveLoggingContext():
                 yield defer.succeed(None)
 
         return self._test_run_in_background(nonblocking_function)
 
-    def test_run_in_background_with_chained_deferred(self):
+    def test_run_in_background_with_chained_deferred(self) -> defer.Deferred:
         # a function which returns a deferred which looks like it has been
         # called, but is actually paused
-        def testfunc():
+        def testfunc() -> defer.Deferred:
             return make_deferred_yieldable(_chained_deferred_function())
 
         return self._test_run_in_background(testfunc)
 
-    def test_run_in_background_with_coroutine(self):
-        async def testfunc():
+    def test_run_in_background_with_coroutine(self) -> defer.Deferred:
+        async def testfunc() -> None:
             self._check_test_key("one")
             d = Clock(reactor).sleep(0)
             self.assertIs(current_context(), SENTINEL_CONTEXT)
@@ -111,18 +133,20 @@ class LoggingContextTestCase(unittest.TestCase):
 
         return self._test_run_in_background(testfunc)
 
-    def test_run_in_background_with_nonblocking_coroutine(self):
-        async def testfunc():
+    def test_run_in_background_with_nonblocking_coroutine(self) -> defer.Deferred:
+        async def testfunc() -> None:
             self._check_test_key("one")
 
         return self._test_run_in_background(testfunc)
 
     @defer.inlineCallbacks
-    def test_make_deferred_yieldable(self):
+    def test_make_deferred_yieldable(
+        self,
+    ) -> Generator["defer.Deferred[object]", object, None]:
         # a function which returns an incomplete deferred, but doesn't follow
         # the synapse rules.
-        def blocking_function():
-            d = defer.Deferred()
+        def blocking_function() -> defer.Deferred:
+            d: defer.Deferred = defer.Deferred()
             reactor.callLater(0, d.callback, None)
             return d
 
@@ -139,7 +163,9 @@ class LoggingContextTestCase(unittest.TestCase):
             self._check_test_key("one")
 
     @defer.inlineCallbacks
-    def test_make_deferred_yieldable_with_chained_deferreds(self):
+    def test_make_deferred_yieldable_with_chained_deferreds(
+        self,
+    ) -> Generator["defer.Deferred[object]", object, None]:
         sentinel_context = current_context()
 
         with LoggingContext("one"):
@@ -152,7 +178,7 @@ class LoggingContextTestCase(unittest.TestCase):
             # now it should be restored
             self._check_test_key("one")
 
-    def test_nested_logging_context(self):
+    def test_nested_logging_context(self) -> None:
         with LoggingContext("foo"):
             nested_context = nested_logging_context(suffix="bar")
             self.assertEqual(nested_context.name, "foo-bar")
@@ -161,11 +187,11 @@ class LoggingContextTestCase(unittest.TestCase):
 # a function which returns a deferred which has been "called", but
 # which had a function which returned another incomplete deferred on
 # its callback list, so won't yet call any other new callbacks.
-def _chained_deferred_function():
+def _chained_deferred_function() -> defer.Deferred:
     d = defer.succeed(None)
 
-    def cb(res):
-        d2 = defer.Deferred()
+    def cb(res: object) -> defer.Deferred:
+        d2: defer.Deferred = defer.Deferred()
         reactor.callLater(0, d2.callback, res)
         return d2