summary refs log tree commit diff
path: root/synapse/http/site.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/http/site.py')
-rw-r--r--synapse/http/site.py75
1 files changed, 45 insertions, 30 deletions
diff --git a/synapse/http/site.py b/synapse/http/site.py
index c665a9d5db..755ad56637 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -14,14 +14,15 @@
 import contextlib
 import logging
 import time
-from typing import Optional, Tuple, Union
+from typing import Generator, Optional, Tuple, Union
 
 import attr
 from zope.interface import implementer
 
 from twisted.internet.interfaces import IAddress, IReactorTime
 from twisted.python.failure import Failure
-from twisted.web.resource import IResource
+from twisted.web.http import HTTPChannel
+from twisted.web.resource import IResource, Resource
 from twisted.web.server import Request, Site
 
 from synapse.config.server import ListenerConfig
@@ -61,10 +62,18 @@ class SynapseRequest(Request):
         logcontext: the log context for this request
     """
 
-    def __init__(self, channel, *args, max_request_body_size=1024, **kw):
-        Request.__init__(self, channel, *args, **kw)
+    def __init__(
+        self,
+        channel: HTTPChannel,
+        site: "SynapseSite",
+        *args,
+        max_request_body_size: int = 1024,
+        **kw,
+    ):
+        super().__init__(channel, *args, **kw)
         self._max_request_body_size = max_request_body_size
-        self.site: SynapseSite = channel.site
+        self.synapse_site = site
+        self.reactor = site.reactor
         self._channel = channel  # this is used by the tests
         self.start_time = 0.0
 
@@ -83,13 +92,13 @@ class SynapseRequest(Request):
         self._is_processing = False
 
         # the time when the asynchronous request handler completed its processing
-        self._processing_finished_time = None
+        self._processing_finished_time: Optional[float] = None
 
         # what time we finished sending the response to the client (or the connection
         # dropped)
-        self.finish_time = None
+        self.finish_time: Optional[float] = None
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         # We overwrite this so that we don't log ``access_token``
         return "<%s at 0x%x method=%r uri=%r clientproto=%r site=%r>" % (
             self.__class__.__name__,
@@ -97,10 +106,10 @@ class SynapseRequest(Request):
             self.get_method(),
             self.get_redacted_uri(),
             self.clientproto.decode("ascii", errors="replace"),
-            self.site.site_tag,
+            self.synapse_site.site_tag,
         )
 
-    def handleContentChunk(self, data):
+    def handleContentChunk(self, data: bytes) -> None:
         # we should have a `content` by now.
         assert self.content, "handleContentChunk() called before gotLength()"
         if self.content.tell() + len(data) > self._max_request_body_size:
@@ -139,7 +148,7 @@ class SynapseRequest(Request):
         # If there's no authenticated entity, it was the requester.
         self.logcontext.request.authenticated_entity = authenticated_entity or requester
 
-    def get_request_id(self):
+    def get_request_id(self) -> str:
         return "%s-%i" % (self.get_method(), self.request_seq)
 
     def get_redacted_uri(self) -> str:
@@ -205,7 +214,7 @@ class SynapseRequest(Request):
 
         return None, None
 
-    def render(self, resrc):
+    def render(self, resrc: Resource) -> None:
         # this is called once a Resource has been found to serve the request; in our
         # case the Resource in question will normally be a JsonResource.
 
@@ -216,7 +225,7 @@ class SynapseRequest(Request):
             request=ContextRequest(
                 request_id=request_id,
                 ip_address=self.getClientIP(),
-                site_tag=self.site.site_tag,
+                site_tag=self.synapse_site.site_tag,
                 # The requester is going to be unknown at this point.
                 requester=None,
                 authenticated_entity=None,
@@ -228,7 +237,7 @@ class SynapseRequest(Request):
         )
 
         # override the Server header which is set by twisted
-        self.setHeader("Server", self.site.server_version_string)
+        self.setHeader("Server", self.synapse_site.server_version_string)
 
         with PreserveLoggingContext(self.logcontext):
             # we start the request metrics timer here with an initial stab
@@ -247,7 +256,7 @@ class SynapseRequest(Request):
             requests_counter.labels(self.get_method(), self.request_metrics.name).inc()
 
     @contextlib.contextmanager
-    def processing(self):
+    def processing(self) -> Generator[None, None, None]:
         """Record the fact that we are processing this request.
 
         Returns a context manager; the correct way to use this is:
@@ -282,7 +291,7 @@ class SynapseRequest(Request):
             if self.finish_time is not None:
                 self._finished_processing()
 
-    def finish(self):
+    def finish(self) -> None:
         """Called when all response data has been written to this Request.
 
         Overrides twisted.web.server.Request.finish to record the finish time and do
@@ -295,7 +304,7 @@ class SynapseRequest(Request):
             with PreserveLoggingContext(self.logcontext):
                 self._finished_processing()
 
-    def connectionLost(self, reason):
+    def connectionLost(self, reason: Union[Failure, Exception]) -> None:
         """Called when the client connection is closed before the response is written.
 
         Overrides twisted.web.server.Request.connectionLost to record the finish time and
@@ -327,7 +336,7 @@ class SynapseRequest(Request):
             if not self._is_processing:
                 self._finished_processing()
 
-    def _started_processing(self, servlet_name):
+    def _started_processing(self, servlet_name: str) -> None:
         """Record the fact that we are processing this request.
 
         This will log the request's arrival. Once the request completes,
@@ -346,17 +355,19 @@ class SynapseRequest(Request):
             self.start_time, name=servlet_name, method=self.get_method()
         )
 
-        self.site.access_logger.debug(
+        self.synapse_site.access_logger.debug(
             "%s - %s - Received request: %s %s",
             self.getClientIP(),
-            self.site.site_tag,
+            self.synapse_site.site_tag,
             self.get_method(),
             self.get_redacted_uri(),
         )
 
-    def _finished_processing(self):
+    def _finished_processing(self) -> None:
         """Log the completion of this request and update the metrics"""
         assert self.logcontext is not None
+        assert self.finish_time is not None
+
         usage = self.logcontext.get_resource_usage()
 
         if self._processing_finished_time is None:
@@ -386,13 +397,13 @@ class SynapseRequest(Request):
         if authenticated_entity:
             requester = f"{authenticated_entity}|{requester}"
 
-        self.site.access_logger.log(
+        self.synapse_site.access_logger.log(
             log_level,
             "%s - %s - {%s}"
             " Processed request: %.3fsec/%.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)"
             ' %sB %s "%s %s %s" "%s" [%d dbevts]',
             self.getClientIP(),
-            self.site.site_tag,
+            self.synapse_site.site_tag,
             requester,
             processing_time,
             response_send_time,
@@ -437,7 +448,7 @@ class XForwardedForRequest(SynapseRequest):
     _forwarded_for: "Optional[_XForwardedForAddress]" = None
     _forwarded_https: bool = False
 
-    def requestReceived(self, command, path, version):
+    def requestReceived(self, command: bytes, path: bytes, version: bytes) -> None:
         # this method is called by the Channel once the full request has been
         # received, to dispatch the request to a resource.
         # We can use it to set the IP address and protocol according to the
@@ -445,7 +456,7 @@ class XForwardedForRequest(SynapseRequest):
         self._process_forwarded_headers()
         return super().requestReceived(command, path, version)
 
-    def _process_forwarded_headers(self):
+    def _process_forwarded_headers(self) -> None:
         headers = self.requestHeaders.getRawHeaders(b"x-forwarded-for")
         if not headers:
             return
@@ -470,7 +481,7 @@ class XForwardedForRequest(SynapseRequest):
             )
             self._forwarded_https = True
 
-    def isSecure(self):
+    def isSecure(self) -> bool:
         if self._forwarded_https:
             return True
         return super().isSecure()
@@ -520,7 +531,7 @@ class SynapseSite(Site):
         site_tag: str,
         config: ListenerConfig,
         resource: IResource,
-        server_version_string,
+        server_version_string: str,
         max_request_body_size: int,
         reactor: IReactorTime,
     ):
@@ -540,19 +551,23 @@ class SynapseSite(Site):
         Site.__init__(self, resource, reactor=reactor)
 
         self.site_tag = site_tag
+        self.reactor = reactor
 
         assert config.http_options is not None
         proxied = config.http_options.x_forwarded
         request_class = XForwardedForRequest if proxied else SynapseRequest
 
-        def request_factory(channel, queued) -> Request:
+        def request_factory(channel, queued: bool) -> Request:
             return request_class(
-                channel, max_request_body_size=max_request_body_size, queued=queued
+                channel,
+                self,
+                max_request_body_size=max_request_body_size,
+                queued=queued,
             )
 
         self.requestFactory = request_factory  # type: ignore
         self.access_logger = logging.getLogger(logger_name)
         self.server_version_string = server_version_string.encode("ascii")
 
-    def log(self, request):
+    def log(self, request: SynapseRequest) -> None:
         pass