summary refs log tree commit diff
path: root/tests/replication/_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/replication/_base.py')
-rw-r--r--tests/replication/_base.py27
1 files changed, 17 insertions, 10 deletions
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index f6a6aed35e..20940c8107 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -22,6 +22,7 @@ from twisted.internet.protocol import Protocol
 from twisted.internet.task import LoopingCall
 from twisted.web.http import HTTPChannel
 from twisted.web.resource import Resource
+from twisted.web.server import Request, Site
 
 from synapse.app.generic_worker import (
     GenericWorkerReplicationHandler,
@@ -32,7 +33,10 @@ from synapse.http.site import SynapseRequest, SynapseSite
 from synapse.replication.http import ReplicationRestResource
 from synapse.replication.tcp.handler import ReplicationCommandHandler
 from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
-from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
+from synapse.replication.tcp.resource import (
+    ReplicationStreamProtocolFactory,
+    ServerReplicationStreamProtocol,
+)
 from synapse.server import HomeServer
 from synapse.util import Clock
 
@@ -59,7 +63,9 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         # build a replication server
         server_factory = ReplicationStreamProtocolFactory(hs)
         self.streamer = hs.get_replication_streamer()
-        self.server = server_factory.buildProtocol(None)
+        self.server = server_factory.buildProtocol(
+            None
+        )  # type: ServerReplicationStreamProtocol
 
         # Make a new HomeServer object for the worker
         self.reactor.lookups["testserv"] = "1.2.3.4"
@@ -155,9 +161,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         request_factory = OneShotRequestFactory()
 
         # Set up the server side protocol
-        channel = _PushHTTPChannel(self.reactor)
-        channel.requestFactory = request_factory
-        channel.site = self.site
+        channel = _PushHTTPChannel(self.reactor, request_factory, self.site)
 
         # Connect client to server and vice versa.
         client_to_server_transport = FakeTransport(
@@ -188,8 +192,9 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         fetching updates for given stream.
         """
 
+        path = request.path  # type: bytes  # type: ignore
         self.assertRegex(
-            request.path,
+            path,
             br"^/_synapse/replication/get_repl_stream_updates/%s/[^/]+$"
             % (stream_name.encode("ascii"),),
         )
@@ -390,9 +395,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         request_factory = OneShotRequestFactory()
 
         # Set up the server side protocol
-        channel = _PushHTTPChannel(self.reactor)
-        channel.requestFactory = request_factory
-        channel.site = self._hs_to_site[hs]
+        channel = _PushHTTPChannel(self.reactor, request_factory, self._hs_to_site[hs])
 
         # Connect client to server and vice versa.
         client_to_server_transport = FakeTransport(
@@ -475,9 +478,13 @@ class _PushHTTPChannel(HTTPChannel):
     makes it very hard to test.
     """
 
-    def __init__(self, reactor: IReactorTime):
+    def __init__(
+        self, reactor: IReactorTime, request_factory: Callable[..., Request], site: Site
+    ):
         super().__init__()
         self.reactor = reactor
+        self.requestFactory = request_factory
+        self.site = site
 
         self._pull_to_push_producer = None  # type: Optional[_PullToPushProducer]