summary refs log tree commit diff
path: root/synapse/http
diff options
context:
space:
mode:
authorRichard van der Hoff <1389908+richvdh@users.noreply.github.com>2021-02-26 14:02:06 +0000
committerGitHub <noreply@github.com>2021-02-26 14:02:06 +0000
commit15090de85075c9d7d54479b4bfd79057de64059b (patch)
tree70ba6e557818865d128d77839227e6331942df7d /synapse/http
parentCall out the need for an X-Forwarded-Proto in the upgrade notes (#9501) (diff)
downloadsynapse-15090de85075c9d7d54479b4bfd79057de64059b.tar.xz
SSO: redirect to public URL before setting cookies (#9436)
... otherwise, we don't get the cookie back.
Diffstat (limited to 'synapse/http')
-rw-r--r--synapse/http/__init__.py37
1 files changed, 36 insertions, 1 deletions
diff --git a/synapse/http/__init__.py b/synapse/http/__init__.py
index c658862fe6..142b007d01 100644
--- a/synapse/http/__init__.py
+++ b/synapse/http/__init__.py
@@ -14,8 +14,9 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import re
+from typing import Union
 
-from twisted.internet import task
+from twisted.internet import address, task
 from twisted.web.client import FileBodyProducer
 from twisted.web.iweb import IRequest
 
@@ -53,6 +54,40 @@ class QuieterFileBodyProducer(FileBodyProducer):
             pass
 
 
+def get_request_uri(request: IRequest) -> bytes:
+    """Return the full URI that was requested by the client"""
+    return b"%s://%s%s" % (
+        b"https" if request.isSecure() else b"http",
+        _get_requested_host(request),
+        # despite its name, "request.uri" is only the path and query-string.
+        request.uri,
+    )
+
+
+def _get_requested_host(request: IRequest) -> bytes:
+    hostname = request.getHeader(b"host")
+    if hostname:
+        return hostname
+
+    # no Host header, use the address/port that the request arrived on
+    host = request.getHost()  # type: Union[address.IPv4Address, address.IPv6Address]
+
+    hostname = host.host.encode("ascii")
+
+    if request.isSecure() and host.port == 443:
+        # default port for https
+        return hostname
+
+    if not request.isSecure() and host.port == 80:
+        # default port for http
+        return hostname
+
+    return b"%s:%i" % (
+        hostname,
+        host.port,
+    )
+
+
 def get_request_user_agent(request: IRequest, default: str = "") -> str:
     """Return the last User-Agent header, or the given default."""
     # There could be raw utf-8 bytes in the User-Agent header.