diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 6e2fbedd99..3e6a21e20f 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -354,6 +354,7 @@ class SsoRedirectServlet(RestServlet):
hs.get_oidc_handler()
self._sso_handler = hs.get_sso_handler()
self._msc2858_enabled = hs.config.experimental.msc2858_enabled
+ self._public_baseurl = hs.config.public_baseurl
def register(self, http_server: HttpServer) -> None:
super().register(http_server)
@@ -373,6 +374,28 @@ class SsoRedirectServlet(RestServlet):
async def on_GET(
self, request: SynapseRequest, idp_id: Optional[str] = None
) -> None:
+ if not self._public_baseurl:
+ raise SynapseError(400, "SSO requires a valid public_baseurl")
+
+ # if this isn't the expected hostname, redirect to the right one, so that we
+ # get our cookies back.
+ requested_uri = b"%s://%s%s" % (
+ b"https" if request.isSecure() else b"http",
+ request.getHeader(b"host"),
+ request.uri,
+ )
+ baseurl_bytes = self._public_baseurl.encode("utf-8")
+ if not requested_uri.startswith(baseurl_bytes):
+ i = requested_uri.index(b"/_matrix")
+ new_uri = baseurl_bytes[:-1] + requested_uri[i:]
+ logger.info(
+ "Requested URI %s is not canonical: redirecting to %s",
+ requested_uri.decode("utf-8", errors="replace"),
+ new_uri.decode("utf-8", errors="replace"),
+ )
+ request.redirect(new_uri)
+ finish_request(request)
+
client_redirect_url = parse_string(
request, "redirectUrl", required=True, encoding=None
)
|