summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/replication/_base.py27
-rw-r--r--tests/server.py2
-rw-r--r--tests/test_utils/logging_setup.py2
3 files changed, 19 insertions, 12 deletions
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index f6a6aed35e..20940c8107 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -22,6 +22,7 @@ from twisted.internet.protocol import Protocol
 from twisted.internet.task import LoopingCall
 from twisted.web.http import HTTPChannel
 from twisted.web.resource import Resource
+from twisted.web.server import Request, Site
 
 from synapse.app.generic_worker import (
     GenericWorkerReplicationHandler,
@@ -32,7 +33,10 @@ from synapse.http.site import SynapseRequest, SynapseSite
 from synapse.replication.http import ReplicationRestResource
 from synapse.replication.tcp.handler import ReplicationCommandHandler
 from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
-from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
+from synapse.replication.tcp.resource import (
+    ReplicationStreamProtocolFactory,
+    ServerReplicationStreamProtocol,
+)
 from synapse.server import HomeServer
 from synapse.util import Clock
 
@@ -59,7 +63,9 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         # build a replication server
         server_factory = ReplicationStreamProtocolFactory(hs)
         self.streamer = hs.get_replication_streamer()
-        self.server = server_factory.buildProtocol(None)
+        self.server = server_factory.buildProtocol(
+            None
+        )  # type: ServerReplicationStreamProtocol
 
         # Make a new HomeServer object for the worker
         self.reactor.lookups["testserv"] = "1.2.3.4"
@@ -155,9 +161,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         request_factory = OneShotRequestFactory()
 
         # Set up the server side protocol
-        channel = _PushHTTPChannel(self.reactor)
-        channel.requestFactory = request_factory
-        channel.site = self.site
+        channel = _PushHTTPChannel(self.reactor, request_factory, self.site)
 
         # Connect client to server and vice versa.
         client_to_server_transport = FakeTransport(
@@ -188,8 +192,9 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         fetching updates for given stream.
         """
 
+        path = request.path  # type: bytes  # type: ignore
         self.assertRegex(
-            request.path,
+            path,
             br"^/_synapse/replication/get_repl_stream_updates/%s/[^/]+$"
             % (stream_name.encode("ascii"),),
         )
@@ -390,9 +395,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         request_factory = OneShotRequestFactory()
 
         # Set up the server side protocol
-        channel = _PushHTTPChannel(self.reactor)
-        channel.requestFactory = request_factory
-        channel.site = self._hs_to_site[hs]
+        channel = _PushHTTPChannel(self.reactor, request_factory, self._hs_to_site[hs])
 
         # Connect client to server and vice versa.
         client_to_server_transport = FakeTransport(
@@ -475,9 +478,13 @@ class _PushHTTPChannel(HTTPChannel):
     makes it very hard to test.
     """
 
-    def __init__(self, reactor: IReactorTime):
+    def __init__(
+        self, reactor: IReactorTime, request_factory: Callable[..., Request], site: Site
+    ):
         super().__init__()
         self.reactor = reactor
+        self.requestFactory = request_factory
+        self.site = site
 
         self._pull_to_push_producer = None  # type: Optional[_PullToPushProducer]
 
diff --git a/tests/server.py b/tests/server.py
index 939a0008ca..863f6da738 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -188,7 +188,7 @@ class FakeSite:
 
 def make_request(
     reactor,
-    site: Site,
+    site: Union[Site, FakeSite],
     method,
     path,
     content=b"",
diff --git a/tests/test_utils/logging_setup.py b/tests/test_utils/logging_setup.py
index 52ae5c5713..74568b34f8 100644
--- a/tests/test_utils/logging_setup.py
+++ b/tests/test_utils/logging_setup.py
@@ -28,7 +28,7 @@ class ToTwistedHandler(logging.Handler):
     def emit(self, record):
         log_entry = self.format(record)
         log_level = record.levelname.lower().replace("warning", "warn")
-        self.tx_log.emit(
+        self.tx_log.emit(  # type: ignore
             twisted.logger.LogLevel.levelWithName(log_level), "{entry}", entry=log_entry
         )