summary refs log tree commit diff
path: root/tests/replication/_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/replication/_base.py')
-rw-r--r--tests/replication/_base.py136
1 files changed, 21 insertions, 115 deletions
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index c9d04aef29..624bd1b927 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -12,14 +12,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import Any, Callable, Dict, List, Optional, Tuple, Type
+from typing import Any, Callable, Dict, List, Optional, Tuple
 
-from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
 from twisted.internet.protocol import Protocol
-from twisted.internet.task import LoopingCall
-from twisted.web.http import HTTPChannel
 from twisted.web.resource import Resource
-from twisted.web.server import Request, Site
 
 from synapse.app.generic_worker import GenericWorkerServer
 from synapse.http.server import JsonResource
@@ -33,7 +29,6 @@ from synapse.replication.tcp.resource import (
     ServerReplicationStreamProtocol,
 )
 from synapse.server import HomeServer
-from synapse.util import Clock
 
 from tests import unittest
 from tests.server import FakeTransport
@@ -154,7 +149,19 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         client_protocol = client_factory.buildProtocol(None)
 
         # Set up the server side protocol
-        channel = _PushHTTPChannel(self.reactor, SynapseRequest, self.site)
+        channel = self.site.buildProtocol(None)
+
+        # hook into the channel's request factory so that we can keep a record
+        # of the requests
+        requests: List[SynapseRequest] = []
+        real_request_factory = channel.requestFactory
+
+        def request_factory(*args, **kwargs):
+            request = real_request_factory(*args, **kwargs)
+            requests.append(request)
+            return request
+
+        channel.requestFactory = request_factory
 
         # Connect client to server and vice versa.
         client_to_server_transport = FakeTransport(
@@ -176,7 +183,10 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         server_to_client_transport.loseConnection()
         client_to_server_transport.loseConnection()
 
-        return channel.request
+        # there should have been exactly one request
+        self.assertEqual(len(requests), 1)
+
+        return requests[0]
 
     def assert_request_is_get_repl_stream_updates(
         self, request: SynapseRequest, stream_name: str
@@ -349,6 +359,8 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
             config=worker_hs.config.server.listeners[0],
             resource=resource,
             server_version_string="1",
+            max_request_body_size=4096,
+            reactor=self.reactor,
         )
 
         if worker_hs.config.redis.redis_enabled:
@@ -386,7 +398,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         client_protocol = client_factory.buildProtocol(None)
 
         # Set up the server side protocol
-        channel = _PushHTTPChannel(self.reactor, SynapseRequest, self._hs_to_site[hs])
+        channel = self._hs_to_site[hs].buildProtocol(None)
 
         # Connect client to server and vice versa.
         client_to_server_transport = FakeTransport(
@@ -444,112 +456,6 @@ class TestReplicationDataHandler(ReplicationDataHandler):
             self.received_rdata_rows.append((stream_name, token, r))
 
 
-class _PushHTTPChannel(HTTPChannel):
-    """A HTTPChannel that wraps pull producers to push producers.
-
-    This is a hack to get around the fact that HTTPChannel transparently wraps a
-    pull producer (which is what Synapse uses to reply to requests) with
-    `_PullToPush` to convert it to a push producer. Unfortunately `_PullToPush`
-    uses the standard reactor rather than letting us use our test reactor, which
-    makes it very hard to test.
-    """
-
-    def __init__(
-        self, reactor: IReactorTime, request_factory: Type[Request], site: Site
-    ):
-        super().__init__()
-        self.reactor = reactor
-        self.requestFactory = request_factory
-        self.site = site
-
-        self._pull_to_push_producer = None  # type: Optional[_PullToPushProducer]
-
-    def registerProducer(self, producer, streaming):
-        # Convert pull producers to push producer.
-        if not streaming:
-            self._pull_to_push_producer = _PullToPushProducer(
-                self.reactor, producer, self
-            )
-            producer = self._pull_to_push_producer
-
-        super().registerProducer(producer, True)
-
-    def unregisterProducer(self):
-        if self._pull_to_push_producer:
-            # We need to manually stop the _PullToPushProducer.
-            self._pull_to_push_producer.stop()
-
-    def checkPersistence(self, request, version):
-        """Check whether the connection can be re-used"""
-        # We hijack this to always say no for ease of wiring stuff up in
-        # `handle_http_replication_attempt`.
-        request.responseHeaders.setRawHeaders(b"connection", [b"close"])
-        return False
-
-    def requestDone(self, request):
-        # Store the request for inspection.
-        self.request = request
-        super().requestDone(request)
-
-
-class _PullToPushProducer:
-    """A push producer that wraps a pull producer."""
-
-    def __init__(
-        self, reactor: IReactorTime, producer: IPullProducer, consumer: IConsumer
-    ):
-        self._clock = Clock(reactor)
-        self._producer = producer
-        self._consumer = consumer
-
-        # While running we use a looping call with a zero delay to call
-        # resumeProducing on given producer.
-        self._looping_call = None  # type: Optional[LoopingCall]
-
-        # We start writing next reactor tick.
-        self._start_loop()
-
-    def _start_loop(self):
-        """Start the looping call to"""
-
-        if not self._looping_call:
-            # Start a looping call which runs every tick.
-            self._looping_call = self._clock.looping_call(self._run_once, 0)
-
-    def stop(self):
-        """Stops calling resumeProducing."""
-        if self._looping_call:
-            self._looping_call.stop()
-            self._looping_call = None
-
-    def pauseProducing(self):
-        """Implements IPushProducer"""
-        self.stop()
-
-    def resumeProducing(self):
-        """Implements IPushProducer"""
-        self._start_loop()
-
-    def stopProducing(self):
-        """Implements IPushProducer"""
-        self.stop()
-        self._producer.stopProducing()
-
-    def _run_once(self):
-        """Calls resumeProducing on producer once."""
-
-        try:
-            self._producer.resumeProducing()
-        except Exception:
-            logger.exception("Failed to call resumeProducing")
-            try:
-                self._consumer.unregisterProducer()
-            except Exception:
-                pass
-
-            self.stopProducing()
-
-
 class FakeRedisPubSubServer:
     """A fake Redis server for pub/sub."""