summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--changelog.d/9436.bugfix1
-rw-r--r--synapse/http/__init__.py37
-rw-r--r--synapse/rest/client/v1/login.py28
-rw-r--r--tests/rest/client/v1/test_login.py61
-rw-r--r--tests/rest/client/v1/utils.py19
-rw-r--r--tests/rest/client/v2_alpha/test_auth.py6
-rw-r--r--tests/server.py6
7 files changed, 130 insertions, 28 deletions
diff --git a/changelog.d/9436.bugfix b/changelog.d/9436.bugfix
new file mode 100644
index 0000000000..a530516eed
--- /dev/null
+++ b/changelog.d/9436.bugfix
@@ -0,0 +1 @@
+Fix a bug in single sign-on which could cause a "No session cookie found" error.
diff --git a/synapse/http/__init__.py b/synapse/http/__init__.py
index c658862fe6..142b007d01 100644
--- a/synapse/http/__init__.py
+++ b/synapse/http/__init__.py
@@ -14,8 +14,9 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import re
+from typing import Union
 
-from twisted.internet import task
+from twisted.internet import address, task
 from twisted.web.client import FileBodyProducer
 from twisted.web.iweb import IRequest
 
@@ -53,6 +54,40 @@ class QuieterFileBodyProducer(FileBodyProducer):
             pass
 
 
+def get_request_uri(request: IRequest) -> bytes:
+    """Return the full URI that was requested by the client"""
+    return b"%s://%s%s" % (
+        b"https" if request.isSecure() else b"http",
+        _get_requested_host(request),
+        # despite its name, "request.uri" is only the path and query-string.
+        request.uri,
+    )
+
+
+def _get_requested_host(request: IRequest) -> bytes:
+    hostname = request.getHeader(b"host")
+    if hostname:
+        return hostname
+
+    # no Host header, use the address/port that the request arrived on
+    host = request.getHost()  # type: Union[address.IPv4Address, address.IPv6Address]
+
+    hostname = host.host.encode("ascii")
+
+    if request.isSecure() and host.port == 443:
+        # default port for https
+        return hostname
+
+    if not request.isSecure() and host.port == 80:
+        # default port for http
+        return hostname
+
+    return b"%s:%i" % (
+        hostname,
+        host.port,
+    )
+
+
 def get_request_user_agent(request: IRequest, default: str = "") -> str:
     """Return the last User-Agent header, or the given default."""
     # There could be raw utf-8 bytes in the User-Agent header.
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 6e2fbedd99..925edfc402 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -20,6 +20,7 @@ from synapse.api.errors import Codes, LoginError, SynapseError
 from synapse.api.ratelimiting import Ratelimiter
 from synapse.appservice import ApplicationService
 from synapse.handlers.sso import SsoIdentityProvider
+from synapse.http import get_request_uri
 from synapse.http.server import HttpServer, finish_request
 from synapse.http.servlet import (
     RestServlet,
@@ -354,6 +355,7 @@ class SsoRedirectServlet(RestServlet):
             hs.get_oidc_handler()
         self._sso_handler = hs.get_sso_handler()
         self._msc2858_enabled = hs.config.experimental.msc2858_enabled
+        self._public_baseurl = hs.config.public_baseurl
 
     def register(self, http_server: HttpServer) -> None:
         super().register(http_server)
@@ -373,6 +375,32 @@ class SsoRedirectServlet(RestServlet):
     async def on_GET(
         self, request: SynapseRequest, idp_id: Optional[str] = None
     ) -> None:
+        if not self._public_baseurl:
+            raise SynapseError(400, "SSO requires a valid public_baseurl")
+
+        # if this isn't the expected hostname, redirect to the right one, so that we
+        # get our cookies back.
+        requested_uri = get_request_uri(request)
+        baseurl_bytes = self._public_baseurl.encode("utf-8")
+        if not requested_uri.startswith(baseurl_bytes):
+            # swap out the incorrect base URL for the right one.
+            #
+            # The idea here is to redirect from
+            #    https://foo.bar/whatever/_matrix/...
+            # to
+            #    https://public.baseurl/_matrix/...
+            #
+            i = requested_uri.index(b"/_matrix")
+            new_uri = baseurl_bytes[:-1] + requested_uri[i:]
+            logger.info(
+                "Requested URI %s is not canonical: redirecting to %s",
+                requested_uri.decode("utf-8", errors="replace"),
+                new_uri.decode("utf-8", errors="replace"),
+            )
+            request.redirect(new_uri)
+            finish_request(request)
+            return
+
         client_redirect_url = parse_string(
             request, "redirectUrl", required=True, encoding=None
         )
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index fb29eaed6f..744d8d0941 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -15,7 +15,7 @@
 
 import time
 import urllib.parse
-from typing import Any, Dict, List, Union
+from typing import Any, Dict, List, Optional, Union
 from urllib.parse import urlencode
 
 from mock import Mock
@@ -47,8 +47,14 @@ except ImportError:
     HAS_JWT = False
 
 
-# public_base_url used in some tests
-BASE_URL = "https://synapse/"
+# synapse server name: used to populate public_baseurl in some tests
+SYNAPSE_SERVER_PUBLIC_HOSTNAME = "synapse"
+
+# public_baseurl for some tests. It uses an http:// scheme because
+# FakeChannel.isSecure() returns False, so synapse will see the requested uri as
+# http://..., so using http in the public_baseurl stops Synapse trying to redirect to
+# https://....
+BASE_URL = "http://%s/" % (SYNAPSE_SERVER_PUBLIC_HOSTNAME,)
 
 # CAS server used in some tests
 CAS_SERVER = "https://fake.test"
@@ -480,11 +486,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
     def test_multi_sso_redirect(self):
         """/login/sso/redirect should redirect to an identity picker"""
         # first hit the redirect url, which should redirect to our idp picker
-        channel = self.make_request(
-            "GET",
-            "/_matrix/client/r0/login/sso/redirect?redirectUrl="
-            + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
-        )
+        channel = self._make_sso_redirect_request(False, None)
         self.assertEqual(channel.code, 302, channel.result)
         uri = channel.headers.getRawHeaders("Location")[0]
 
@@ -628,34 +630,21 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
 
     def test_client_idp_redirect_msc2858_disabled(self):
         """If the client tries to pick an IdP but MSC2858 is disabled, return a 400"""
-        channel = self.make_request(
-            "GET",
-            "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/oidc?redirectUrl="
-            + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
-        )
+        channel = self._make_sso_redirect_request(True, "oidc")
         self.assertEqual(channel.code, 400, channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
 
     @override_config({"experimental_features": {"msc2858_enabled": True}})
     def test_client_idp_redirect_to_unknown(self):
         """If the client tries to pick an unknown IdP, return a 404"""
-        channel = self.make_request(
-            "GET",
-            "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/xxx?redirectUrl="
-            + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
-        )
+        channel = self._make_sso_redirect_request(True, "xxx")
         self.assertEqual(channel.code, 404, channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
 
     @override_config({"experimental_features": {"msc2858_enabled": True}})
     def test_client_idp_redirect_to_oidc(self):
         """If the client pick a known IdP, redirect to it"""
-        channel = self.make_request(
-            "GET",
-            "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/oidc?redirectUrl="
-            + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
-        )
-
+        channel = self._make_sso_redirect_request(True, "oidc")
         self.assertEqual(channel.code, 302, channel.result)
         oidc_uri = channel.headers.getRawHeaders("Location")[0]
         oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
@@ -663,6 +652,30 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
         # it should redirect us to the auth page of the OIDC server
         self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
 
+    def _make_sso_redirect_request(
+        self, unstable_endpoint: bool = False, idp_prov: Optional[str] = None
+    ):
+        """Send a request to /_matrix/client/r0/login/sso/redirect
+
+        ... or the unstable equivalent
+
+        ... possibly specifying an IDP provider
+        """
+        endpoint = (
+            "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect"
+            if unstable_endpoint
+            else "/_matrix/client/r0/login/sso/redirect"
+        )
+        if idp_prov is not None:
+            endpoint += "/" + idp_prov
+        endpoint += "?redirectUrl=" + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
+
+        return self.make_request(
+            "GET",
+            endpoint,
+            custom_headers=[("Host", SYNAPSE_SERVER_PUBLIC_HOSTNAME)],
+        )
+
     @staticmethod
     def _get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str:
         prefix = key + " = "
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 8231a423f3..946740aa5d 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -542,13 +542,30 @@ class RestHelper:
         if client_redirect_url:
             params["redirectUrl"] = client_redirect_url
 
-        # hit the redirect url (which will issue a cookie and state)
+        # hit the redirect url (which should redirect back to the redirect url. This
+        # is the easiest way of figuring out what the Host header ought to be set to
+        # to keep Synapse happy.
         channel = make_request(
             self.hs.get_reactor(),
             self.site,
             "GET",
             "/_matrix/client/r0/login/sso/redirect?" + urllib.parse.urlencode(params),
         )
+        assert channel.code == 302
+
+        # hit the redirect url again with the right Host header, which should now issue
+        # a cookie and redirect to the SSO provider.
+        location = channel.headers.getRawHeaders("Location")[0]
+        parts = urllib.parse.urlsplit(location)
+        channel = make_request(
+            self.hs.get_reactor(),
+            self.site,
+            "GET",
+            urllib.parse.urlunsplit(("", "") + parts[2:]),
+            custom_headers=[
+                ("Host", parts[1]),
+            ],
+        )
 
         assert channel.code == 302
         channel.extract_cookies(cookies)
diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py
index c26ad824f7..9734a2159a 100644
--- a/tests/rest/client/v2_alpha/test_auth.py
+++ b/tests/rest/client/v2_alpha/test_auth.py
@@ -161,7 +161,11 @@ class UIAuthTests(unittest.HomeserverTestCase):
 
     def default_config(self):
         config = super().default_config()
-        config["public_baseurl"] = "https://synapse.test"
+
+        # public_baseurl uses an http:// scheme because FakeChannel.isSecure() returns
+        # False, so synapse will see the requested uri as http://..., so using http in
+        # the public_baseurl stops Synapse trying to redirect to https.
+        config["public_baseurl"] = "http://synapse.test"
 
         if HAS_OIDC:
             # we enable OIDC as a way of testing SSO flows
diff --git a/tests/server.py b/tests/server.py
index d4ece5c448..939a0008ca 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -124,7 +124,11 @@ class FakeChannel:
         return address.IPv4Address("TCP", self._ip, 3423)
 
     def getHost(self):
-        return None
+        # this is called by Request.__init__ to configure Request.host.
+        return address.IPv4Address("TCP", "127.0.0.1", 8888)
+
+    def isSecure(self):
+        return False
 
     @property
     def transport(self):