summary refs log tree commit diff
path: root/synapse/http/site.py
diff options
context:
space:
mode:
authorRichard van der Hoff <1389908+richvdh@users.noreply.github.com>2021-04-23 19:20:44 +0100
committerGitHub <noreply@github.com>2021-04-23 19:20:44 +0100
commit3ff225175462dde8376aa584e3a47c43b1f0e790 (patch)
tree550b35fb107b0af201c0e409818d9630d8d3a599 /synapse/http/site.py
parentKill off `_PushHTTPChannel`. (#9878) (diff)
downloadsynapse-3ff225175462dde8376aa584e3a47c43b1f0e790.tar.xz
Improved validation for received requests (#9817)
* Simplify `start_listening` callpath

* Correctly check the size of uploaded files
Diffstat (limited to 'synapse/http/site.py')
-rw-r--r--synapse/http/site.py32
1 files changed, 27 insertions, 5 deletions
diff --git a/synapse/http/site.py b/synapse/http/site.py
index e911ee4809..671fd3fbcc 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -14,7 +14,7 @@
 import contextlib
 import logging
 import time
-from typing import Optional, Tuple, Type, Union
+from typing import Optional, Tuple, Union
 
 import attr
 from zope.interface import implementer
@@ -50,6 +50,7 @@ class SynapseRequest(Request):
      * Redaction of access_token query-params in __repr__
      * Logging at start and end
      * Metrics to record CPU, wallclock and DB time by endpoint.
+     * A limit to the size of request which will be accepted
 
     It also provides a method `processing`, which returns a context manager. If this
     method is called, the request won't be logged until the context manager is closed;
@@ -60,8 +61,9 @@ class SynapseRequest(Request):
         logcontext: the log context for this request
     """
 
-    def __init__(self, channel, *args, **kw):
+    def __init__(self, channel, *args, max_request_body_size=1024, **kw):
         Request.__init__(self, channel, *args, **kw)
+        self._max_request_body_size = max_request_body_size
         self.site = channel.site  # type: SynapseSite
         self._channel = channel  # this is used by the tests
         self.start_time = 0.0
@@ -98,6 +100,18 @@ class SynapseRequest(Request):
             self.site.site_tag,
         )
 
+    def handleContentChunk(self, data):
+        # we should have a `content` by now.
+        assert self.content, "handleContentChunk() called before gotLength()"
+        if self.content.tell() + len(data) > self._max_request_body_size:
+            logger.warning(
+                "Aborting connection from %s because the request exceeds maximum size",
+                self.client,
+            )
+            self.transport.abortConnection()
+            return
+        super().handleContentChunk(data)
+
     @property
     def requester(self) -> Optional[Union[Requester, str]]:
         return self._requester
@@ -505,6 +519,7 @@ class SynapseSite(Site):
         config: ListenerConfig,
         resource: IResource,
         server_version_string,
+        max_request_body_size: int,
         reactor: IReactorTime,
     ):
         """
@@ -516,6 +531,8 @@ class SynapseSite(Site):
             resource:  The base of the resource tree to be used for serving requests on
                 this site
             server_version_string: A string to present for the Server header
+            max_request_body_size: Maximum request body length to allow before
+                dropping the connection
             reactor: reactor to be used to manage connection timeouts
         """
         Site.__init__(self, resource, reactor=reactor)
@@ -524,9 +541,14 @@ class SynapseSite(Site):
 
         assert config.http_options is not None
         proxied = config.http_options.x_forwarded
-        self.requestFactory = (
-            XForwardedForRequest if proxied else SynapseRequest
-        )  # type: Type[Request]
+        request_class = XForwardedForRequest if proxied else SynapseRequest
+
+        def request_factory(channel, queued) -> Request:
+            return request_class(
+                channel, max_request_body_size=max_request_body_size, queued=queued
+            )
+
+        self.requestFactory = request_factory  # type: ignore
         self.access_logger = logging.getLogger(logger_name)
         self.server_version_string = server_version_string.encode("ascii")