summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/http/site.py37
-rw-r--r--tests/logging/test_terse_json.py2
-rw-r--r--tests/rest/key/v2/test_remote_key_resource.py4
3 files changed, 26 insertions, 17 deletions
diff --git a/synapse/http/site.py b/synapse/http/site.py
index 4c1bf5d90e..55fe6aa3bd 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -14,13 +14,14 @@
 import contextlib
 import logging
 import time
-from typing import Optional, Tuple, Union
+from typing import Generator, Optional, Tuple, Union
 
 import attr
 from zope.interface import implementer
 
 from twisted.internet.interfaces import IAddress, IReactorTime
 from twisted.python.failure import Failure
+from twisted.web.http import HTTPChannel
 from twisted.web.resource import IResource, Resource
 from twisted.web.server import Request, Site
 
@@ -61,10 +62,17 @@ class SynapseRequest(Request):
         logcontext: the log context for this request
     """
 
-    def __init__(self, channel, *args, max_request_body_size: int = 1024, **kw):
-        Request.__init__(self, channel, *args, **kw)
+    def __init__(
+        self,
+        channel: HTTPChannel,
+        site: "SynapseSite",
+        *args,
+        max_request_body_size: int = 1024,
+        **kw,
+    ):
+        super().__init__(channel, *args, **kw)
         self._max_request_body_size = max_request_body_size
-        self.site: SynapseSite = channel.site
+        self.synapse_site = site
         self._channel = channel  # this is used by the tests
         self.start_time = 0.0
 
@@ -97,7 +105,7 @@ class SynapseRequest(Request):
             self.get_method(),
             self.get_redacted_uri(),
             self.clientproto.decode("ascii", errors="replace"),
-            self.site.site_tag,
+            self.synapse_site.site_tag,
         )
 
     def handleContentChunk(self, data: bytes) -> None:
@@ -216,7 +224,7 @@ class SynapseRequest(Request):
             request=ContextRequest(
                 request_id=request_id,
                 ip_address=self.getClientIP(),
-                site_tag=self.site.site_tag,
+                site_tag=self.synapse_site.site_tag,
                 # The requester is going to be unknown at this point.
                 requester=None,
                 authenticated_entity=None,
@@ -228,7 +236,7 @@ class SynapseRequest(Request):
         )
 
         # override the Server header which is set by twisted
-        self.setHeader("Server", self.site.server_version_string)
+        self.setHeader("Server", self.synapse_site.server_version_string)
 
         with PreserveLoggingContext(self.logcontext):
             # we start the request metrics timer here with an initial stab
@@ -247,7 +255,7 @@ class SynapseRequest(Request):
             requests_counter.labels(self.get_method(), self.request_metrics.name).inc()
 
     @contextlib.contextmanager
-    def processing(self):
+    def processing(self) -> Generator[None, None, None]:
         """Record the fact that we are processing this request.
 
         Returns a context manager; the correct way to use this is:
@@ -346,10 +354,10 @@ class SynapseRequest(Request):
             self.start_time, name=servlet_name, method=self.get_method()
         )
 
-        self.site.access_logger.debug(
+        self.synapse_site.access_logger.debug(
             "%s - %s - Received request: %s %s",
             self.getClientIP(),
-            self.site.site_tag,
+            self.synapse_site.site_tag,
             self.get_method(),
             self.get_redacted_uri(),
         )
@@ -388,13 +396,13 @@ class SynapseRequest(Request):
         if authenticated_entity:
             requester = f"{authenticated_entity}|{requester}"
 
-        self.site.access_logger.log(
+        self.synapse_site.access_logger.log(
             log_level,
             "%s - %s - {%s}"
             " Processed request: %.3fsec/%.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)"
             ' %sB %s "%s %s %s" "%s" [%d dbevts]',
             self.getClientIP(),
-            self.site.site_tag,
+            self.synapse_site.site_tag,
             requester,
             processing_time,
             response_send_time,
@@ -522,7 +530,7 @@ class SynapseSite(Site):
         site_tag: str,
         config: ListenerConfig,
         resource: IResource,
-        server_version_string,
+        server_version_string: str,
         max_request_body_size: int,
         reactor: IReactorTime,
     ):
@@ -547,9 +555,10 @@ class SynapseSite(Site):
         proxied = config.http_options.x_forwarded
         request_class = XForwardedForRequest if proxied else SynapseRequest
 
-        def request_factory(channel, queued) -> Request:
+        def request_factory(channel: HTTPChannel, queued: bool) -> Request:
             return request_class(
                 channel,
+                self,
                 max_request_body_size=max_request_body_size,
                 queued=queued,
             )
diff --git a/tests/logging/test_terse_json.py b/tests/logging/test_terse_json.py
index 1160716929..b93d69e86a 100644
--- a/tests/logging/test_terse_json.py
+++ b/tests/logging/test_terse_json.py
@@ -152,7 +152,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
         site = Mock(spec=["site_tag", "server_version_string", "getResourceFor"])
         site.site_tag = "test-site"
         site.server_version_string = "Server v1"
-        request = SynapseRequest(FakeChannel(site, None))
+        request = SynapseRequest(FakeChannel(site, None), site)
         # Call requestReceived to finish instantiating the object.
         request.content = BytesIO()
         # Partially skip some of the internal processing of SynapseRequest.
diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py
index a75c0ea3f0..4672a68596 100644
--- a/tests/rest/key/v2/test_remote_key_resource.py
+++ b/tests/rest/key/v2/test_remote_key_resource.py
@@ -84,7 +84,7 @@ class RemoteKeyResourceTestCase(BaseRemoteKeyResourceTestCase):
         Checks that the response is a 200 and returns the decoded json body.
         """
         channel = FakeChannel(self.site, self.reactor)
-        req = SynapseRequest(channel)
+        req = SynapseRequest(channel, self.site)
         req.content = BytesIO(b"")
         req.requestReceived(
             b"GET",
@@ -183,7 +183,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
             )
 
             channel = FakeChannel(self.site, self.reactor)
-            req = SynapseRequest(channel)
+            req = SynapseRequest(channel, self.site)
             req.content = BytesIO(encode_canonical_json(data))
 
             req.requestReceived(