diff --git a/changelog.d/6619.misc b/changelog.d/6619.misc
new file mode 100644
index 0000000000..b608133219
--- /dev/null
+++ b/changelog.d/6619.misc
@@ -0,0 +1 @@
+Simplify http handling by removing redundant SynapseRequestFactory.
diff --git a/synapse/http/site.py b/synapse/http/site.py
index ff8184a3d0..9f2d035fa0 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -47,9 +47,9 @@ class SynapseRequest(Request):
logcontext(LoggingContext) : the log context for this request
"""
- def __init__(self, site, channel, *args, **kw):
+ def __init__(self, channel, *args, **kw):
Request.__init__(self, channel, *args, **kw)
- self.site = site
+ self.site = channel.site
self._channel = channel # this is used by the tests
self.authenticated_entity = None
self.start_time = 0
@@ -331,18 +331,6 @@ class XForwardedForRequest(SynapseRequest):
)
-class SynapseRequestFactory(object):
- def __init__(self, site, x_forwarded_for):
- self.site = site
- self.x_forwarded_for = x_forwarded_for
-
- def __call__(self, *args, **kwargs):
- if self.x_forwarded_for:
- return XForwardedForRequest(self.site, *args, **kwargs)
- else:
- return SynapseRequest(self.site, *args, **kwargs)
-
-
class SynapseSite(Site):
"""
Subclass of a twisted http Site that does access logging with python's
@@ -364,7 +352,7 @@ class SynapseSite(Site):
self.site_tag = site_tag
proxied = config.get("x_forwarded", False)
- self.requestFactory = SynapseRequestFactory(self, proxied)
+ self.requestFactory = XForwardedForRequest if proxied else SynapseRequest
self.access_logger = logging.getLogger(logger_name)
self.server_version_string = server_version_string.encode("ascii")
diff --git a/tests/server.py b/tests/server.py
index a554dfdd57..1644710aa0 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -20,6 +20,7 @@ from twisted.python.failure import Failure
from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
from twisted.web.http import unquote
from twisted.web.http_headers import Headers
+from twisted.web.server import Site
from synapse.http.site import SynapseRequest
from synapse.util import Clock
@@ -42,6 +43,7 @@ class FakeChannel(object):
wire).
"""
+ site = attr.ib(type=Site)
_reactor = attr.ib()
result = attr.ib(default=attr.Factory(dict))
_producer = None
@@ -176,9 +178,9 @@ def make_request(
content = content.encode("utf8")
site = FakeSite()
- channel = FakeChannel(reactor)
+ channel = FakeChannel(site, reactor)
- req = request(site, channel)
+ req = request(channel)
req.process = lambda: b""
req.content = BytesIO(content)
req.postpath = list(map(unquote, path[1:].split(b"/")))
|