diff --git a/tests/http/test_site.py b/tests/http/test_site.py
new file mode 100644
index 0000000000..8c13b4f693
--- /dev/null
+++ b/tests/http/test_site.py
@@ -0,0 +1,83 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet.address import IPv6Address
+from twisted.test.proto_helpers import StringTransport
+
+from synapse.app.homeserver import SynapseHomeServer
+
+from tests.unittest import HomeserverTestCase
+
+
+class SynapseRequestTestCase(HomeserverTestCase):
+ def make_homeserver(self, reactor, clock):
+ return self.setup_test_homeserver(homeserver_to_use=SynapseHomeServer)
+
+ def test_large_request(self):
+ """overlarge HTTP requests should be rejected"""
+ self.hs.start_listening()
+
+ # find the HTTP server which is configured to listen on port 0
+ (port, factory, _backlog, interface) = self.reactor.tcpServers[0]
+ self.assertEqual(interface, "::")
+ self.assertEqual(port, 0)
+
+ # as a control case, first send a regular request.
+
+ # complete the connection and wire it up to a fake transport
+ client_address = IPv6Address("TCP", "::1", "2345")
+ protocol = factory.buildProtocol(client_address)
+ transport = StringTransport()
+ protocol.makeConnection(transport)
+
+ protocol.dataReceived(
+ b"POST / HTTP/1.1\r\n"
+ b"Connection: close\r\n"
+ b"Transfer-Encoding: chunked\r\n"
+ b"\r\n"
+ b"0\r\n"
+ b"\r\n"
+ )
+
+ while not transport.disconnecting:
+ self.reactor.advance(1)
+
+ # we should get a 404
+ self.assertRegex(transport.value().decode(), r"^HTTP/1\.1 404 ")
+
+ # now send an oversized request
+ protocol = factory.buildProtocol(client_address)
+ transport = StringTransport()
+ protocol.makeConnection(transport)
+
+ protocol.dataReceived(
+ b"POST / HTTP/1.1\r\n"
+ b"Connection: close\r\n"
+ b"Transfer-Encoding: chunked\r\n"
+ b"\r\n"
+ )
+
+ # we deliberately send all the data in one big chunk, to ensure that
+ # twisted isn't buffering the data in the chunked transfer decoder.
+ # we start with the chunk size, in hex. (We won't actually send this much)
+ protocol.dataReceived(b"10000000\r\n")
+ sent = 0
+ while not transport.disconnected:
+ self.assertLess(sent, 0x10000000, "connection did not drop")
+ protocol.dataReceived(b"\0" * 1024)
+ sent += 1024
+
+ # default max upload size is 50M, so it should drop on the next buffer after
+ # that.
+ self.assertEqual(sent, 50 * 1024 * 1024 + 1024)
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index c9d04aef29..624bd1b927 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -12,14 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Any, Callable, Dict, List, Optional, Tuple, Type
+from typing import Any, Callable, Dict, List, Optional, Tuple
-from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
from twisted.internet.protocol import Protocol
-from twisted.internet.task import LoopingCall
-from twisted.web.http import HTTPChannel
from twisted.web.resource import Resource
-from twisted.web.server import Request, Site
from synapse.app.generic_worker import GenericWorkerServer
from synapse.http.server import JsonResource
@@ -33,7 +29,6 @@ from synapse.replication.tcp.resource import (
ServerReplicationStreamProtocol,
)
from synapse.server import HomeServer
-from synapse.util import Clock
from tests import unittest
from tests.server import FakeTransport
@@ -154,7 +149,19 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
client_protocol = client_factory.buildProtocol(None)
# Set up the server side protocol
- channel = _PushHTTPChannel(self.reactor, SynapseRequest, self.site)
+ channel = self.site.buildProtocol(None)
+
+ # hook into the channel's request factory so that we can keep a record
+ # of the requests
+ requests: List[SynapseRequest] = []
+ real_request_factory = channel.requestFactory
+
+ def request_factory(*args, **kwargs):
+ request = real_request_factory(*args, **kwargs)
+ requests.append(request)
+ return request
+
+ channel.requestFactory = request_factory
# Connect client to server and vice versa.
client_to_server_transport = FakeTransport(
@@ -176,7 +183,10 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
server_to_client_transport.loseConnection()
client_to_server_transport.loseConnection()
- return channel.request
+ # there should have been exactly one request
+ self.assertEqual(len(requests), 1)
+
+ return requests[0]
def assert_request_is_get_repl_stream_updates(
self, request: SynapseRequest, stream_name: str
@@ -349,6 +359,8 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
config=worker_hs.config.server.listeners[0],
resource=resource,
server_version_string="1",
+ max_request_body_size=4096,
+ reactor=self.reactor,
)
if worker_hs.config.redis.redis_enabled:
@@ -386,7 +398,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
client_protocol = client_factory.buildProtocol(None)
# Set up the server side protocol
- channel = _PushHTTPChannel(self.reactor, SynapseRequest, self._hs_to_site[hs])
+ channel = self._hs_to_site[hs].buildProtocol(None)
# Connect client to server and vice versa.
client_to_server_transport = FakeTransport(
@@ -444,112 +456,6 @@ class TestReplicationDataHandler(ReplicationDataHandler):
self.received_rdata_rows.append((stream_name, token, r))
-class _PushHTTPChannel(HTTPChannel):
- """A HTTPChannel that wraps pull producers to push producers.
-
- This is a hack to get around the fact that HTTPChannel transparently wraps a
- pull producer (which is what Synapse uses to reply to requests) with
- `_PullToPush` to convert it to a push producer. Unfortunately `_PullToPush`
- uses the standard reactor rather than letting us use our test reactor, which
- makes it very hard to test.
- """
-
- def __init__(
- self, reactor: IReactorTime, request_factory: Type[Request], site: Site
- ):
- super().__init__()
- self.reactor = reactor
- self.requestFactory = request_factory
- self.site = site
-
- self._pull_to_push_producer = None # type: Optional[_PullToPushProducer]
-
- def registerProducer(self, producer, streaming):
- # Convert pull producers to push producer.
- if not streaming:
- self._pull_to_push_producer = _PullToPushProducer(
- self.reactor, producer, self
- )
- producer = self._pull_to_push_producer
-
- super().registerProducer(producer, True)
-
- def unregisterProducer(self):
- if self._pull_to_push_producer:
- # We need to manually stop the _PullToPushProducer.
- self._pull_to_push_producer.stop()
-
- def checkPersistence(self, request, version):
- """Check whether the connection can be re-used"""
- # We hijack this to always say no for ease of wiring stuff up in
- # `handle_http_replication_attempt`.
- request.responseHeaders.setRawHeaders(b"connection", [b"close"])
- return False
-
- def requestDone(self, request):
- # Store the request for inspection.
- self.request = request
- super().requestDone(request)
-
-
-class _PullToPushProducer:
- """A push producer that wraps a pull producer."""
-
- def __init__(
- self, reactor: IReactorTime, producer: IPullProducer, consumer: IConsumer
- ):
- self._clock = Clock(reactor)
- self._producer = producer
- self._consumer = consumer
-
- # While running we use a looping call with a zero delay to call
- # resumeProducing on given producer.
- self._looping_call = None # type: Optional[LoopingCall]
-
- # We start writing next reactor tick.
- self._start_loop()
-
- def _start_loop(self):
- """Start the looping call to"""
-
- if not self._looping_call:
- # Start a looping call which runs every tick.
- self._looping_call = self._clock.looping_call(self._run_once, 0)
-
- def stop(self):
- """Stops calling resumeProducing."""
- if self._looping_call:
- self._looping_call.stop()
- self._looping_call = None
-
- def pauseProducing(self):
- """Implements IPushProducer"""
- self.stop()
-
- def resumeProducing(self):
- """Implements IPushProducer"""
- self._start_loop()
-
- def stopProducing(self):
- """Implements IPushProducer"""
- self.stop()
- self._producer.stopProducing()
-
- def _run_once(self):
- """Calls resumeProducing on producer once."""
-
- try:
- self._producer.resumeProducing()
- except Exception:
- logger.exception("Failed to call resumeProducing")
- try:
- self._consumer.unregisterProducer()
- except Exception:
- pass
-
- self.stopProducing()
-
-
class FakeRedisPubSubServer:
"""A fake Redis server for pub/sub."""
diff --git a/tests/server.py b/tests/server.py
index b535a5d886..9df8cda24f 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -603,12 +603,6 @@ class FakeTransport:
if self.disconnected:
return
- if not hasattr(self.other, "transport"):
- # the other has no transport yet; reschedule
- if self.autoflush:
- self._reactor.callLater(0.0, self.flush)
- return
-
if maxbytes is not None:
to_write = self.buffer[:maxbytes]
else:
diff --git a/tests/test_server.py b/tests/test_server.py
index 55cde7f62f..407e172e41 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -202,6 +202,8 @@ class OptionsResourceTests(unittest.TestCase):
parse_listener_def({"type": "http", "port": 0}),
self.resource,
"1.0",
+ max_request_body_size=1234,
+ reactor=self.reactor,
)
# render the request and return the channel
diff --git a/tests/unittest.py b/tests/unittest.py
index ee22a53849..9bd02bd9c4 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -247,6 +247,8 @@ class HomeserverTestCase(TestCase):
config=self.hs.config.server.listeners[0],
resource=self.resource,
server_version_string="1",
+ max_request_body_size=1234,
+ reactor=self.reactor,
)
from tests.rest.client.v1.utils import RestHelper
|