summary refs log tree commit diff
diff options
context:
space:
mode:
-rwxr-xr-xsynapse/app/homeserver.py31
-rw-r--r--synapse/config/server.py2
-rw-r--r--synapse/server.py12
3 files changed, 31 insertions, 14 deletions
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 4228bac673..12da0bc4b5 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -34,7 +34,7 @@ from twisted.application import service
 from twisted.enterprise import adbapi
 from twisted.web.resource import Resource, EncodingResourceWrapper
 from twisted.web.static import File
-from twisted.web.server import Site, GzipEncoderFactory
+from twisted.web.server import Site, GzipEncoderFactory, Request
 from twisted.web.http import proxiedLogFormatter, combinedLogFormatter
 from synapse.http.server import JsonResource, RootRedirect
 from synapse.rest.media.v0.content_repository import ContentRepoResource
@@ -199,7 +199,7 @@ class SynapseHomeServer(HomeServer):
                 port,
                 SynapseSite(
                     "synapse.access.https",
-                    config,
+                    listener_config,
                     root_resource,
                 ),
                 self.tls_context_factory,
@@ -210,7 +210,7 @@ class SynapseHomeServer(HomeServer):
                 port,
                 SynapseSite(
                     "synapse.access.https",
-                    config,
+                    listener_config,
                     root_resource,
                 ),
                 interface=bind_address
@@ -441,6 +441,28 @@ class SynapseService(service.Service):
         return self._port.stopListening()
 
 
+class XForwardedForRequest(Request):
+    def __init__(self, *args, **kw):
+        Request.__init__(self, *args, **kw)
+
+    """
+    Add a layer on top of another request that only uses the value of an
+    X-Forwarded-For header as the result of C{getClientIP}.
+    """
+    def getClientIP(self):
+        """
+        @return: The client address (the first address) in the value of the
+            I{X-Forwarded-For header}.  If the header is not present, return
+            C{b"-"}.
+        """
+        return self.requestHeaders.getRawHeaders(
+            b"x-forwarded-for", [b"-"])[0].split(b",")[0].strip()
+
+
+def XForwardedFactory(*args, **kwargs):
+    return XForwardedForRequest(*args, **kwargs)
+
+
 class SynapseSite(Site):
     """
     Subclass of a twisted http Site that does access logging with python's
@@ -448,7 +470,8 @@ class SynapseSite(Site):
     """
     def __init__(self, logger_name, config, resource, *args, **kwargs):
         Site.__init__(self, resource, *args, **kwargs)
-        if config.captcha_ip_origin_is_x_forwarded:
+        if config.get("x_forwarded", False):
+            self.requestFactory = XForwardedFactory
             self._log_formatter = proxiedLogFormatter
         else:
             self._log_formatter = combinedLogFormatter
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 26017c7efa..9dab167b21 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -157,6 +157,8 @@ class ServerConfig(Config):
             bind_address: ''
             type: http
 
+            x_forwarded: False
+
             resources:
               - names: [client, webclient]
                 compress: true
diff --git a/synapse/server.py b/synapse/server.py
index 8b3dc675cc..4d1fb1cbf6 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -132,16 +132,8 @@ class BaseHomeServer(object):
         setattr(BaseHomeServer, "get_%s" % (depname), _get)
 
     def get_ip_from_request(self, request):
-        # May be an X-Forwarding-For header depending on config
-        ip_addr = request.getClientIP()
-        if self.config.captcha_ip_origin_is_x_forwarded:
-            # use the header
-            if request.requestHeaders.hasHeader("X-Forwarded-For"):
-                ip_addr = request.requestHeaders.getRawHeaders(
-                    "X-Forwarded-For"
-                )[0]
-
-        return ip_addr
+        # X-Forwarded-For is handled by our custom request type.
+        return request.getClientIP()
 
     def is_mine(self, domain_specific_string):
         return domain_specific_string.domain == self.hostname