summary refs log tree commit diff
path: root/synapse/http/site.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/http/site.py')
-rw-r--r--synapse/http/site.py85
1 files changed, 70 insertions, 15 deletions
diff --git a/synapse/http/site.py b/synapse/http/site.py
index 4a4fb5ef26..30153237e3 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -16,6 +16,10 @@ import logging
 import time
 from typing import Optional, Union
 
+import attr
+from zope.interface import implementer
+
+from twisted.internet.interfaces import IAddress
 from twisted.python.failure import Failure
 from twisted.web.server import Request, Site
 
@@ -333,26 +337,77 @@ class SynapseRequest(Request):
 
 
 class XForwardedForRequest(SynapseRequest):
-    def __init__(self, *args, **kw):
-        SynapseRequest.__init__(self, *args, **kw)
+    """Request object which honours proxy headers
 
+    Extends SynapseRequest to replace getClientIP, getClientAddress, and isSecure with
+    information from request headers.
     """
-    Add a layer on top of another request that only uses the value of an
-    X-Forwarded-For header as the result of C{getClientIP}.
-    """
 
-    def getClientIP(self):
+    # the client IP and ssl flag, as extracted from the headers.
+    _forwarded_for = None  # type: Optional[_XForwardedForAddress]
+    _forwarded_https = False  # type: bool
+
+    def requestReceived(self, command, path, version):
+        # this method is called by the Channel once the full request has been
+        # received, to dispatch the request to a resource.
+        # We can use it to set the IP address and protocol according to the
+        # headers.
+        self._process_forwarded_headers()
+        return super().requestReceived(command, path, version)
+
+    def _process_forwarded_headers(self):
+        headers = self.requestHeaders.getRawHeaders(b"x-forwarded-for")
+        if not headers:
+            return
+
+        # for now, we just use the first x-forwarded-for header. Really, we ought
+        # to start from the client IP address, and check whether it is trusted; if it
+        # is, work backwards through the headers until we find an untrusted address.
+        # see https://github.com/matrix-org/synapse/issues/9471
+        self._forwarded_for = _XForwardedForAddress(
+            headers[0].split(b",")[0].strip().decode("ascii")
+        )
+
+        # if we got an x-forwarded-for header, also look for an x-forwarded-proto header
+        header = self.getHeader(b"x-forwarded-proto")
+        if header is not None:
+            self._forwarded_https = header.lower() == b"https"
+        else:
+            # this is done largely for backwards-compatibility so that people that
+            # haven't set an x-forwarded-proto header don't get a redirect loop.
+            logger.warning(
+                "forwarded request lacks an x-forwarded-proto header: assuming https"
+            )
+            self._forwarded_https = True
+
+    def isSecure(self):
+        if self._forwarded_https:
+            return True
+        return super().isSecure()
+
+    def getClientIP(self) -> str:
         """
-        @return: The client address (the first address) in the value of the
-            I{X-Forwarded-For header}.  If the header is not present, return
-            C{b"-"}.
+        Return the IP address of the client who submitted this request.
+
+        This method is deprecated.  Use getClientAddress() instead.
         """
-        return (
-            self.requestHeaders.getRawHeaders(b"x-forwarded-for", [b"-"])[0]
-            .split(b",")[0]
-            .strip()
-            .decode("ascii")
-        )
+        if self._forwarded_for is not None:
+            return self._forwarded_for.host
+        return super().getClientIP()
+
+    def getClientAddress(self) -> IAddress:
+        """
+        Return the address of the client who submitted this request.
+        """
+        if self._forwarded_for is not None:
+            return self._forwarded_for
+        return super().getClientAddress()
+
+
+@implementer(IAddress)
+@attr.s(frozen=True, slots=True)
+class _XForwardedForAddress:
+    host = attr.ib(type=str)
 
 
 class SynapseSite(Site):