summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2021-01-15 14:05:55 +0000
committerErik Johnston <erik@matrix.org>2021-01-15 14:05:55 +0000
commit029c9ef967e214cbee7abb3a976031aa1d95970c (patch)
tree1660b11854b6a4e8f4b82cb2cd839b61366ee865
parentMerge remote-tracking branch 'origin/develop' into matrix-org-hotfixes (diff)
parentImprove UsernamePickerTestCase (#9112) (diff)
downloadsynapse-029c9ef967e214cbee7abb3a976031aa1d95970c.tar.xz
Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes
-rw-r--r--.gitignore1
-rw-r--r--INSTALL.md3
-rw-r--r--README.rst21
-rw-r--r--changelog.d/8997.doc1
-rw-r--r--changelog.d/9091.feature1
-rw-r--r--changelog.d/9109.feature1
-rw-r--r--changelog.d/9112.misc1
-rw-r--r--changelog.d/9114.bugfix2
-rw-r--r--changelog.d/9116.bugfix1
-rw-r--r--changelog.d/9118.misc1
-rw-r--r--docs/sample_config.yaml8
-rw-r--r--synapse/config/oidc_config.py26
-rw-r--r--synapse/config/sso.py10
-rw-r--r--synapse/handlers/auth.py25
-rw-r--r--synapse/handlers/oidc_handler.py22
-rw-r--r--synapse/handlers/sso.py45
-rw-r--r--synapse/res/templates/sso_auth_bad_user.html18
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py41
-rw-r--r--synapse/util/iterutils.py2
-rw-r--r--tests/handlers/test_oidc.py123
-rw-r--r--tests/rest/client/v1/test_login.py105
-rw-r--r--tests/rest/client/v1/utils.py11
-rw-r--r--tests/rest/client/v2_alpha/test_auth.py29
-rw-r--r--tests/util/test_itertools.py8
24 files changed, 324 insertions, 182 deletions
diff --git a/.gitignore b/.gitignore
index 2bccf19997..2cef1b0a5a 100644
--- a/.gitignore
+++ b/.gitignore
@@ -12,6 +12,7 @@
 _trial_temp/
 _trial_temp*/
 /out
+.DS_Store
 
 # stuff that is likely to exist when you run a server locally
 /*.db
diff --git a/INSTALL.md b/INSTALL.md
index 656833637c..d405d9fe55 100644
--- a/INSTALL.md
+++ b/INSTALL.md
@@ -190,7 +190,8 @@ via brew and inform `pip` about it so that `psycopg2` builds:
 
 ```sh
 brew install openssl@1.1
-export LDFLAGS=-L/usr/local/Cellar/openssl\@1.1/1.1.1d/lib/
+export LDFLAGS="-L/usr/local/opt/openssl/lib"
+export CPPFLAGS="-I/usr/local/opt/openssl/include"
 ```
 
 ##### OpenSUSE
diff --git a/README.rst b/README.rst
index 9ff375708b..af914d71a8 100644
--- a/README.rst
+++ b/README.rst
@@ -280,6 +280,27 @@ differ)::
 
     PASSED (skips=15, successes=1322)
 
+We recommend using the demo which starts 3 federated instances running on ports `8080` - `8082`
+
+    ./demo/start.sh
+
+(to stop, you can use `./demo/stop.sh`)
+
+If you just want to start a single instance of the app and run it directly:
+
+    # Create the homeserver.yaml config once
+    python -m synapse.app.homeserver \
+      --server-name my.domain.name \
+      --config-path homeserver.yaml \
+      --generate-config \
+      --report-stats=[yes|no]
+
+    # Start the app
+    python -m synapse.app.homeserver --config-path homeserver.yaml
+
+
+
+
 Running the Integration Tests
 =============================
 
diff --git a/changelog.d/8997.doc b/changelog.d/8997.doc
new file mode 100644
index 0000000000..dd1a882301
--- /dev/null
+++ b/changelog.d/8997.doc
@@ -0,0 +1 @@
+Add some extra docs for getting Synapse running on macOS.
diff --git a/changelog.d/9091.feature b/changelog.d/9091.feature
new file mode 100644
index 0000000000..79fcd701f8
--- /dev/null
+++ b/changelog.d/9091.feature
@@ -0,0 +1 @@
+During user-interactive authentication via single-sign-on, give a better error if the user uses the wrong account on the SSO IdP.
diff --git a/changelog.d/9109.feature b/changelog.d/9109.feature
new file mode 100644
index 0000000000..01a24dcf49
--- /dev/null
+++ b/changelog.d/9109.feature
@@ -0,0 +1 @@
+Add support for multiple SSO Identity Providers.
diff --git a/changelog.d/9112.misc b/changelog.d/9112.misc
new file mode 100644
index 0000000000..691f9d8b43
--- /dev/null
+++ b/changelog.d/9112.misc
@@ -0,0 +1 @@
+Improve `UsernamePickerTestCase`.
diff --git a/changelog.d/9114.bugfix b/changelog.d/9114.bugfix
index 77112abd5c..211f26589d 100644
--- a/changelog.d/9114.bugfix
+++ b/changelog.d/9114.bugfix
@@ -1 +1 @@
-Fix bug in federation catchup logic that caused outbound federation to be delayed for large servers after start up. Introduced in v1.21.0.
+Fix bugs in federation catchup logic that caused outbound federation to be delayed for large servers after start up. Introduced in v1.8.0 and v1.21.0.
diff --git a/changelog.d/9116.bugfix b/changelog.d/9116.bugfix
new file mode 100644
index 0000000000..211f26589d
--- /dev/null
+++ b/changelog.d/9116.bugfix
@@ -0,0 +1 @@
+Fix bugs in federation catchup logic that caused outbound federation to be delayed for large servers after start up. Introduced in v1.8.0 and v1.21.0.
diff --git a/changelog.d/9118.misc b/changelog.d/9118.misc
new file mode 100644
index 0000000000..346741d982
--- /dev/null
+++ b/changelog.d/9118.misc
@@ -0,0 +1 @@
+Improve efficiency of large state resolutions.
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index c8ae46d1b3..9da351f9f3 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -1969,6 +1969,14 @@ sso:
     #
     #   This template has no additional variables.
     #
+    # * HTML page shown after a user-interactive authentication session which
+    #   does not map correctly onto the expected user: 'sso_auth_bad_user.html'.
+    #
+    #   When rendering, this template is given the following variables:
+    #     * server_name: the homeserver's name.
+    #     * user_id_to_verify: the MXID of the user that we are trying to
+    #       validate.
+    #
     # * HTML page shown during single sign-on if a deactivated user (according to Synapse's database)
     #   attempts to login: 'sso_account_deactivated.html'.
     #
diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py
index c705de5694..fddca19223 100644
--- a/synapse/config/oidc_config.py
+++ b/synapse/config/oidc_config.py
@@ -1,6 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2020 Quentin Gliech
-# Copyright 2020 The Matrix.org Foundation C.I.C.
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -14,6 +14,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import string
 from typing import Optional, Type
 
 import attr
@@ -38,7 +39,7 @@ class OIDCConfig(Config):
 
         oidc_config = config.get("oidc_config")
         if oidc_config and oidc_config.get("enabled", False):
-            validate_config(OIDC_PROVIDER_CONFIG_SCHEMA, oidc_config, "oidc_config")
+            validate_config(OIDC_PROVIDER_CONFIG_SCHEMA, oidc_config, ("oidc_config",))
             self.oidc_provider = _parse_oidc_config_dict(oidc_config)
 
         if not self.oidc_provider:
@@ -205,6 +206,8 @@ OIDC_PROVIDER_CONFIG_SCHEMA = {
     "type": "object",
     "required": ["issuer", "client_id", "client_secret"],
     "properties": {
+        "idp_id": {"type": "string", "minLength": 1, "maxLength": 128},
+        "idp_name": {"type": "string"},
         "discover": {"type": "boolean"},
         "issuer": {"type": "string"},
         "client_id": {"type": "string"},
@@ -277,7 +280,17 @@ def _parse_oidc_config_dict(oidc_config: JsonDict) -> "OidcProviderConfig":
             "methods: %s" % (", ".join(missing_methods),)
         )
 
+    # MSC2858 will appy certain limits in what can be used as an IdP id, so let's
+    # enforce those limits now.
+    idp_id = oidc_config.get("idp_id", "oidc")
+    valid_idp_chars = set(string.ascii_letters + string.digits + "-._~")
+
+    if any(c not in valid_idp_chars for c in idp_id):
+        raise ConfigError('idp_id may only contain A-Z, a-z, 0-9, "-", ".", "_", "~"')
+
     return OidcProviderConfig(
+        idp_id=idp_id,
+        idp_name=oidc_config.get("idp_name", "OIDC"),
         discover=oidc_config.get("discover", True),
         issuer=oidc_config["issuer"],
         client_id=oidc_config["client_id"],
@@ -296,8 +309,15 @@ def _parse_oidc_config_dict(oidc_config: JsonDict) -> "OidcProviderConfig":
     )
 
 
-@attr.s
+@attr.s(slots=True, frozen=True)
 class OidcProviderConfig:
+    # a unique identifier for this identity provider. Used in the 'user_external_ids'
+    # table, as well as the query/path parameter used in the login protocol.
+    idp_id = attr.ib(type=str)
+
+    # user-facing name for this identity provider.
+    idp_name = attr.ib(type=str)
+
     # whether the OIDC discovery mechanism is used to discover endpoints
     discover = attr.ib(type=bool)
 
diff --git a/synapse/config/sso.py b/synapse/config/sso.py
index 1aeb1c5c92..366f0d4698 100644
--- a/synapse/config/sso.py
+++ b/synapse/config/sso.py
@@ -37,6 +37,7 @@ class SSOConfig(Config):
             self.sso_error_template,
             sso_account_deactivated_template,
             sso_auth_success_template,
+            self.sso_auth_bad_user_template,
         ) = self.read_templates(
             [
                 "sso_login_idp_picker.html",
@@ -45,6 +46,7 @@ class SSOConfig(Config):
                 "sso_error.html",
                 "sso_account_deactivated.html",
                 "sso_auth_success.html",
+                "sso_auth_bad_user.html",
             ],
             template_dir,
         )
@@ -160,6 +162,14 @@ class SSOConfig(Config):
             #
             #   This template has no additional variables.
             #
+            # * HTML page shown after a user-interactive authentication session which
+            #   does not map correctly onto the expected user: 'sso_auth_bad_user.html'.
+            #
+            #   When rendering, this template is given the following variables:
+            #     * server_name: the homeserver's name.
+            #     * user_id_to_verify: the MXID of the user that we are trying to
+            #       validate.
+            #
             # * HTML page shown during single sign-on if a deactivated user (according to Synapse's database)
             #   attempts to login: 'sso_account_deactivated.html'.
             #
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 4f881a439a..18cd2b62f0 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -263,10 +263,6 @@ class AuthHandler(BaseHandler):
         # authenticating for an operation to occur on their account.
         self._sso_auth_confirm_template = hs.config.sso_auth_confirm_template
 
-        # The following template is shown after a successful user interactive
-        # authentication session. It tells the user they can close the window.
-        self._sso_auth_success_template = hs.config.sso_auth_success_template
-
         # The following template is shown during the SSO authentication process if
         # the account is deactivated.
         self._sso_account_deactivated_template = (
@@ -1394,27 +1390,6 @@ class AuthHandler(BaseHandler):
             description=session.description, redirect_url=redirect_url,
         )
 
-    async def complete_sso_ui_auth(
-        self, registered_user_id: str, session_id: str, request: Request,
-    ):
-        """Having figured out a mxid for this user, complete the HTTP request
-
-        Args:
-            registered_user_id: The registered user ID to complete SSO login for.
-            session_id: The ID of the user-interactive auth session.
-            request: The request to complete.
-        """
-        # Mark the stage of the authentication as successful.
-        # Save the user who authenticated with SSO, this will be used to ensure
-        # that the account be modified is also the person who logged in.
-        await self.store.mark_ui_auth_stage_complete(
-            session_id, LoginType.SSO, registered_user_id
-        )
-
-        # Render the HTML and return.
-        html = self._sso_auth_success_template
-        respond_with_html(request, 200, html)
-
     async def complete_sso_login(
         self,
         registered_user_id: str,
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index d6347bb1b8..f63a90ec5c 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -175,7 +175,7 @@ class OidcHandler:
             session_data = self._token_generator.verify_oidc_session_token(
                 session, state
             )
-        except MacaroonDeserializationException as e:
+        except (MacaroonDeserializationException, ValueError) as e:
             logger.exception("Invalid session")
             self._sso_handler.render_error(request, "invalid_session", str(e))
             return
@@ -253,10 +253,10 @@ class OidcProvider:
         self._server_name = hs.config.server_name  # type: str
 
         # identifier for the external_ids table
-        self.idp_id = "oidc"
+        self.idp_id = provider.idp_id
 
         # user-facing name of this auth provider
-        self.idp_name = "OIDC"
+        self.idp_name = provider.idp_name
 
         self._sso_handler = hs.get_sso_handler()
 
@@ -656,6 +656,7 @@ class OidcProvider:
         cookie = self._token_generator.generate_oidc_session_token(
             state=state,
             session_data=OidcSessionData(
+                idp_id=self.idp_id,
                 nonce=nonce,
                 client_redirect_url=client_redirect_url.decode(),
                 ui_auth_session_id=ui_auth_session_id,
@@ -924,6 +925,7 @@ class OidcSessionTokenGenerator:
         macaroon.add_first_party_caveat("gen = 1")
         macaroon.add_first_party_caveat("type = session")
         macaroon.add_first_party_caveat("state = %s" % (state,))
+        macaroon.add_first_party_caveat("idp_id = %s" % (session_data.idp_id,))
         macaroon.add_first_party_caveat("nonce = %s" % (session_data.nonce,))
         macaroon.add_first_party_caveat(
             "client_redirect_url = %s" % (session_data.client_redirect_url,)
@@ -952,6 +954,9 @@ class OidcSessionTokenGenerator:
 
         Returns:
             The data extracted from the session cookie
+
+        Raises:
+            ValueError if an expected caveat is missing from the macaroon.
         """
         macaroon = pymacaroons.Macaroon.deserialize(session)
 
@@ -960,6 +965,7 @@ class OidcSessionTokenGenerator:
         v.satisfy_exact("type = session")
         v.satisfy_exact("state = %s" % (state,))
         v.satisfy_general(lambda c: c.startswith("nonce = "))
+        v.satisfy_general(lambda c: c.startswith("idp_id = "))
         v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
         # Sometimes there's a UI auth session ID, it seems to be OK to attempt
         # to always satisfy this.
@@ -968,9 +974,9 @@ class OidcSessionTokenGenerator:
 
         v.verify(macaroon, self._macaroon_secret_key)
 
-        # Extract the `nonce`, `client_redirect_url`, and maybe the
-        # `ui_auth_session_id` from the token.
+        # Extract the session data from the token.
         nonce = self._get_value_from_macaroon(macaroon, "nonce")
+        idp_id = self._get_value_from_macaroon(macaroon, "idp_id")
         client_redirect_url = self._get_value_from_macaroon(
             macaroon, "client_redirect_url"
         )
@@ -983,6 +989,7 @@ class OidcSessionTokenGenerator:
 
         return OidcSessionData(
             nonce=nonce,
+            idp_id=idp_id,
             client_redirect_url=client_redirect_url,
             ui_auth_session_id=ui_auth_session_id,
         )
@@ -998,7 +1005,7 @@ class OidcSessionTokenGenerator:
             The extracted value
 
         Raises:
-            Exception: if the caveat was not in the macaroon
+            ValueError: if the caveat was not in the macaroon
         """
         prefix = key + " = "
         for caveat in macaroon.caveats:
@@ -1019,6 +1026,9 @@ class OidcSessionTokenGenerator:
 class OidcSessionData:
     """The attributes which are stored in a OIDC session cookie"""
 
+    # the Identity Provider being used
+    idp_id = attr.ib(type=str)
+
     # The `nonce` parameter passed to the OIDC provider.
     nonce = attr.ib(type=str)
 
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index d096e0b091..dcc85e9871 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -22,7 +22,9 @@ from typing_extensions import NoReturn, Protocol
 
 from twisted.web.http import Request
 
+from synapse.api.constants import LoginType
 from synapse.api.errors import Codes, RedirectException, SynapseError
+from synapse.handlers.ui_auth import UIAuthSessionDataConstants
 from synapse.http import get_request_user_agent
 from synapse.http.server import respond_with_html
 from synapse.http.site import SynapseRequest
@@ -146,8 +148,13 @@ class SsoHandler:
         self._store = hs.get_datastore()
         self._server_name = hs.hostname
         self._registration_handler = hs.get_registration_handler()
-        self._error_template = hs.config.sso_error_template
         self._auth_handler = hs.get_auth_handler()
+        self._error_template = hs.config.sso_error_template
+        self._bad_user_template = hs.config.sso_auth_bad_user_template
+
+        # The following template is shown after a successful user interactive
+        # authentication session. It tells the user they can close the window.
+        self._sso_auth_success_template = hs.config.sso_auth_success_template
 
         # a lock on the mappings
         self._mapping_lock = Linearizer(name="sso_user_mapping", clock=hs.get_clock())
@@ -577,19 +584,45 @@ class SsoHandler:
             auth_provider_id, remote_user_id,
         )
 
+        user_id_to_verify = await self._auth_handler.get_session_data(
+            ui_auth_session_id, UIAuthSessionDataConstants.REQUEST_USER_ID
+        )  # type: str
+
         if not user_id:
             logger.warning(
                 "Remote user %s/%s has not previously logged in here: UIA will fail",
                 auth_provider_id,
                 remote_user_id,
             )
-            # Let the UIA flow handle this the same as if they presented creds for a
-            # different user.
-            user_id = ""
+        elif user_id != user_id_to_verify:
+            logger.warning(
+                "Remote user %s/%s mapped onto incorrect user %s: UIA will fail",
+                auth_provider_id,
+                remote_user_id,
+                user_id,
+            )
+        else:
+            # success!
+            # Mark the stage of the authentication as successful.
+            await self._store.mark_ui_auth_stage_complete(
+                ui_auth_session_id, LoginType.SSO, user_id
+            )
+
+            # Render the HTML confirmation page and return.
+            html = self._sso_auth_success_template
+            respond_with_html(request, 200, html)
+            return
+
+        # the user_id didn't match: mark the stage of the authentication as unsuccessful
+        await self._store.mark_ui_auth_stage_complete(
+            ui_auth_session_id, LoginType.SSO, ""
+        )
 
-        await self._auth_handler.complete_sso_ui_auth(
-            user_id, ui_auth_session_id, request
+        # render an error page.
+        html = self._bad_user_template.render(
+            server_name=self._server_name, user_id_to_verify=user_id_to_verify,
         )
+        respond_with_html(request, 200, html)
 
     async def check_username_availability(
         self, localpart: str, session_id: str,
diff --git a/synapse/res/templates/sso_auth_bad_user.html b/synapse/res/templates/sso_auth_bad_user.html
new file mode 100644
index 0000000000..3611191bf9
--- /dev/null
+++ b/synapse/res/templates/sso_auth_bad_user.html
@@ -0,0 +1,18 @@
+<html>
+<head>
+    <title>Authentication Failed</title>
+</head>
+    <body>
+        <div>
+            <p>
+                We were unable to validate your <tt>{{server_name | e}}</tt> account via
+                single-sign-on (SSO), because the SSO Identity Provider returned
+                different details than when you logged in.
+            </p>
+            <p>
+                Try the operation again, and ensure that you use the same details on
+                the Identity Provider as when you log into your account.
+            </p>
+        </div>
+    </body>
+</html>
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 1b6ccd51c8..c128889bf9 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -25,6 +25,7 @@ from twisted.enterprise.adbapi import Connection
 from synapse.logging.opentracing import log_kv, set_tag, trace
 from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.database import DatabasePool, make_in_list_sql_clause
+from synapse.storage.engines import PostgresEngine
 from synapse.storage.types import Cursor
 from synapse.types import JsonDict
 from synapse.util import json_encoder
@@ -513,21 +514,35 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
 
         for user_chunk in batch_iter(user_ids, 100):
             clause, params = make_in_list_sql_clause(
-                txn.database_engine, "k.user_id", user_chunk
-            )
-            sql = (
-                """
-                SELECT k.user_id, k.keytype, k.keydata, k.stream_id
-                  FROM e2e_cross_signing_keys k
-                  INNER JOIN (SELECT user_id, keytype, MAX(stream_id) AS stream_id
-                                FROM e2e_cross_signing_keys
-                               GROUP BY user_id, keytype) s
-                 USING (user_id, stream_id, keytype)
-                 WHERE
-            """
-                + clause
+                txn.database_engine, "user_id", user_chunk
             )
 
+            # Fetch the latest key for each type per user.
+            if isinstance(self.database_engine, PostgresEngine):
+                # The `DISTINCT ON` clause will pick the *first* row it
+                # encounters, so ordering by stream ID desc will ensure we get
+                # the latest key.
+                sql = """
+                    SELECT DISTINCT ON (user_id, keytype) user_id, keytype, keydata, stream_id
+                        FROM e2e_cross_signing_keys
+                        WHERE %(clause)s
+                        ORDER BY user_id, keytype, stream_id DESC
+                """ % {
+                    "clause": clause
+                }
+            else:
+                # SQLite has special handling for bare columns when using
+                # MIN/MAX with a `GROUP BY` clause where it picks the value from
+                # a row that matches the MIN/MAX.
+                sql = """
+                    SELECT user_id, keytype, keydata, MAX(stream_id)
+                        FROM e2e_cross_signing_keys
+                        WHERE %(clause)s
+                        GROUP BY user_id, keytype
+                """ % {
+                    "clause": clause
+                }
+
             txn.execute(sql, params)
             rows = self.db_pool.cursor_to_dict(txn)
 
diff --git a/synapse/util/iterutils.py b/synapse/util/iterutils.py
index f7b4857a84..6ef2b008a4 100644
--- a/synapse/util/iterutils.py
+++ b/synapse/util/iterutils.py
@@ -92,7 +92,7 @@ def sorted_topologically(
         node = heapq.heappop(zero_degree)
         yield node
 
-        for edge in reverse_graph[node]:
+        for edge in reverse_graph.get(node, []):
             if edge in degree_map:
                 degree_map[edge] -= 1
                 if degree_map[edge] == 0:
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index 5d338bea87..02e21ed6ca 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -13,20 +13,14 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import json
-import re
-from typing import Dict, Optional
-from urllib.parse import parse_qs, urlencode, urlparse
+from typing import Optional
+from urllib.parse import parse_qs, urlparse
 
 from mock import ANY, Mock, patch
 
 import pymacaroons
 
-from twisted.web.resource import Resource
-
-from synapse.api.errors import RedirectException
 from synapse.handlers.sso import MappingException
-from synapse.rest.client.v1 import login
-from synapse.rest.synapse.client.pick_username import pick_username_resource
 from synapse.server import HomeServer
 from synapse.types import UserID
 
@@ -848,6 +842,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         return self.handler._token_generator.generate_oidc_session_token(
             state=state,
             session_data=OidcSessionData(
+                idp_id="oidc",
                 nonce=nonce,
                 client_redirect_url=client_redirect_url,
                 ui_auth_session_id=ui_auth_session_id,
@@ -855,116 +850,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
         )
 
 
-class UsernamePickerTestCase(HomeserverTestCase):
-    if not HAS_OIDC:
-        skip = "requires OIDC"
-
-    servlets = [login.register_servlets]
-
-    def default_config(self):
-        config = super().default_config()
-        config["public_baseurl"] = BASE_URL
-        oidc_config = {
-            "enabled": True,
-            "client_id": CLIENT_ID,
-            "client_secret": CLIENT_SECRET,
-            "issuer": ISSUER,
-            "scopes": SCOPES,
-            "user_mapping_provider": {
-                "config": {"display_name_template": "{{ user.displayname }}"}
-            },
-        }
-
-        # Update this config with what's in the default config so that
-        # override_config works as expected.
-        oidc_config.update(config.get("oidc_config", {}))
-        config["oidc_config"] = oidc_config
-
-        # whitelist this client URI so we redirect straight to it rather than
-        # serving a confirmation page
-        config["sso"] = {"client_whitelist": ["https://whitelisted.client"]}
-        return config
-
-    def create_resource_dict(self) -> Dict[str, Resource]:
-        d = super().create_resource_dict()
-        d["/_synapse/client/pick_username"] = pick_username_resource(self.hs)
-        return d
-
-    def test_username_picker(self):
-        """Test the happy path of a username picker flow."""
-        client_redirect_url = "https://whitelisted.client"
-
-        # first of all, mock up an OIDC callback to the OidcHandler, which should
-        # raise a RedirectException
-        userinfo = {"sub": "tester", "displayname": "Jonny"}
-        f = self.get_failure(
-            _make_callback_with_userinfo(
-                self.hs, userinfo, client_redirect_url=client_redirect_url
-            ),
-            RedirectException,
-        )
-
-        # check the Location and cookies returned by the RedirectException
-        self.assertEqual(f.value.location, b"/_synapse/client/pick_username")
-        cookieheader = f.value.cookies[0]
-        regex = re.compile(b"^username_mapping_session=([a-zA-Z]+);")
-        m = regex.search(cookieheader)
-        if not m:
-            self.fail("cookie header %s does not match %s" % (cookieheader, regex))
-
-        # introspect the sso handler a bit to check that the username mapping session
-        # looks ok.
-        session_id = m.group(1).decode("ascii")
-        username_mapping_sessions = self.hs.get_sso_handler()._username_mapping_sessions
-        self.assertIn(
-            session_id, username_mapping_sessions, "session id not found in map"
-        )
-        session = username_mapping_sessions[session_id]
-        self.assertEqual(session.remote_user_id, "tester")
-        self.assertEqual(session.display_name, "Jonny")
-        self.assertEqual(session.client_redirect_url, client_redirect_url)
-
-        # the expiry time should be about 15 minutes away
-        expected_expiry = self.clock.time_msec() + (15 * 60 * 1000)
-        self.assertApproximates(session.expiry_time_ms, expected_expiry, tolerance=1000)
-
-        # Now, submit a username to the username picker, which should serve a redirect
-        # back to the client
-        submit_path = f.value.location + b"/submit"
-        content = urlencode({b"username": b"bobby"}).encode("utf8")
-        chan = self.make_request(
-            "POST",
-            path=submit_path,
-            content=content,
-            content_is_form=True,
-            custom_headers=[
-                ("Cookie", cookieheader),
-                # old versions of twisted don't do form-parsing without a valid
-                # content-length header.
-                ("Content-Length", str(len(content))),
-            ],
-        )
-        self.assertEqual(chan.code, 302, chan.result)
-        location_headers = chan.headers.getRawHeaders("Location")
-        # ensure that the returned location starts with the requested redirect URL
-        self.assertEqual(
-            location_headers[0][: len(client_redirect_url)], client_redirect_url
-        )
-
-        # fish the login token out of the returned redirect uri
-        parts = urlparse(location_headers[0])
-        query = parse_qs(parts.query)
-        login_token = query["loginToken"][0]
-
-        # finally, submit the matrix login token to the login API, which gives us our
-        # matrix access token, mxid, and device id.
-        chan = self.make_request(
-            "POST", "/login", content={"type": "m.login.token", "token": login_token},
-        )
-        self.assertEqual(chan.code, 200, chan.result)
-        self.assertEqual(chan.json_body["user_id"], "@bobby:test")
-
-
 async def _make_callback_with_userinfo(
     hs: HomeServer, userinfo: dict, client_redirect_url: str = "http://client/redirect"
 ) -> None:
@@ -990,7 +875,7 @@ async def _make_callback_with_userinfo(
     session = handler._token_generator.generate_oidc_session_token(
         state=state,
         session_data=OidcSessionData(
-            nonce="nonce", client_redirect_url=client_redirect_url,
+            idp_id="oidc", nonce="nonce", client_redirect_url=client_redirect_url,
         ),
     )
     request = _build_callback_request("code", state, session)
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index f9b8011961..73a009efd1 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -17,6 +17,7 @@ import time
 import urllib.parse
 from html.parser import HTMLParser
 from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
+from urllib.parse import parse_qs, urlencode, urlparse
 
 from mock import Mock
 
@@ -30,13 +31,14 @@ from synapse.rest.client.v1 import login, logout
 from synapse.rest.client.v2_alpha import devices, register
 from synapse.rest.client.v2_alpha.account import WhoamiRestServlet
 from synapse.rest.synapse.client.pick_idp import PickIdpResource
+from synapse.rest.synapse.client.pick_username import pick_username_resource
 from synapse.types import create_requester
 
 from tests import unittest
 from tests.handlers.test_oidc import HAS_OIDC
 from tests.handlers.test_saml import has_saml2
 from tests.rest.client.v1.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG
-from tests.unittest import override_config, skip_unless
+from tests.unittest import HomeserverTestCase, override_config, skip_unless
 
 try:
     import jwt
@@ -1060,3 +1062,104 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
         channel = self.make_request(b"POST", LOGIN_URL, params)
 
         self.assertEquals(channel.result["code"], b"401", channel.result)
+
+
+@skip_unless(HAS_OIDC, "requires OIDC")
+class UsernamePickerTestCase(HomeserverTestCase):
+    """Tests for the username picker flow of SSO login"""
+
+    servlets = [login.register_servlets]
+
+    def default_config(self):
+        config = super().default_config()
+        config["public_baseurl"] = BASE_URL
+
+        config["oidc_config"] = {}
+        config["oidc_config"].update(TEST_OIDC_CONFIG)
+        config["oidc_config"]["user_mapping_provider"] = {
+            "config": {"display_name_template": "{{ user.displayname }}"}
+        }
+
+        # whitelist this client URI so we redirect straight to it rather than
+        # serving a confirmation page
+        config["sso"] = {"client_whitelist": ["https://whitelisted.client"]}
+        return config
+
+    def create_resource_dict(self) -> Dict[str, Resource]:
+        from synapse.rest.oidc import OIDCResource
+
+        d = super().create_resource_dict()
+        d["/_synapse/client/pick_username"] = pick_username_resource(self.hs)
+        d["/_synapse/oidc"] = OIDCResource(self.hs)
+        return d
+
+    def test_username_picker(self):
+        """Test the happy path of a username picker flow."""
+        client_redirect_url = "https://whitelisted.client"
+
+        # do the start of the login flow
+        channel = self.helper.auth_via_oidc(
+            {"sub": "tester", "displayname": "Jonny"}, client_redirect_url
+        )
+
+        # that should redirect to the username picker
+        self.assertEqual(channel.code, 302, channel.result)
+        picker_url = channel.headers.getRawHeaders("Location")[0]
+        self.assertEqual(picker_url, "/_synapse/client/pick_username")
+
+        # ... with a username_mapping_session cookie
+        cookies = {}  # type: Dict[str,str]
+        channel.extract_cookies(cookies)
+        self.assertIn("username_mapping_session", cookies)
+        session_id = cookies["username_mapping_session"]
+
+        # introspect the sso handler a bit to check that the username mapping session
+        # looks ok.
+        username_mapping_sessions = self.hs.get_sso_handler()._username_mapping_sessions
+        self.assertIn(
+            session_id, username_mapping_sessions, "session id not found in map",
+        )
+        session = username_mapping_sessions[session_id]
+        self.assertEqual(session.remote_user_id, "tester")
+        self.assertEqual(session.display_name, "Jonny")
+        self.assertEqual(session.client_redirect_url, client_redirect_url)
+
+        # the expiry time should be about 15 minutes away
+        expected_expiry = self.clock.time_msec() + (15 * 60 * 1000)
+        self.assertApproximates(session.expiry_time_ms, expected_expiry, tolerance=1000)
+
+        # Now, submit a username to the username picker, which should serve a redirect
+        # back to the client
+        submit_path = picker_url + "/submit"
+        content = urlencode({b"username": b"bobby"}).encode("utf8")
+        chan = self.make_request(
+            "POST",
+            path=submit_path,
+            content=content,
+            content_is_form=True,
+            custom_headers=[
+                ("Cookie", "username_mapping_session=" + session_id),
+                # old versions of twisted don't do form-parsing without a valid
+                # content-length header.
+                ("Content-Length", str(len(content))),
+            ],
+        )
+        self.assertEqual(chan.code, 302, chan.result)
+        location_headers = chan.headers.getRawHeaders("Location")
+        # ensure that the returned location starts with the requested redirect URL
+        self.assertEqual(
+            location_headers[0][: len(client_redirect_url)], client_redirect_url
+        )
+
+        # fish the login token out of the returned redirect uri
+        parts = urlparse(location_headers[0])
+        query = parse_qs(parts.query)
+        login_token = query["loginToken"][0]
+
+        # finally, submit the matrix login token to the login API, which gives us our
+        # matrix access token, mxid, and device id.
+        chan = self.make_request(
+            "POST", "/login", content={"type": "m.login.token", "token": login_token},
+        )
+        self.assertEqual(chan.code, 200, chan.result)
+        self.assertEqual(chan.json_body["user_id"], "@bobby:test")
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 85d1709ead..c6647dbe08 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -363,10 +363,10 @@ class RestHelper:
         the normal places.
         """
         client_redirect_url = "https://x"
-        channel = self.auth_via_oidc(remote_user_id, client_redirect_url)
+        channel = self.auth_via_oidc({"sub": remote_user_id}, client_redirect_url)
 
         # expect a confirmation page
-        assert channel.code == 200
+        assert channel.code == 200, channel.result
 
         # fish the matrix login token out of the body of the confirmation page
         m = re.search(
@@ -390,7 +390,7 @@ class RestHelper:
 
     def auth_via_oidc(
         self,
-        remote_user_id: str,
+        user_info_dict: JsonDict,
         client_redirect_url: Optional[str] = None,
         ui_auth_session_id: Optional[str] = None,
     ) -> FakeChannel:
@@ -411,7 +411,8 @@ class RestHelper:
         the normal places.
 
         Args:
-            remote_user_id: the remote id that the OIDC provider should present
+            user_info_dict: the remote userinfo that the OIDC provider should present.
+                Typically this should be '{"sub": "<remote user id>"}'.
             client_redirect_url: for a login flow, the client redirect URL to pass to
                 the login redirect endpoint
             ui_auth_session_id: if set, we will perform a UI Auth flow. The session id
@@ -457,7 +458,7 @@ class RestHelper:
             # a dummy OIDC access token
             ("https://issuer.test/token", {"access_token": "TEST"}),
             # and then one to the user_info endpoint, which returns our remote user id.
-            ("https://issuer.test/userinfo", {"sub": remote_user_id}),
+            ("https://issuer.test/userinfo", user_info_dict),
         ]
 
         async def mock_req(method: str, uri: str, data=None, headers=None):
diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py
index 5f6ca23b06..3e8661f9b9 100644
--- a/tests/rest/client/v2_alpha/test_auth.py
+++ b/tests/rest/client/v2_alpha/test_auth.py
@@ -411,7 +411,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
         # run the UIA-via-SSO flow
         session_id = channel.json_body["session"]
         channel = self.helper.auth_via_oidc(
-            remote_user_id=remote_user_id, ui_auth_session_id=session_id
+            {"sub": remote_user_id}, ui_auth_session_id=session_id
         )
 
         # that should serve a confirmation page
@@ -457,3 +457,30 @@ class UIAuthTests(unittest.HomeserverTestCase):
         self.assertIn({"stages": ["m.login.password"]}, flows)
         self.assertIn({"stages": ["m.login.sso"]}, flows)
         self.assertEqual(len(flows), 2)
+
+    @skip_unless(HAS_OIDC, "requires OIDC")
+    @override_config({"oidc_config": TEST_OIDC_CONFIG})
+    def test_ui_auth_fails_for_incorrect_sso_user(self):
+        """If the user tries to authenticate with the wrong SSO user, they get an error
+        """
+        # log the user in
+        login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
+        self.assertEqual(login_resp["user_id"], self.user)
+
+        # start a UI Auth flow by attempting to delete a device
+        channel = self.delete_device(self.user_tok, self.device_id, 401)
+
+        flows = channel.json_body["flows"]
+        self.assertIn({"stages": ["m.login.sso"]}, flows)
+        session_id = channel.json_body["session"]
+
+        # do the OIDC auth, but auth as the wrong user
+        channel = self.helper.auth_via_oidc("wrong_user", ui_auth_session_id=session_id)
+
+        # that should return a failure message
+        self.assertSubstring("We were unable to validate", channel.text_body)
+
+        # ... and the delete op should now fail with a 403
+        self.delete_device(
+            self.user_tok, self.device_id, 403, body={"auth": {"session": session_id}}
+        )
diff --git a/tests/util/test_itertools.py b/tests/util/test_itertools.py
index 1184cea5a3..522c8061f9 100644
--- a/tests/util/test_itertools.py
+++ b/tests/util/test_itertools.py
@@ -56,6 +56,14 @@ class SortTopologically(TestCase):
         graph = {}  # type: Dict[int, List[int]]
         self.assertEqual(list(sorted_topologically([], graph)), [])
 
+    def test_handle_empty_graph(self):
+        "Test that a graph where a node doesn't have an entry is treated as empty"
+
+        graph = {}  # type: Dict[int, List[int]]
+
+        # For disconnected nodes the output is simply sorted.
+        self.assertEqual(list(sorted_topologically([1, 2], graph)), [1, 2])
+
     def test_disconnected(self):
         "Test that a graph with no edges work"