summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2021-03-26 16:49:46 +0000
committerGitHub <noreply@github.com>2021-03-26 16:49:46 +0000
commitb5efcb577e2c9b8b38cb86f87cf65fa93eb2566b (patch)
tree1172fe29e337b163f56283f0eb8f898324ef32cf /tests
parentMerge branch 'master' into develop (diff)
downloadsynapse-b5efcb577e2c9b8b38cb86f87cf65fa93eb2566b.tar.xz
Make it possible to use dmypy (#9692)
Running `dmypy run` will do a `mypy` check while spinning up a daemon
that makes rerunning `dmypy run` a lot faster.

`dmypy` doesn't support `follow_imports = silent` and has
`local_partial_types` enabled, so this PR enables those options and
fixes the issues that were newly raised. Note that `local_partial_types`
will be enabled by default in upcoming mypy releases.
Diffstat (limited to '')
-rw-r--r--tests/replication/tcp/streams/test_typing.py1
-rw-r--r--tests/replication/test_multi_media_repo.py4
-rw-r--r--tests/server.py28
3 files changed, 23 insertions, 10 deletions
diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py
index 5acfb3e53e..ca49d4dd3a 100644
--- a/tests/replication/tcp/streams/test_typing.py
+++ b/tests/replication/tcp/streams/test_typing.py
@@ -69,6 +69,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
         self.assert_request_is_get_repl_stream_updates(request, "typing")
 
         # The from token should be the token from the last RDATA we got.
+        assert request.args is not None
         self.assertEqual(int(request.args[b"from_token"][0]), token)
 
         self.test_handler.on_rdata.assert_called_once()
diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py
index 7ff11cde10..b0800f9840 100644
--- a/tests/replication/test_multi_media_repo.py
+++ b/tests/replication/test_multi_media_repo.py
@@ -15,7 +15,7 @@
 import logging
 import os
 from binascii import unhexlify
-from typing import Tuple
+from typing import Optional, Tuple
 
 from twisted.internet.protocol import Factory
 from twisted.protocols.tls import TLSMemoryBIOFactory
@@ -32,7 +32,7 @@ from tests.server import FakeChannel, FakeSite, FakeTransport, make_request
 
 logger = logging.getLogger(__name__)
 
-test_server_connection_factory = None
+test_server_connection_factory = None  # type: Optional[TestServerTLSConnectionFactory]
 
 
 class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
diff --git a/tests/server.py b/tests/server.py
index 57cc4ac605..b535a5d886 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -2,7 +2,7 @@ import json
 import logging
 from collections import deque
 from io import SEEK_END, BytesIO
-from typing import Callable, Iterable, MutableMapping, Optional, Tuple, Union
+from typing import Callable, Dict, Iterable, MutableMapping, Optional, Tuple, Union
 
 import attr
 from typing_extensions import Deque
@@ -13,8 +13,11 @@ from twisted.internet._resolver import SimpleResolverComplexifier
 from twisted.internet.defer import Deferred, fail, succeed
 from twisted.internet.error import DNSLookupError
 from twisted.internet.interfaces import (
+    IHostnameResolver,
+    IProtocol,
+    IPullProducer,
+    IPushProducer,
     IReactorPluggableNameResolver,
-    IReactorTCP,
     IResolverSimple,
     ITransport,
 )
@@ -45,11 +48,11 @@ class FakeChannel:
     wire).
     """
 
-    site = attr.ib(type=Site)
+    site = attr.ib(type=Union[Site, "FakeSite"])
     _reactor = attr.ib()
     result = attr.ib(type=dict, default=attr.Factory(dict))
     _ip = attr.ib(type=str, default="127.0.0.1")
-    _producer = None
+    _producer = None  # type: Optional[Union[IPullProducer, IPushProducer]]
 
     @property
     def json_body(self):
@@ -159,7 +162,11 @@ class FakeChannel:
 
         Any cookines found are added to the given dict
         """
-        for h in self.headers.getRawHeaders("Set-Cookie"):
+        headers = self.headers.getRawHeaders("Set-Cookie")
+        if not headers:
+            return
+
+        for h in headers:
             parts = h.split(";")
             k, v = parts[0].split("=", maxsplit=1)
             cookies[k] = v
@@ -311,8 +318,8 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
 
         self._tcp_callbacks = {}
         self._udp = []
-        lookups = self.lookups = {}
-        self._thread_callbacks = deque()  # type: Deque[Callable[[], None]]()
+        lookups = self.lookups = {}  # type: Dict[str, str]
+        self._thread_callbacks = deque()  # type: Deque[Callable[[], None]]
 
         @implementer(IResolverSimple)
         class FakeResolver:
@@ -324,6 +331,9 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
         self.nameResolver = SimpleResolverComplexifier(FakeResolver())
         super().__init__()
 
+    def installNameResolver(self, resolver: IHostnameResolver) -> IHostnameResolver:
+        raise NotImplementedError()
+
     def listenUDP(self, port, protocol, interface="", maxPacketSize=8196):
         p = udp.Port(port, protocol, interface, maxPacketSize, self)
         p.startListening()
@@ -621,7 +631,9 @@ class FakeTransport:
             self.disconnected = True
 
 
-def connect_client(reactor: IReactorTCP, client_id: int) -> AccumulatingProtocol:
+def connect_client(
+    reactor: ThreadedMemoryReactorClock, client_id: int
+) -> Tuple[IProtocol, AccumulatingProtocol]:
     """
     Connect a client to a fake TCP transport.