summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/9874.misc1
-rw-r--r--synapse/app/generic_worker.py1
-rw-r--r--synapse/app/homeserver.py25
-rw-r--r--synapse/http/site.py37
-rw-r--r--tests/replication/_base.py1
-rw-r--r--tests/test_server.py1
-rw-r--r--tests/unittest.py1
7 files changed, 43 insertions, 24 deletions
diff --git a/changelog.d/9874.misc b/changelog.d/9874.misc
new file mode 100644
index 0000000000..ba1097e65e
--- /dev/null
+++ b/changelog.d/9874.misc
@@ -0,0 +1 @@
+Pass a reactor into `SynapseSite` to make testing easier.
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 7b2ac3ca64..70e07d0574 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -367,6 +367,7 @@ class GenericWorkerServer(HomeServer):
                 listener_config,
                 root_resource,
                 self.version_string,
+                reactor=self.get_reactor(),
             ),
             reactor=self.get_reactor(),
         )
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 8be8b520eb..140f6bcdee 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -126,19 +126,20 @@ class SynapseHomeServer(HomeServer):
         else:
             root_resource = OptionsResource()
 
-        root_resource = create_resource_tree(resources, root_resource)
+        site = SynapseSite(
+            "synapse.access.%s.%s" % ("https" if tls else "http", site_tag),
+            site_tag,
+            listener_config,
+            create_resource_tree(resources, root_resource),
+            self.version_string,
+            reactor=self.get_reactor(),
+        )
 
         if tls:
             ports = listen_ssl(
                 bind_addresses,
                 port,
-                SynapseSite(
-                    "synapse.access.https.%s" % (site_tag,),
-                    site_tag,
-                    listener_config,
-                    root_resource,
-                    self.version_string,
-                ),
+                site,
                 self.tls_server_context_factory,
                 reactor=self.get_reactor(),
             )
@@ -148,13 +149,7 @@ class SynapseHomeServer(HomeServer):
             ports = listen_tcp(
                 bind_addresses,
                 port,
-                SynapseSite(
-                    "synapse.access.http.%s" % (site_tag,),
-                    site_tag,
-                    listener_config,
-                    root_resource,
-                    self.version_string,
-                ),
+                site,
                 reactor=self.get_reactor(),
             )
             logger.info("Synapse now listening on TCP port %d", port)
diff --git a/synapse/http/site.py b/synapse/http/site.py
index 32b5e19c09..e911ee4809 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -19,8 +19,9 @@ from typing import Optional, Tuple, Type, Union
 import attr
 from zope.interface import implementer
 
-from twisted.internet.interfaces import IAddress
+from twisted.internet.interfaces import IAddress, IReactorTime
 from twisted.python.failure import Failure
+from twisted.web.resource import IResource
 from twisted.web.server import Request, Site
 
 from synapse.config.server import ListenerConfig
@@ -485,21 +486,39 @@ class _XForwardedForAddress:
 
 class SynapseSite(Site):
     """
-    Subclass of a twisted http Site that does access logging with python's
-    standard logging
+    Synapse-specific twisted http Site
+
+    This does two main things.
+
+    First, it replaces the requestFactory in use so that we build SynapseRequests
+    instead of regular t.w.server.Requests. All of the  constructor params are really
+    just parameters for SynapseRequest.
+
+    Second, it inhibits the log() method called by Request.finish, since SynapseRequest
+    does its own logging.
     """
 
     def __init__(
         self,
-        logger_name,
-        site_tag,
+        logger_name: str,
+        site_tag: str,
         config: ListenerConfig,
-        resource,
+        resource: IResource,
         server_version_string,
-        *args,
-        **kwargs,
+        reactor: IReactorTime,
     ):
-        Site.__init__(self, resource, *args, **kwargs)
+        """
+
+        Args:
+            logger_name:  The name of the logger to use for access logs.
+            site_tag:  A tag to use for this site - mostly in access logs.
+            config:  Configuration for the HTTP listener corresponding to this site
+            resource:  The base of the resource tree to be used for serving requests on
+                this site
+            server_version_string: A string to present for the Server header
+            reactor: reactor to be used to manage connection timeouts
+        """
+        Site.__init__(self, resource, reactor=reactor)
 
         self.site_tag = site_tag
 
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index c9d04aef29..5cf58d8b60 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -349,6 +349,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
             config=worker_hs.config.server.listeners[0],
             resource=resource,
             server_version_string="1",
+            reactor=self.reactor,
         )
 
         if worker_hs.config.redis.redis_enabled:
diff --git a/tests/test_server.py b/tests/test_server.py
index 55cde7f62f..45400be367 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -202,6 +202,7 @@ class OptionsResourceTests(unittest.TestCase):
             parse_listener_def({"type": "http", "port": 0}),
             self.resource,
             "1.0",
+            reactor=self.reactor,
         )
 
         # render the request and return the channel
diff --git a/tests/unittest.py b/tests/unittest.py
index ee22a53849..5353e75c7c 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -247,6 +247,7 @@ class HomeserverTestCase(TestCase):
             config=self.hs.config.server.listeners[0],
             resource=self.resource,
             server_version_string="1",
+            reactor=self.reactor,
         )
 
         from tests.rest.client.v1.utils import RestHelper