summary refs log tree commit diff
path: root/tests/replication/_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/replication/_base.py')
-rw-r--r--tests/replication/_base.py44
1 files changed, 16 insertions, 28 deletions
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 20940c8107..67b7913666 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -13,9 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import Any, Callable, Dict, List, Optional, Tuple
-
-import attr
+from typing import Any, Callable, Dict, List, Optional, Tuple, Type
 
 from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
 from twisted.internet.protocol import Protocol
@@ -158,10 +156,8 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         # Set up client side protocol
         client_protocol = client_factory.buildProtocol(None)
 
-        request_factory = OneShotRequestFactory()
-
         # Set up the server side protocol
-        channel = _PushHTTPChannel(self.reactor, request_factory, self.site)
+        channel = _PushHTTPChannel(self.reactor, SynapseRequest, self.site)
 
         # Connect client to server and vice versa.
         client_to_server_transport = FakeTransport(
@@ -183,7 +179,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         server_to_client_transport.loseConnection()
         client_to_server_transport.loseConnection()
 
-        return request_factory.request
+        return channel.request
 
     def assert_request_is_get_repl_stream_updates(
         self, request: SynapseRequest, stream_name: str
@@ -237,7 +233,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         if self.hs.config.redis.redis_enabled:
             # Handle attempts to connect to fake redis server.
             self.reactor.add_tcp_client_callback(
-                "localhost",
+                b"localhost",
                 6379,
                 self.connect_any_redis_attempts,
             )
@@ -392,10 +388,8 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         # Set up client side protocol
         client_protocol = client_factory.buildProtocol(None)
 
-        request_factory = OneShotRequestFactory()
-
         # Set up the server side protocol
-        channel = _PushHTTPChannel(self.reactor, request_factory, self._hs_to_site[hs])
+        channel = _PushHTTPChannel(self.reactor, SynapseRequest, self._hs_to_site[hs])
 
         # Connect client to server and vice versa.
         client_to_server_transport = FakeTransport(
@@ -421,7 +415,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         clients = self.reactor.tcpClients
         while clients:
             (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
-            self.assertEqual(host, "localhost")
+            self.assertEqual(host, b"localhost")
             self.assertEqual(port, 6379)
 
             client_protocol = client_factory.buildProtocol(None)
@@ -453,21 +447,6 @@ class TestReplicationDataHandler(GenericWorkerReplicationHandler):
             self.received_rdata_rows.append((stream_name, token, r))
 
 
-@attr.s()
-class OneShotRequestFactory:
-    """A simple request factory that generates a single `SynapseRequest` and
-    stores it for future use. Can only be used once.
-    """
-
-    request = attr.ib(default=None)
-
-    def __call__(self, *args, **kwargs):
-        assert self.request is None
-
-        self.request = SynapseRequest(*args, **kwargs)
-        return self.request
-
-
 class _PushHTTPChannel(HTTPChannel):
     """A HTTPChannel that wraps pull producers to push producers.
 
@@ -479,7 +458,7 @@ class _PushHTTPChannel(HTTPChannel):
     """
 
     def __init__(
-        self, reactor: IReactorTime, request_factory: Callable[..., Request], site: Site
+        self, reactor: IReactorTime, request_factory: Type[Request], site: Site
     ):
         super().__init__()
         self.reactor = reactor
@@ -510,6 +489,11 @@ class _PushHTTPChannel(HTTPChannel):
         request.responseHeaders.setRawHeaders(b"connection", [b"close"])
         return False
 
+    def requestDone(self, request):
+        # Store the request for inspection.
+        self.request = request
+        super().requestDone(request)
+
 
 class _PullToPushProducer:
     """A push producer that wraps a pull producer."""
@@ -597,6 +581,8 @@ class FakeRedisPubSubServer:
 class FakeRedisPubSubProtocol(Protocol):
     """A connection from a client talking to the fake Redis server."""
 
+    transport = None  # type: Optional[FakeTransport]
+
     def __init__(self, server: FakeRedisPubSubServer):
         self._server = server
         self._reader = hiredis.Reader()
@@ -641,6 +627,8 @@ class FakeRedisPubSubProtocol(Protocol):
 
     def send(self, msg):
         """Send a message back to the client."""
+        assert self.transport is not None
+
         raw = self.encode(msg).encode("utf-8")
 
         self.transport.write(raw)