summary refs log tree commit diff
path: root/synapse/http/site.py
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2021-09-24 11:01:25 +0100
committerGitHub <noreply@github.com>2021-09-24 11:01:25 +0100
commit50022cff966a3991fbd8a1e5c98f490d9b335442 (patch)
treeb3a86c3f0d2f8ef5ea352f5c671651e6169ab3b4 /synapse/http/site.py
parentFix AuthBlocking check when requester is appservice (#10881) (diff)
downloadsynapse-50022cff966a3991fbd8a1e5c98f490d9b335442.tar.xz
Add reactor to `SynapseRequest` and fix up types. (#10868)
Diffstat (limited to 'synapse/http/site.py')
-rw-r--r--synapse/http/site.py37
1 files changed, 24 insertions, 13 deletions
diff --git a/synapse/http/site.py b/synapse/http/site.py
index dd4c749e16..755ad56637 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -14,13 +14,14 @@
 import contextlib
 import logging
 import time
-from typing import Optional, Tuple, Union
+from typing import Generator, Optional, Tuple, Union
 
 import attr
 from zope.interface import implementer
 
 from twisted.internet.interfaces import IAddress, IReactorTime
 from twisted.python.failure import Failure
+from twisted.web.http import HTTPChannel
 from twisted.web.resource import IResource, Resource
 from twisted.web.server import Request, Site
 
@@ -61,10 +62,18 @@ class SynapseRequest(Request):
         logcontext: the log context for this request
     """
 
-    def __init__(self, channel, *args, max_request_body_size: int = 1024, **kw):
-        Request.__init__(self, channel, *args, **kw)
+    def __init__(
+        self,
+        channel: HTTPChannel,
+        site: "SynapseSite",
+        *args,
+        max_request_body_size: int = 1024,
+        **kw,
+    ):
+        super().__init__(channel, *args, **kw)
         self._max_request_body_size = max_request_body_size
-        self.site: SynapseSite = channel.site
+        self.synapse_site = site
+        self.reactor = site.reactor
         self._channel = channel  # this is used by the tests
         self.start_time = 0.0
 
@@ -97,7 +106,7 @@ class SynapseRequest(Request):
             self.get_method(),
             self.get_redacted_uri(),
             self.clientproto.decode("ascii", errors="replace"),
-            self.site.site_tag,
+            self.synapse_site.site_tag,
         )
 
     def handleContentChunk(self, data: bytes) -> None:
@@ -216,7 +225,7 @@ class SynapseRequest(Request):
             request=ContextRequest(
                 request_id=request_id,
                 ip_address=self.getClientIP(),
-                site_tag=self.site.site_tag,
+                site_tag=self.synapse_site.site_tag,
                 # The requester is going to be unknown at this point.
                 requester=None,
                 authenticated_entity=None,
@@ -228,7 +237,7 @@ class SynapseRequest(Request):
         )
 
         # override the Server header which is set by twisted
-        self.setHeader("Server", self.site.server_version_string)
+        self.setHeader("Server", self.synapse_site.server_version_string)
 
         with PreserveLoggingContext(self.logcontext):
             # we start the request metrics timer here with an initial stab
@@ -247,7 +256,7 @@ class SynapseRequest(Request):
             requests_counter.labels(self.get_method(), self.request_metrics.name).inc()
 
     @contextlib.contextmanager
-    def processing(self):
+    def processing(self) -> Generator[None, None, None]:
         """Record the fact that we are processing this request.
 
         Returns a context manager; the correct way to use this is:
@@ -346,10 +355,10 @@ class SynapseRequest(Request):
             self.start_time, name=servlet_name, method=self.get_method()
         )
 
-        self.site.access_logger.debug(
+        self.synapse_site.access_logger.debug(
             "%s - %s - Received request: %s %s",
             self.getClientIP(),
-            self.site.site_tag,
+            self.synapse_site.site_tag,
             self.get_method(),
             self.get_redacted_uri(),
         )
@@ -388,13 +397,13 @@ class SynapseRequest(Request):
         if authenticated_entity:
             requester = f"{authenticated_entity}|{requester}"
 
-        self.site.access_logger.log(
+        self.synapse_site.access_logger.log(
             log_level,
             "%s - %s - {%s}"
             " Processed request: %.3fsec/%.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)"
             ' %sB %s "%s %s %s" "%s" [%d dbevts]',
             self.getClientIP(),
-            self.site.site_tag,
+            self.synapse_site.site_tag,
             requester,
             processing_time,
             response_send_time,
@@ -522,7 +531,7 @@ class SynapseSite(Site):
         site_tag: str,
         config: ListenerConfig,
         resource: IResource,
-        server_version_string,
+        server_version_string: str,
         max_request_body_size: int,
         reactor: IReactorTime,
     ):
@@ -542,6 +551,7 @@ class SynapseSite(Site):
         Site.__init__(self, resource, reactor=reactor)
 
         self.site_tag = site_tag
+        self.reactor = reactor
 
         assert config.http_options is not None
         proxied = config.http_options.x_forwarded
@@ -550,6 +560,7 @@ class SynapseSite(Site):
         def request_factory(channel, queued: bool) -> Request:
             return request_class(
                 channel,
+                self,
                 max_request_body_size=max_request_body_size,
                 queued=queued,
             )