diff --git a/changelog.d/9878.misc b/changelog.d/9878.misc
new file mode 100644
index 0000000000..927876852d
--- /dev/null
+++ b/changelog.d/9878.misc
@@ -0,0 +1 @@
+Remove redundant `_PushHTTPChannel` test class.
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 5cf58d8b60..dc3519ea13 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -12,14 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Any, Callable, Dict, List, Optional, Tuple, Type
+from typing import Any, Callable, Dict, List, Optional, Tuple
-from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
from twisted.internet.protocol import Protocol
-from twisted.internet.task import LoopingCall
-from twisted.web.http import HTTPChannel
from twisted.web.resource import Resource
-from twisted.web.server import Request, Site
from synapse.app.generic_worker import GenericWorkerServer
from synapse.http.server import JsonResource
@@ -33,7 +29,6 @@ from synapse.replication.tcp.resource import (
ServerReplicationStreamProtocol,
)
from synapse.server import HomeServer
-from synapse.util import Clock
from tests import unittest
from tests.server import FakeTransport
@@ -154,7 +149,19 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
client_protocol = client_factory.buildProtocol(None)
# Set up the server side protocol
- channel = _PushHTTPChannel(self.reactor, SynapseRequest, self.site)
+ channel = self.site.buildProtocol(None)
+
+ # hook into the channel's request factory so that we can keep a record
+ # of the requests
+ requests: List[SynapseRequest] = []
+ real_request_factory = channel.requestFactory
+
+ def request_factory(*args, **kwargs):
+ request = real_request_factory(*args, **kwargs)
+ requests.append(request)
+ return request
+
+ channel.requestFactory = request_factory
# Connect client to server and vice versa.
client_to_server_transport = FakeTransport(
@@ -176,7 +183,10 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
server_to_client_transport.loseConnection()
client_to_server_transport.loseConnection()
- return channel.request
+ # there should have been exactly one request
+ self.assertEqual(len(requests), 1)
+
+ return requests[0]
def assert_request_is_get_repl_stream_updates(
self, request: SynapseRequest, stream_name: str
@@ -387,7 +397,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
client_protocol = client_factory.buildProtocol(None)
# Set up the server side protocol
- channel = _PushHTTPChannel(self.reactor, SynapseRequest, self._hs_to_site[hs])
+ channel = self._hs_to_site[hs].buildProtocol(None)
# Connect client to server and vice versa.
client_to_server_transport = FakeTransport(
@@ -445,112 +455,6 @@ class TestReplicationDataHandler(ReplicationDataHandler):
self.received_rdata_rows.append((stream_name, token, r))
-class _PushHTTPChannel(HTTPChannel):
- """A HTTPChannel that wraps pull producers to push producers.
-
- This is a hack to get around the fact that HTTPChannel transparently wraps a
- pull producer (which is what Synapse uses to reply to requests) with
- `_PullToPush` to convert it to a push producer. Unfortunately `_PullToPush`
- uses the standard reactor rather than letting us use our test reactor, which
- makes it very hard to test.
- """
-
- def __init__(
- self, reactor: IReactorTime, request_factory: Type[Request], site: Site
- ):
- super().__init__()
- self.reactor = reactor
- self.requestFactory = request_factory
- self.site = site
-
- self._pull_to_push_producer = None # type: Optional[_PullToPushProducer]
-
- def registerProducer(self, producer, streaming):
- # Convert pull producers to push producer.
- if not streaming:
- self._pull_to_push_producer = _PullToPushProducer(
- self.reactor, producer, self
- )
- producer = self._pull_to_push_producer
-
- super().registerProducer(producer, True)
-
- def unregisterProducer(self):
- if self._pull_to_push_producer:
- # We need to manually stop the _PullToPushProducer.
- self._pull_to_push_producer.stop()
-
- def checkPersistence(self, request, version):
- """Check whether the connection can be re-used"""
- # We hijack this to always say no for ease of wiring stuff up in
- # `handle_http_replication_attempt`.
- request.responseHeaders.setRawHeaders(b"connection", [b"close"])
- return False
-
- def requestDone(self, request):
- # Store the request for inspection.
- self.request = request
- super().requestDone(request)
-
-
-class _PullToPushProducer:
- """A push producer that wraps a pull producer."""
-
- def __init__(
- self, reactor: IReactorTime, producer: IPullProducer, consumer: IConsumer
- ):
- self._clock = Clock(reactor)
- self._producer = producer
- self._consumer = consumer
-
- # While running we use a looping call with a zero delay to call
- # resumeProducing on given producer.
- self._looping_call = None # type: Optional[LoopingCall]
-
- # We start writing next reactor tick.
- self._start_loop()
-
- def _start_loop(self):
- """Start the looping call to"""
-
- if not self._looping_call:
- # Start a looping call which runs every tick.
- self._looping_call = self._clock.looping_call(self._run_once, 0)
-
- def stop(self):
- """Stops calling resumeProducing."""
- if self._looping_call:
- self._looping_call.stop()
- self._looping_call = None
-
- def pauseProducing(self):
- """Implements IPushProducer"""
- self.stop()
-
- def resumeProducing(self):
- """Implements IPushProducer"""
- self._start_loop()
-
- def stopProducing(self):
- """Implements IPushProducer"""
- self.stop()
- self._producer.stopProducing()
-
- def _run_once(self):
- """Calls resumeProducing on producer once."""
-
- try:
- self._producer.resumeProducing()
- except Exception:
- logger.exception("Failed to call resumeProducing")
- try:
- self._consumer.unregisterProducer()
- except Exception:
- pass
-
- self.stopProducing()
-
-
class FakeRedisPubSubServer:
"""A fake Redis server for pub/sub."""
diff --git a/tests/server.py b/tests/server.py
index b535a5d886..9df8cda24f 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -603,12 +603,6 @@ class FakeTransport:
if self.disconnected:
return
- if not hasattr(self.other, "transport"):
- # the other has no transport yet; reschedule
- if self.autoflush:
- self._reactor.callLater(0.0, self.flush)
- return
-
if maxbytes is not None:
to_write = self.buffer[:maxbytes]
else:
|