diff --git a/synapse/http/site.py b/synapse/http/site.py
index 32b5e19c09..671fd3fbcc 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -14,13 +14,14 @@
import contextlib
import logging
import time
-from typing import Optional, Tuple, Type, Union
+from typing import Optional, Tuple, Union
import attr
from zope.interface import implementer
-from twisted.internet.interfaces import IAddress
+from twisted.internet.interfaces import IAddress, IReactorTime
from twisted.python.failure import Failure
+from twisted.web.resource import IResource
from twisted.web.server import Request, Site
from synapse.config.server import ListenerConfig
@@ -49,6 +50,7 @@ class SynapseRequest(Request):
* Redaction of access_token query-params in __repr__
* Logging at start and end
* Metrics to record CPU, wallclock and DB time by endpoint.
+ * A limit to the size of request which will be accepted
It also provides a method `processing`, which returns a context manager. If this
method is called, the request won't be logged until the context manager is closed;
@@ -59,8 +61,9 @@ class SynapseRequest(Request):
logcontext: the log context for this request
"""
- def __init__(self, channel, *args, **kw):
+ def __init__(self, channel, *args, max_request_body_size=1024, **kw):
Request.__init__(self, channel, *args, **kw)
+ self._max_request_body_size = max_request_body_size
self.site = channel.site # type: SynapseSite
self._channel = channel # this is used by the tests
self.start_time = 0.0
@@ -97,6 +100,18 @@ class SynapseRequest(Request):
self.site.site_tag,
)
+ def handleContentChunk(self, data):
+ # we should have a `content` by now.
+ assert self.content, "handleContentChunk() called before gotLength()"
+ if self.content.tell() + len(data) > self._max_request_body_size:
+ logger.warning(
+ "Aborting connection from %s because the request exceeds maximum size",
+ self.client,
+ )
+ self.transport.abortConnection()
+ return
+ super().handleContentChunk(data)
+
@property
def requester(self) -> Optional[Union[Requester, str]]:
return self._requester
@@ -485,29 +500,55 @@ class _XForwardedForAddress:
class SynapseSite(Site):
"""
- Subclass of a twisted http Site that does access logging with python's
- standard logging
+ Synapse-specific twisted http Site
+
+ This does two main things.
+
+ First, it replaces the requestFactory in use so that we build SynapseRequests
+ instead of regular t.w.server.Requests. All of the constructor params are really
+ just parameters for SynapseRequest.
+
+ Second, it inhibits the log() method called by Request.finish, since SynapseRequest
+ does its own logging.
"""
def __init__(
self,
- logger_name,
- site_tag,
+ logger_name: str,
+ site_tag: str,
config: ListenerConfig,
- resource,
+ resource: IResource,
server_version_string,
- *args,
- **kwargs,
+ max_request_body_size: int,
+ reactor: IReactorTime,
):
- Site.__init__(self, resource, *args, **kwargs)
+ """
+
+ Args:
+ logger_name: The name of the logger to use for access logs.
+ site_tag: A tag to use for this site - mostly in access logs.
+ config: Configuration for the HTTP listener corresponding to this site
+ resource: The base of the resource tree to be used for serving requests on
+ this site
+ server_version_string: A string to present for the Server header
+ max_request_body_size: Maximum request body length to allow before
+ dropping the connection
+ reactor: reactor to be used to manage connection timeouts
+ """
+ Site.__init__(self, resource, reactor=reactor)
self.site_tag = site_tag
assert config.http_options is not None
proxied = config.http_options.x_forwarded
- self.requestFactory = (
- XForwardedForRequest if proxied else SynapseRequest
- ) # type: Type[Request]
+ request_class = XForwardedForRequest if proxied else SynapseRequest
+
+ def request_factory(channel, queued) -> Request:
+ return request_class(
+ channel, max_request_body_size=max_request_body_size, queued=queued
+ )
+
+ self.requestFactory = request_factory # type: ignore
self.access_logger = logging.getLogger(logger_name)
self.server_version_string = server_version_string.encode("ascii")
|