summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/logging/__init__.py6
-rw-r--r--tests/logging/test_opentracing.py4
-rw-r--r--tests/logging/test_remote_handler.py25
-rw-r--r--tests/logging/test_terse_json.py30
-rw-r--r--tests/rest/client/test_transactions.py42
5 files changed, 69 insertions, 38 deletions
diff --git a/tests/logging/__init__.py b/tests/logging/__init__.py
index 1acf5666a8..1c5de95a80 100644
--- a/tests/logging/__init__.py
+++ b/tests/logging/__init__.py
@@ -13,9 +13,11 @@
 # limitations under the License.
 import logging
 
+from tests.unittest import TestCase
 
-class LoggerCleanupMixin:
-    def get_logger(self, handler):
+
+class LoggerCleanupMixin(TestCase):
+    def get_logger(self, handler: logging.Handler) -> logging.Logger:
         """
         Attach a handler to a logger and add clean-ups to remove revert this.
         """
diff --git a/tests/logging/test_opentracing.py b/tests/logging/test_opentracing.py
index 0917e478a5..e28ba84cc2 100644
--- a/tests/logging/test_opentracing.py
+++ b/tests/logging/test_opentracing.py
@@ -153,7 +153,7 @@ class LogContextScopeManagerTestCase(TestCase):
 
         scopes = []
 
-        async def task(i: int):
+        async def task(i: int) -> None:
             scope = start_active_span(
                 f"task{i}",
                 tracer=self._tracer,
@@ -165,7 +165,7 @@ class LogContextScopeManagerTestCase(TestCase):
             self.assertEqual(self._tracer.active_span, scope.span)
             scope.close()
 
-        async def root():
+        async def root() -> None:
             with start_active_span("root span", tracer=self._tracer) as root_scope:
                 self.assertEqual(self._tracer.active_span, root_scope.span)
                 scopes.append(root_scope)
diff --git a/tests/logging/test_remote_handler.py b/tests/logging/test_remote_handler.py
index b0d046fe00..c08954d887 100644
--- a/tests/logging/test_remote_handler.py
+++ b/tests/logging/test_remote_handler.py
@@ -11,7 +11,10 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from twisted.test.proto_helpers import AccumulatingProtocol
+from typing import Tuple
+
+from twisted.internet.protocol import Protocol
+from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
 
 from synapse.logging import RemoteHandler
 
@@ -20,7 +23,9 @@ from tests.server import FakeTransport, get_clock
 from tests.unittest import TestCase
 
 
-def connect_logging_client(reactor, client_id):
+def connect_logging_client(
+    reactor: MemoryReactorClock, client_id: int
+) -> Tuple[Protocol, AccumulatingProtocol]:
     # This is essentially tests.server.connect_client, but disabling autoflush on
     # the client transport. This is necessary to avoid an infinite loop due to
     # sending of data via the logging transport causing additional logs to be
@@ -35,10 +40,10 @@ def connect_logging_client(reactor, client_id):
 
 
 class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase):
-    def setUp(self):
+    def setUp(self) -> None:
         self.reactor, _ = get_clock()
 
-    def test_log_output(self):
+    def test_log_output(self) -> None:
         """
         The remote handler delivers logs over TCP.
         """
@@ -51,6 +56,7 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase):
         client, server = connect_logging_client(self.reactor, 0)
 
         # Trigger data being sent
+        assert isinstance(client.transport, FakeTransport)
         client.transport.flush()
 
         # One log message, with a single trailing newline
@@ -61,7 +67,7 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase):
         # Ensure the data passed through properly.
         self.assertEqual(logs[0], "Hello there, wally!")
 
-    def test_log_backpressure_debug(self):
+    def test_log_backpressure_debug(self) -> None:
         """
         When backpressure is hit, DEBUG logs will be shed.
         """
@@ -83,6 +89,7 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase):
 
         # Allow the reconnection
         client, server = connect_logging_client(self.reactor, 0)
+        assert isinstance(client.transport, FakeTransport)
         client.transport.flush()
 
         # Only the 7 infos made it through, the debugs were elided
@@ -90,7 +97,7 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase):
         self.assertEqual(len(logs), 7)
         self.assertNotIn(b"debug", server.data)
 
-    def test_log_backpressure_info(self):
+    def test_log_backpressure_info(self) -> None:
         """
         When backpressure is hit, DEBUG and INFO logs will be shed.
         """
@@ -116,6 +123,7 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase):
 
         # Allow the reconnection
         client, server = connect_logging_client(self.reactor, 0)
+        assert isinstance(client.transport, FakeTransport)
         client.transport.flush()
 
         # The 10 warnings made it through, the debugs and infos were elided
@@ -124,7 +132,7 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase):
         self.assertNotIn(b"debug", server.data)
         self.assertNotIn(b"info", server.data)
 
-    def test_log_backpressure_cut_middle(self):
+    def test_log_backpressure_cut_middle(self) -> None:
         """
         When backpressure is hit, and no more DEBUG and INFOs cannot be culled,
         it will cut the middle messages out.
@@ -140,6 +148,7 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase):
 
         # Allow the reconnection
         client, server = connect_logging_client(self.reactor, 0)
+        assert isinstance(client.transport, FakeTransport)
         client.transport.flush()
 
         # The first five and last five warnings made it through, the debugs and
@@ -151,7 +160,7 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase):
             logs,
         )
 
-    def test_cancel_connection(self):
+    def test_cancel_connection(self) -> None:
         """
         Gracefully handle the connection being cancelled.
         """
diff --git a/tests/logging/test_terse_json.py b/tests/logging/test_terse_json.py
index 0b0d8737c1..fa27f1279a 100644
--- a/tests/logging/test_terse_json.py
+++ b/tests/logging/test_terse_json.py
@@ -14,24 +14,28 @@
 import json
 import logging
 from io import BytesIO, StringIO
+from typing import cast
 from unittest.mock import Mock, patch
 
+from twisted.web.http import HTTPChannel
 from twisted.web.server import Request
 
 from synapse.http.site import SynapseRequest
 from synapse.logging._terse_json import JsonFormatter, TerseJsonFormatter
 from synapse.logging.context import LoggingContext, LoggingContextFilter
+from synapse.types import JsonDict
 
 from tests.logging import LoggerCleanupMixin
-from tests.server import FakeChannel
+from tests.server import FakeChannel, get_clock
 from tests.unittest import TestCase
 
 
 class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
-    def setUp(self):
+    def setUp(self) -> None:
         self.output = StringIO()
+        self.reactor, _ = get_clock()
 
-    def get_log_line(self):
+    def get_log_line(self) -> JsonDict:
         # One log message, with a single trailing newline.
         data = self.output.getvalue()
         logs = data.splitlines()
@@ -39,7 +43,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
         self.assertEqual(data.count("\n"), 1)
         return json.loads(logs[0])
 
-    def test_terse_json_output(self):
+    def test_terse_json_output(self) -> None:
         """
         The Terse JSON formatter converts log messages to JSON.
         """
@@ -61,7 +65,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
         self.assertCountEqual(log.keys(), expected_log_keys)
         self.assertEqual(log["log"], "Hello there, wally!")
 
-    def test_extra_data(self):
+    def test_extra_data(self) -> None:
         """
         Additional information can be included in the structured logging.
         """
@@ -93,7 +97,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
         self.assertEqual(log["int"], 3)
         self.assertIs(log["bool"], True)
 
-    def test_json_output(self):
+    def test_json_output(self) -> None:
         """
         The Terse JSON formatter converts log messages to JSON.
         """
@@ -114,7 +118,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
         self.assertCountEqual(log.keys(), expected_log_keys)
         self.assertEqual(log["log"], "Hello there, wally!")
 
-    def test_with_context(self):
+    def test_with_context(self) -> None:
         """
         The logging context should be added to the JSON response.
         """
@@ -139,7 +143,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
         self.assertEqual(log["log"], "Hello there, wally!")
         self.assertEqual(log["request"], "name")
 
-    def test_with_request_context(self):
+    def test_with_request_context(self) -> None:
         """
         Information from the logging context request should be added to the JSON response.
         """
@@ -154,11 +158,13 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
         site.server_version_string = "Server v1"
         site.reactor = Mock()
         site.experimental_cors_msc3886 = False
-        request = SynapseRequest(FakeChannel(site, None), site)
+        request = SynapseRequest(
+            cast(HTTPChannel, FakeChannel(site, self.reactor)), site
+        )
         # Call requestReceived to finish instantiating the object.
         request.content = BytesIO()
-        # Partially skip some of the internal processing of SynapseRequest.
-        request._started_processing = Mock()
+        # Partially skip some internal processing of SynapseRequest.
+        request._started_processing = Mock()  # type: ignore[assignment]
         request.request_metrics = Mock(spec=["name"])
         with patch.object(Request, "render"):
             request.requestReceived(b"POST", b"/_matrix/client/versions", b"1.1")
@@ -200,7 +206,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
         self.assertEqual(log["protocol"], "1.1")
         self.assertEqual(log["user_agent"], "")
 
-    def test_with_exception(self):
+    def test_with_exception(self) -> None:
         """
         The logging exception type & value should be added to the JSON response.
         """
diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py
index 21a1ca2a68..3086e1b565 100644
--- a/tests/rest/client/test_transactions.py
+++ b/tests/rest/client/test_transactions.py
@@ -13,18 +13,22 @@
 # limitations under the License.
 
 from http import HTTPStatus
+from typing import Any, Generator, Tuple, cast
 from unittest.mock import Mock, call
 
-from twisted.internet import defer, reactor
+from twisted.internet import defer, reactor as _reactor
 
 from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context
 from synapse.rest.client.transactions import CLEANUP_PERIOD_MS, HttpTransactionCache
+from synapse.types import ISynapseReactor, JsonDict
 from synapse.util import Clock
 
 from tests import unittest
 from tests.test_utils import make_awaitable
 from tests.utils import MockClock
 
+reactor = cast(ISynapseReactor, _reactor)
+
 
 class HttpTransactionCacheTestCase(unittest.TestCase):
     def setUp(self) -> None:
@@ -34,11 +38,13 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
         self.hs.get_auth = Mock()
         self.cache = HttpTransactionCache(self.hs)
 
-        self.mock_http_response = (HTTPStatus.OK, "GOOD JOB!")
+        self.mock_http_response = (HTTPStatus.OK, {"result": "GOOD JOB!"})
         self.mock_key = "foo"
 
     @defer.inlineCallbacks
-    def test_executes_given_function(self):
+    def test_executes_given_function(
+        self,
+    ) -> Generator["defer.Deferred[Any]", object, None]:
         cb = Mock(return_value=make_awaitable(self.mock_http_response))
         res = yield self.cache.fetch_or_execute(
             self.mock_key, cb, "some_arg", keyword="arg"
@@ -47,7 +53,9 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
         self.assertEqual(res, self.mock_http_response)
 
     @defer.inlineCallbacks
-    def test_deduplicates_based_on_key(self):
+    def test_deduplicates_based_on_key(
+        self,
+    ) -> Generator["defer.Deferred[Any]", object, None]:
         cb = Mock(return_value=make_awaitable(self.mock_http_response))
         for i in range(3):  # invoke multiple times
             res = yield self.cache.fetch_or_execute(
@@ -58,18 +66,20 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
         cb.assert_called_once_with("some_arg", keyword="arg", changing_args=0)
 
     @defer.inlineCallbacks
-    def test_logcontexts_with_async_result(self):
+    def test_logcontexts_with_async_result(
+        self,
+    ) -> Generator["defer.Deferred[Any]", object, None]:
         @defer.inlineCallbacks
-        def cb():
+        def cb() -> Generator["defer.Deferred[object]", object, Tuple[int, JsonDict]]:
             yield Clock(reactor).sleep(0)
-            return "yay"
+            return 1, {}
 
         @defer.inlineCallbacks
-        def test():
+        def test() -> Generator["defer.Deferred[Any]", object, None]:
             with LoggingContext("c") as c1:
                 res = yield self.cache.fetch_or_execute(self.mock_key, cb)
                 self.assertIs(current_context(), c1)
-                self.assertEqual(res, "yay")
+                self.assertEqual(res, (1, {}))
 
         # run the test twice in parallel
         d = defer.gatherResults([test(), test()])
@@ -78,13 +88,15 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
         self.assertIs(current_context(), SENTINEL_CONTEXT)
 
     @defer.inlineCallbacks
-    def test_does_not_cache_exceptions(self):
+    def test_does_not_cache_exceptions(
+        self,
+    ) -> Generator["defer.Deferred[Any]", object, None]:
         """Checks that, if the callback throws an exception, it is called again
         for the next request.
         """
         called = [False]
 
-        def cb():
+        def cb() -> "defer.Deferred[Tuple[int, JsonDict]]":
             if called[0]:
                 # return a valid result the second time
                 return defer.succeed(self.mock_http_response)
@@ -104,13 +116,15 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
             self.assertIs(current_context(), test_context)
 
     @defer.inlineCallbacks
-    def test_does_not_cache_failures(self):
+    def test_does_not_cache_failures(
+        self,
+    ) -> Generator["defer.Deferred[Any]", object, None]:
         """Checks that, if the callback returns a failure, it is called again
         for the next request.
         """
         called = [False]
 
-        def cb():
+        def cb() -> "defer.Deferred[Tuple[int, JsonDict]]":
             if called[0]:
                 # return a valid result the second time
                 return defer.succeed(self.mock_http_response)
@@ -130,7 +144,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
             self.assertIs(current_context(), test_context)
 
     @defer.inlineCallbacks
-    def test_cleans_up(self):
+    def test_cleans_up(self) -> Generator["defer.Deferred[Any]", object, None]:
         cb = Mock(return_value=make_awaitable(self.mock_http_response))
         yield self.cache.fetch_or_execute(self.mock_key, cb, "an arg")
         # should NOT have cleaned up yet