summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/http/__init__.py37
-rw-r--r--synapse/rest/client/v1/login.py28
2 files changed, 64 insertions, 1 deletions
diff --git a/synapse/http/__init__.py b/synapse/http/__init__.py
index c658862fe6..142b007d01 100644
--- a/synapse/http/__init__.py
+++ b/synapse/http/__init__.py
@@ -14,8 +14,9 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import re
+from typing import Union
 
-from twisted.internet import task
+from twisted.internet import address, task
 from twisted.web.client import FileBodyProducer
 from twisted.web.iweb import IRequest
 
@@ -53,6 +54,40 @@ class QuieterFileBodyProducer(FileBodyProducer):
             pass
 
 
+def get_request_uri(request: IRequest) -> bytes:
+    """Return the full URI that was requested by the client"""
+    return b"%s://%s%s" % (
+        b"https" if request.isSecure() else b"http",
+        _get_requested_host(request),
+        # despite its name, "request.uri" is only the path and query-string.
+        request.uri,
+    )
+
+
+def _get_requested_host(request: IRequest) -> bytes:
+    hostname = request.getHeader(b"host")
+    if hostname:
+        return hostname
+
+    # no Host header, use the address/port that the request arrived on
+    host = request.getHost()  # type: Union[address.IPv4Address, address.IPv6Address]
+
+    hostname = host.host.encode("ascii")
+
+    if request.isSecure() and host.port == 443:
+        # default port for https
+        return hostname
+
+    if not request.isSecure() and host.port == 80:
+        # default port for http
+        return hostname
+
+    return b"%s:%i" % (
+        hostname,
+        host.port,
+    )
+
+
 def get_request_user_agent(request: IRequest, default: str = "") -> str:
     """Return the last User-Agent header, or the given default."""
     # There could be raw utf-8 bytes in the User-Agent header.
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 6e2fbedd99..925edfc402 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -20,6 +20,7 @@ from synapse.api.errors import Codes, LoginError, SynapseError
 from synapse.api.ratelimiting import Ratelimiter
 from synapse.appservice import ApplicationService
 from synapse.handlers.sso import SsoIdentityProvider
+from synapse.http import get_request_uri
 from synapse.http.server import HttpServer, finish_request
 from synapse.http.servlet import (
     RestServlet,
@@ -354,6 +355,7 @@ class SsoRedirectServlet(RestServlet):
             hs.get_oidc_handler()
         self._sso_handler = hs.get_sso_handler()
         self._msc2858_enabled = hs.config.experimental.msc2858_enabled
+        self._public_baseurl = hs.config.public_baseurl
 
     def register(self, http_server: HttpServer) -> None:
         super().register(http_server)
@@ -373,6 +375,32 @@ class SsoRedirectServlet(RestServlet):
     async def on_GET(
         self, request: SynapseRequest, idp_id: Optional[str] = None
     ) -> None:
+        if not self._public_baseurl:
+            raise SynapseError(400, "SSO requires a valid public_baseurl")
+
+        # if this isn't the expected hostname, redirect to the right one, so that we
+        # get our cookies back.
+        requested_uri = get_request_uri(request)
+        baseurl_bytes = self._public_baseurl.encode("utf-8")
+        if not requested_uri.startswith(baseurl_bytes):
+            # swap out the incorrect base URL for the right one.
+            #
+            # The idea here is to redirect from
+            #    https://foo.bar/whatever/_matrix/...
+            # to
+            #    https://public.baseurl/_matrix/...
+            #
+            i = requested_uri.index(b"/_matrix")
+            new_uri = baseurl_bytes[:-1] + requested_uri[i:]
+            logger.info(
+                "Requested URI %s is not canonical: redirecting to %s",
+                requested_uri.decode("utf-8", errors="replace"),
+                new_uri.decode("utf-8", errors="replace"),
+            )
+            request.redirect(new_uri)
+            finish_request(request)
+            return
+
         client_redirect_url = parse_string(
             request, "redirectUrl", required=True, encoding=None
         )