summary refs log tree commit diff
path: root/synapse/app/generic_worker.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/app/generic_worker.py')
-rw-r--r--synapse/app/generic_worker.py15
1 files changed, 13 insertions, 2 deletions
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index dc0d3eb725..274d582d07 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -23,6 +23,7 @@ from typing_extensions import ContextManager
 
 from twisted.internet import address
 from twisted.web.resource import IResource
+from twisted.web.server import Request
 
 import synapse
 import synapse.events
@@ -190,7 +191,7 @@ class KeyUploadServlet(RestServlet):
         self.http_client = hs.get_simple_http_client()
         self.main_uri = hs.config.worker_main_http_uri
 
-    async def on_POST(self, request, device_id):
+    async def on_POST(self, request: Request, device_id: Optional[str]):
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         user_id = requester.user.to_string()
         body = parse_json_object_from_request(request)
@@ -223,10 +224,12 @@ class KeyUploadServlet(RestServlet):
                 header: request.requestHeaders.getRawHeaders(header, [])
                 for header in (b"Authorization", b"User-Agent")
             }
-            # Add the previous hop the the X-Forwarded-For header.
+            # Add the previous hop to the X-Forwarded-For header.
             x_forwarded_for = request.requestHeaders.getRawHeaders(
                 b"X-Forwarded-For", []
             )
+            # we use request.client here, since we want the previous hop, not the
+            # original client (as returned by request.getClientAddress()).
             if isinstance(request.client, (address.IPv4Address, address.IPv6Address)):
                 previous_host = request.client.host.encode("ascii")
                 # If the header exists, add to the comma-separated list of the first
@@ -239,6 +242,14 @@ class KeyUploadServlet(RestServlet):
                     x_forwarded_for = [previous_host]
             headers[b"X-Forwarded-For"] = x_forwarded_for
 
+            # Replicate the original X-Forwarded-Proto header. Note that
+            # XForwardedForRequest overrides isSecure() to give us the original protocol
+            # used by the client, as opposed to the protocol used by our upstream proxy
+            # - which is what we want here.
+            headers[b"X-Forwarded-Proto"] = [
+                b"https" if request.isSecure() else b"http"
+            ]
+
             try:
                 result = await self.http_client.post_json_get_json(
                     self.main_uri + request.uri.decode("ascii"), body, headers=headers