summary refs log tree commit diff
diff options
context:
space:
mode:
authorQuentin Gliech <quenting@element.io>2021-12-06 18:43:06 +0100
committerGitHub <noreply@github.com>2021-12-06 12:43:06 -0500
commita15a893df8428395df7cb95b729431575001c38a (patch)
tree7572abf2fa680c942dc882cc05e9062bb63b55b8
parentAdd admin API to get some information about federation status (#11407) (diff)
downloadsynapse-a15a893df8428395df7cb95b729431575001c38a.tar.xz
Save the OIDC session ID (sid) with the device on login (#11482)
As a step towards allowing back-channel logout for OIDC.
-rw-r--r--changelog.d/11482.misc1
-rw-r--r--synapse/handlers/auth.py34
-rw-r--r--synapse/handlers/device.py8
-rw-r--r--synapse/handlers/oidc.py58
-rw-r--r--synapse/handlers/register.py15
-rw-r--r--synapse/handlers/sso.py4
-rw-r--r--synapse/module_api/__init__.py2
-rw-r--r--synapse/replication/http/login.py8
-rw-r--r--synapse/rest/client/login.py7
-rw-r--r--synapse/storage/databases/main/devices.py50
-rw-r--r--synapse/storage/schema/main/delta/65/11_devices_auth_provider_session.sql27
-rw-r--r--tests/handlers/test_auth.py6
-rw-r--r--tests/handlers/test_cas.py40
-rw-r--r--tests/handlers/test_oidc.py135
-rw-r--r--tests/handlers/test_saml.py40
15 files changed, 370 insertions, 65 deletions
diff --git a/changelog.d/11482.misc b/changelog.d/11482.misc
new file mode 100644
index 0000000000..e78662988f
--- /dev/null
+++ b/changelog.d/11482.misc
@@ -0,0 +1 @@
+Save the OpenID Connect session ID on login.
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 4d9c4e5834..61607cf2ba 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -39,6 +39,7 @@ import attr
 import bcrypt
 import pymacaroons
 import unpaddedbase64
+from pymacaroons.exceptions import MacaroonVerificationFailedException
 
 from twisted.web.server import Request
 
@@ -182,8 +183,11 @@ class LoginTokenAttributes:
 
     user_id = attr.ib(type=str)
 
-    # the SSO Identity Provider that the user authenticated with, to get this token
     auth_provider_id = attr.ib(type=str)
+    """The SSO Identity Provider that the user authenticated with, to get this token."""
+
+    auth_provider_session_id = attr.ib(type=Optional[str])
+    """The session ID advertised by the SSO Identity Provider."""
 
 
 class AuthHandler:
@@ -1650,6 +1654,7 @@ class AuthHandler:
         client_redirect_url: str,
         extra_attributes: Optional[JsonDict] = None,
         new_user: bool = False,
+        auth_provider_session_id: Optional[str] = None,
     ) -> None:
         """Having figured out a mxid for this user, complete the HTTP request
 
@@ -1665,6 +1670,7 @@ class AuthHandler:
                 during successful login. Must be JSON serializable.
             new_user: True if we should use wording appropriate to a user who has just
                 registered.
+            auth_provider_session_id: The session ID from the SSO IdP received during login.
         """
         # If the account has been deactivated, do not proceed with the login
         # flow.
@@ -1685,6 +1691,7 @@ class AuthHandler:
             extra_attributes,
             new_user=new_user,
             user_profile_data=profile,
+            auth_provider_session_id=auth_provider_session_id,
         )
 
     def _complete_sso_login(
@@ -1696,6 +1703,7 @@ class AuthHandler:
         extra_attributes: Optional[JsonDict] = None,
         new_user: bool = False,
         user_profile_data: Optional[ProfileInfo] = None,
+        auth_provider_session_id: Optional[str] = None,
     ) -> None:
         """
         The synchronous portion of complete_sso_login.
@@ -1717,7 +1725,9 @@ class AuthHandler:
 
         # Create a login token
         login_token = self.macaroon_gen.generate_short_term_login_token(
-            registered_user_id, auth_provider_id=auth_provider_id
+            registered_user_id,
+            auth_provider_id=auth_provider_id,
+            auth_provider_session_id=auth_provider_session_id,
         )
 
         # Append the login token to the original redirect URL (i.e. with its query
@@ -1822,6 +1832,7 @@ class MacaroonGenerator:
         self,
         user_id: str,
         auth_provider_id: str,
+        auth_provider_session_id: Optional[str] = None,
         duration_in_ms: int = (2 * 60 * 1000),
     ) -> str:
         macaroon = self._generate_base_macaroon(user_id)
@@ -1830,6 +1841,10 @@ class MacaroonGenerator:
         expiry = now + duration_in_ms
         macaroon.add_first_party_caveat("time < %d" % (expiry,))
         macaroon.add_first_party_caveat("auth_provider_id = %s" % (auth_provider_id,))
+        if auth_provider_session_id is not None:
+            macaroon.add_first_party_caveat(
+                "auth_provider_session_id = %s" % (auth_provider_session_id,)
+            )
         return macaroon.serialize()
 
     def verify_short_term_login_token(self, token: str) -> LoginTokenAttributes:
@@ -1851,15 +1866,28 @@ class MacaroonGenerator:
         user_id = get_value_from_macaroon(macaroon, "user_id")
         auth_provider_id = get_value_from_macaroon(macaroon, "auth_provider_id")
 
+        auth_provider_session_id: Optional[str] = None
+        try:
+            auth_provider_session_id = get_value_from_macaroon(
+                macaroon, "auth_provider_session_id"
+            )
+        except MacaroonVerificationFailedException:
+            pass
+
         v = pymacaroons.Verifier()
         v.satisfy_exact("gen = 1")
         v.satisfy_exact("type = login")
         v.satisfy_general(lambda c: c.startswith("user_id = "))
         v.satisfy_general(lambda c: c.startswith("auth_provider_id = "))
+        v.satisfy_general(lambda c: c.startswith("auth_provider_session_id = "))
         satisfy_expiry(v, self.hs.get_clock().time_msec)
         v.verify(macaroon, self.hs.config.key.macaroon_secret_key)
 
-        return LoginTokenAttributes(user_id=user_id, auth_provider_id=auth_provider_id)
+        return LoginTokenAttributes(
+            user_id=user_id,
+            auth_provider_id=auth_provider_id,
+            auth_provider_session_id=auth_provider_session_id,
+        )
 
     def generate_delete_pusher_token(self, user_id: str) -> str:
         macaroon = self._generate_base_macaroon(user_id)
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 68b446eb66..82ee11e921 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -301,6 +301,8 @@ class DeviceHandler(DeviceWorkerHandler):
         user_id: str,
         device_id: Optional[str],
         initial_device_display_name: Optional[str] = None,
+        auth_provider_id: Optional[str] = None,
+        auth_provider_session_id: Optional[str] = None,
     ) -> str:
         """
         If the given device has not been registered, register it with the
@@ -312,6 +314,8 @@ class DeviceHandler(DeviceWorkerHandler):
             user_id:  @user:id
             device_id: device id supplied by client
             initial_device_display_name: device display name from client
+            auth_provider_id: The SSO IdP the user used, if any.
+            auth_provider_session_id: The session ID (sid) got from the SSO IdP.
         Returns:
             device id (generated if none was supplied)
         """
@@ -323,6 +327,8 @@ class DeviceHandler(DeviceWorkerHandler):
                 user_id=user_id,
                 device_id=device_id,
                 initial_device_display_name=initial_device_display_name,
+                auth_provider_id=auth_provider_id,
+                auth_provider_session_id=auth_provider_session_id,
             )
             if new_device:
                 await self.notify_device_update(user_id, [device_id])
@@ -337,6 +343,8 @@ class DeviceHandler(DeviceWorkerHandler):
                 user_id=user_id,
                 device_id=new_device_id,
                 initial_device_display_name=initial_device_display_name,
+                auth_provider_id=auth_provider_id,
+                auth_provider_session_id=auth_provider_session_id,
             )
             if new_device:
                 await self.notify_device_update(user_id, [new_device_id])
diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py
index 3665d91513..deb3539751 100644
--- a/synapse/handlers/oidc.py
+++ b/synapse/handlers/oidc.py
@@ -23,7 +23,7 @@ from authlib.common.security import generate_token
 from authlib.jose import JsonWebToken, jwt
 from authlib.oauth2.auth import ClientAuth
 from authlib.oauth2.rfc6749.parameters import prepare_grant_uri
-from authlib.oidc.core import CodeIDToken, ImplicitIDToken, UserInfo
+from authlib.oidc.core import CodeIDToken, UserInfo
 from authlib.oidc.discovery import OpenIDProviderMetadata, get_well_known_url
 from jinja2 import Environment, Template
 from pymacaroons.exceptions import (
@@ -117,7 +117,8 @@ class OidcHandler:
         for idp_id, p in self._providers.items():
             try:
                 await p.load_metadata()
-                await p.load_jwks()
+                if not p._uses_userinfo:
+                    await p.load_jwks()
             except Exception as e:
                 raise Exception(
                     "Error while initialising OIDC provider %r" % (idp_id,)
@@ -498,10 +499,6 @@ class OidcProvider:
         return await self._jwks.get()
 
     async def _load_jwks(self) -> JWKS:
-        if self._uses_userinfo:
-            # We're not using jwt signing, return an empty jwk set
-            return {"keys": []}
-
         metadata = await self.load_metadata()
 
         # Load the JWKS using the `jwks_uri` metadata.
@@ -663,7 +660,7 @@ class OidcProvider:
 
         return UserInfo(resp)
 
-    async def _parse_id_token(self, token: Token, nonce: str) -> UserInfo:
+    async def _parse_id_token(self, token: Token, nonce: str) -> CodeIDToken:
         """Return an instance of UserInfo from token's ``id_token``.
 
         Args:
@@ -673,7 +670,7 @@ class OidcProvider:
                 request. This value should match the one inside the token.
 
         Returns:
-            An object representing the user.
+            The decoded claims in the ID token.
         """
         metadata = await self.load_metadata()
         claims_params = {
@@ -684,9 +681,6 @@ class OidcProvider:
             # If we got an `access_token`, there should be an `at_hash` claim
             # in the `id_token` that we can check against.
             claims_params["access_token"] = token["access_token"]
-            claims_cls = CodeIDToken
-        else:
-            claims_cls = ImplicitIDToken
 
         alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"])
         jwt = JsonWebToken(alg_values)
@@ -703,7 +697,7 @@ class OidcProvider:
             claims = jwt.decode(
                 id_token,
                 key=jwk_set,
-                claims_cls=claims_cls,
+                claims_cls=CodeIDToken,
                 claims_options=claim_options,
                 claims_params=claims_params,
             )
@@ -713,7 +707,7 @@ class OidcProvider:
             claims = jwt.decode(
                 id_token,
                 key=jwk_set,
-                claims_cls=claims_cls,
+                claims_cls=CodeIDToken,
                 claims_options=claim_options,
                 claims_params=claims_params,
             )
@@ -721,7 +715,8 @@ class OidcProvider:
         logger.debug("Decoded id_token JWT %r; validating", claims)
 
         claims.validate(leeway=120)  # allows 2 min of clock skew
-        return UserInfo(claims)
+
+        return claims
 
     async def handle_redirect_request(
         self,
@@ -837,8 +832,22 @@ class OidcProvider:
 
         logger.debug("Successfully obtained OAuth2 token data: %r", token)
 
-        # Now that we have a token, get the userinfo, either by decoding the
-        # `id_token` or by fetching the `userinfo_endpoint`.
+        # If there is an id_token, it should be validated, regardless of the
+        # userinfo endpoint is used or not.
+        if token.get("id_token") is not None:
+            try:
+                id_token = await self._parse_id_token(token, nonce=session_data.nonce)
+                sid = id_token.get("sid")
+            except Exception as e:
+                logger.exception("Invalid id_token")
+                self._sso_handler.render_error(request, "invalid_token", str(e))
+                return
+        else:
+            id_token = None
+            sid = None
+
+        # Now that we have a token, get the userinfo either from the `id_token`
+        # claims or by fetching the `userinfo_endpoint`.
         if self._uses_userinfo:
             try:
                 userinfo = await self._fetch_userinfo(token)
@@ -846,13 +855,14 @@ class OidcProvider:
                 logger.exception("Could not fetch userinfo")
                 self._sso_handler.render_error(request, "fetch_error", str(e))
                 return
+        elif id_token is not None:
+            userinfo = UserInfo(id_token)
         else:
-            try:
-                userinfo = await self._parse_id_token(token, nonce=session_data.nonce)
-            except Exception as e:
-                logger.exception("Invalid id_token")
-                self._sso_handler.render_error(request, "invalid_token", str(e))
-                return
+            logger.error("Missing id_token in token response")
+            self._sso_handler.render_error(
+                request, "invalid_token", "Missing id_token in token response"
+            )
+            return
 
         # first check if we're doing a UIA
         if session_data.ui_auth_session_id:
@@ -884,7 +894,7 @@ class OidcProvider:
         # Call the mapper to register/login the user
         try:
             await self._complete_oidc_login(
-                userinfo, token, request, session_data.client_redirect_url
+                userinfo, token, request, session_data.client_redirect_url, sid
             )
         except MappingException as e:
             logger.exception("Could not map user")
@@ -896,6 +906,7 @@ class OidcProvider:
         token: Token,
         request: SynapseRequest,
         client_redirect_url: str,
+        sid: Optional[str],
     ) -> None:
         """Given a UserInfo response, complete the login flow
 
@@ -1008,6 +1019,7 @@ class OidcProvider:
             oidc_response_to_user_attributes,
             grandfather_existing_users,
             extra_attributes,
+            auth_provider_session_id=sid,
         )
 
     def _remote_id_from_userinfo(self, userinfo: UserInfo) -> str:
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index b14ddd8267..f08a516a75 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -746,6 +746,7 @@ class RegistrationHandler:
         is_appservice_ghost: bool = False,
         auth_provider_id: Optional[str] = None,
         should_issue_refresh_token: bool = False,
+        auth_provider_session_id: Optional[str] = None,
     ) -> Tuple[str, str, Optional[int], Optional[str]]:
         """Register a device for a user and generate an access token.
 
@@ -756,9 +757,9 @@ class RegistrationHandler:
             device_id: The device ID to check, or None to generate a new one.
             initial_display_name: An optional display name for the device.
             is_guest: Whether this is a guest account
-            auth_provider_id: The SSO IdP the user used, if any (just used for the
-                prometheus metrics).
+            auth_provider_id: The SSO IdP the user used, if any.
             should_issue_refresh_token: Whether it should also issue a refresh token
+            auth_provider_session_id: The session ID received during login from the SSO IdP.
         Returns:
             Tuple of device ID, access token, access token expiration time and refresh token
         """
@@ -769,6 +770,8 @@ class RegistrationHandler:
             is_guest=is_guest,
             is_appservice_ghost=is_appservice_ghost,
             should_issue_refresh_token=should_issue_refresh_token,
+            auth_provider_id=auth_provider_id,
+            auth_provider_session_id=auth_provider_session_id,
         )
 
         login_counter.labels(
@@ -791,6 +794,8 @@ class RegistrationHandler:
         is_guest: bool = False,
         is_appservice_ghost: bool = False,
         should_issue_refresh_token: bool = False,
+        auth_provider_id: Optional[str] = None,
+        auth_provider_session_id: Optional[str] = None,
     ) -> LoginDict:
         """Helper for register_device
 
@@ -822,7 +827,11 @@ class RegistrationHandler:
         refresh_token_id = None
 
         registered_device_id = await self.device_handler.check_device_registered(
-            user_id, device_id, initial_display_name
+            user_id,
+            device_id,
+            initial_display_name,
+            auth_provider_id=auth_provider_id,
+            auth_provider_session_id=auth_provider_session_id,
         )
         if is_guest:
             assert access_token_expiry is None
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 49fde01cf0..65c27bc64a 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -365,6 +365,7 @@ class SsoHandler:
         sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
         grandfather_existing_users: Callable[[], Awaitable[Optional[str]]],
         extra_login_attributes: Optional[JsonDict] = None,
+        auth_provider_session_id: Optional[str] = None,
     ) -> None:
         """
         Given an SSO ID, retrieve the user ID for it and possibly register the user.
@@ -415,6 +416,8 @@ class SsoHandler:
             extra_login_attributes: An optional dictionary of extra
                 attributes to be provided to the client in the login response.
 
+            auth_provider_session_id: An optional session ID from the IdP.
+
         Raises:
             MappingException if there was a problem mapping the response to a user.
             RedirectException: if the mapping provider needs to redirect the user
@@ -490,6 +493,7 @@ class SsoHandler:
             client_redirect_url,
             extra_login_attributes,
             new_user=new_user,
+            auth_provider_session_id=auth_provider_session_id,
         )
 
     async def _call_attribute_mapper(
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index a8154168be..6bfb4b8d1b 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -626,6 +626,7 @@ class ModuleApi:
         user_id: str,
         duration_in_ms: int = (2 * 60 * 1000),
         auth_provider_id: str = "",
+        auth_provider_session_id: Optional[str] = None,
     ) -> str:
         """Generate a login token suitable for m.login.token authentication
 
@@ -643,6 +644,7 @@ class ModuleApi:
         return self._hs.get_macaroon_generator().generate_short_term_login_token(
             user_id,
             auth_provider_id,
+            auth_provider_session_id,
             duration_in_ms,
         )
 
diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py
index 0db419ea57..daacc34cea 100644
--- a/synapse/replication/http/login.py
+++ b/synapse/replication/http/login.py
@@ -46,6 +46,8 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
         is_guest,
         is_appservice_ghost,
         should_issue_refresh_token,
+        auth_provider_id,
+        auth_provider_session_id,
     ):
         """
         Args:
@@ -63,6 +65,8 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
             "is_guest": is_guest,
             "is_appservice_ghost": is_appservice_ghost,
             "should_issue_refresh_token": should_issue_refresh_token,
+            "auth_provider_id": auth_provider_id,
+            "auth_provider_session_id": auth_provider_session_id,
         }
 
     async def _handle_request(self, request, user_id):
@@ -73,6 +77,8 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
         is_guest = content["is_guest"]
         is_appservice_ghost = content["is_appservice_ghost"]
         should_issue_refresh_token = content["should_issue_refresh_token"]
+        auth_provider_id = content["auth_provider_id"]
+        auth_provider_session_id = content["auth_provider_session_id"]
 
         res = await self.registration_handler.register_device_inner(
             user_id,
@@ -81,6 +87,8 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
             is_guest,
             is_appservice_ghost=is_appservice_ghost,
             should_issue_refresh_token=should_issue_refresh_token,
+            auth_provider_id=auth_provider_id,
+            auth_provider_session_id=auth_provider_session_id,
         )
 
         return 200, res
diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py
index a66ee4fb3d..1b23fa18cf 100644
--- a/synapse/rest/client/login.py
+++ b/synapse/rest/client/login.py
@@ -303,6 +303,7 @@ class LoginRestServlet(RestServlet):
         ratelimit: bool = True,
         auth_provider_id: Optional[str] = None,
         should_issue_refresh_token: bool = False,
+        auth_provider_session_id: Optional[str] = None,
     ) -> LoginResponse:
         """Called when we've successfully authed the user and now need to
         actually login them in (e.g. create devices). This gets called on
@@ -318,10 +319,10 @@ class LoginRestServlet(RestServlet):
             create_non_existent_users: Whether to create the user if they don't
                 exist. Defaults to False.
             ratelimit: Whether to ratelimit the login request.
-            auth_provider_id: The SSO IdP the user used, if any (just used for the
-                prometheus metrics).
+            auth_provider_id: The SSO IdP the user used, if any.
             should_issue_refresh_token: True if this login should issue
                 a refresh token alongside the access token.
+            auth_provider_session_id: The session ID got during login from the SSO IdP.
 
         Returns:
             result: Dictionary of account information after successful login.
@@ -354,6 +355,7 @@ class LoginRestServlet(RestServlet):
             initial_display_name,
             auth_provider_id=auth_provider_id,
             should_issue_refresh_token=should_issue_refresh_token,
+            auth_provider_session_id=auth_provider_session_id,
         )
 
         result = LoginResponse(
@@ -399,6 +401,7 @@ class LoginRestServlet(RestServlet):
             self.auth_handler._sso_login_callback,
             auth_provider_id=res.auth_provider_id,
             should_issue_refresh_token=should_issue_refresh_token,
+            auth_provider_session_id=res.auth_provider_session_id,
         )
 
     async def _do_jwt_login(
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 9ccc66e589..d5a4a661cd 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -139,6 +139,27 @@ class DeviceWorkerStore(SQLBaseStore):
 
         return {d["device_id"]: d for d in devices}
 
+    async def get_devices_by_auth_provider_session_id(
+        self, auth_provider_id: str, auth_provider_session_id: str
+    ) -> List[Dict[str, Any]]:
+        """Retrieve the list of devices associated with a SSO IdP session ID.
+
+        Args:
+            auth_provider_id: The SSO IdP ID as defined in the server config
+            auth_provider_session_id: The session ID within the IdP
+        Returns:
+            A list of dicts containing the device_id and the user_id of each device
+        """
+        return await self.db_pool.simple_select_list(
+            table="device_auth_providers",
+            keyvalues={
+                "auth_provider_id": auth_provider_id,
+                "auth_provider_session_id": auth_provider_session_id,
+            },
+            retcols=("user_id", "device_id"),
+            desc="get_devices_by_auth_provider_session_id",
+        )
+
     @trace
     async def get_device_updates_by_remote(
         self, destination: str, from_stream_id: int, limit: int
@@ -1070,7 +1091,12 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         )
 
     async def store_device(
-        self, user_id: str, device_id: str, initial_device_display_name: Optional[str]
+        self,
+        user_id: str,
+        device_id: str,
+        initial_device_display_name: Optional[str],
+        auth_provider_id: Optional[str] = None,
+        auth_provider_session_id: Optional[str] = None,
     ) -> bool:
         """Ensure the given device is known; add it to the store if not
 
@@ -1079,6 +1105,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             device_id: id of device
             initial_device_display_name: initial displayname of the device.
                 Ignored if device exists.
+            auth_provider_id: The SSO IdP the user used, if any.
+            auth_provider_session_id: The session ID (sid) got from a OIDC login.
 
         Returns:
             Whether the device was inserted or an existing device existed with that ID.
@@ -1115,6 +1143,18 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
                 if hidden:
                     raise StoreError(400, "The device ID is in use", Codes.FORBIDDEN)
 
+            if auth_provider_id and auth_provider_session_id:
+                await self.db_pool.simple_insert(
+                    "device_auth_providers",
+                    values={
+                        "user_id": user_id,
+                        "device_id": device_id,
+                        "auth_provider_id": auth_provider_id,
+                        "auth_provider_session_id": auth_provider_session_id,
+                    },
+                    desc="store_device_auth_provider",
+                )
+
             self.device_id_exists_cache.set(key, True)
             return inserted
         except StoreError:
@@ -1168,6 +1208,14 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
                 keyvalues={"user_id": user_id},
             )
 
+            self.db_pool.simple_delete_many_txn(
+                txn,
+                table="device_auth_providers",
+                column="device_id",
+                values=device_ids,
+                keyvalues={"user_id": user_id},
+            )
+
         await self.db_pool.runInteraction("delete_devices", _delete_devices_txn)
         for device_id in device_ids:
             self.device_id_exists_cache.invalidate((user_id, device_id))
diff --git a/synapse/storage/schema/main/delta/65/11_devices_auth_provider_session.sql b/synapse/storage/schema/main/delta/65/11_devices_auth_provider_session.sql
new file mode 100644
index 0000000000..a65bfb520d
--- /dev/null
+++ b/synapse/storage/schema/main/delta/65/11_devices_auth_provider_session.sql
@@ -0,0 +1,27 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Track the auth provider used by each login as well as the session ID
+CREATE TABLE device_auth_providers (
+  user_id TEXT NOT NULL,
+  device_id TEXT NOT NULL,
+  auth_provider_id TEXT NOT NULL,
+  auth_provider_session_id TEXT NOT NULL
+);
+
+CREATE INDEX device_auth_providers_devices
+  ON device_auth_providers (user_id, device_id);
+CREATE INDEX device_auth_providers_sessions
+  ON device_auth_providers (auth_provider_id, auth_provider_session_id);
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index 72e176da75..03b8b8615c 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -71,7 +71,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
 
     def test_short_term_login_token_gives_user_id(self):
         token = self.macaroon_generator.generate_short_term_login_token(
-            self.user1, "", 5000
+            self.user1, "", duration_in_ms=5000
         )
         res = self.get_success(self.auth_handler.validate_short_term_login_token(token))
         self.assertEqual(self.user1, res.user_id)
@@ -94,7 +94,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
 
     def test_short_term_login_token_cannot_replace_user_id(self):
         token = self.macaroon_generator.generate_short_term_login_token(
-            self.user1, "", 5000
+            self.user1, "", duration_in_ms=5000
         )
         macaroon = pymacaroons.Macaroon.deserialize(token)
 
@@ -213,6 +213,6 @@ class AuthTestCase(unittest.HomeserverTestCase):
 
     def _get_macaroon(self):
         token = self.macaroon_generator.generate_short_term_login_token(
-            self.user1, "", 5000
+            self.user1, "", duration_in_ms=5000
         )
         return pymacaroons.Macaroon.deserialize(token)
diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py
index b625995d12..8705ff8943 100644
--- a/tests/handlers/test_cas.py
+++ b/tests/handlers/test_cas.py
@@ -66,7 +66,13 @@ class CasHandlerTestCase(HomeserverTestCase):
 
         # check that the auth handler got called as expected
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", "cas", request, "redirect_uri", None, new_user=True
+            "@test_user:test",
+            "cas",
+            request,
+            "redirect_uri",
+            None,
+            new_user=True,
+            auth_provider_session_id=None,
         )
 
     def test_map_cas_user_to_existing_user(self):
@@ -89,7 +95,13 @@ class CasHandlerTestCase(HomeserverTestCase):
 
         # check that the auth handler got called as expected
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", "cas", request, "redirect_uri", None, new_user=False
+            "@test_user:test",
+            "cas",
+            request,
+            "redirect_uri",
+            None,
+            new_user=False,
+            auth_provider_session_id=None,
         )
 
         # Subsequent calls should map to the same mxid.
@@ -98,7 +110,13 @@ class CasHandlerTestCase(HomeserverTestCase):
             self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
         )
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", "cas", request, "redirect_uri", None, new_user=False
+            "@test_user:test",
+            "cas",
+            request,
+            "redirect_uri",
+            None,
+            new_user=False,
+            auth_provider_session_id=None,
         )
 
     def test_map_cas_user_to_invalid_localpart(self):
@@ -116,7 +134,13 @@ class CasHandlerTestCase(HomeserverTestCase):
 
         # check that the auth handler got called as expected
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@f=c3=b6=c3=b6:test", "cas", request, "redirect_uri", None, new_user=True
+            "@f=c3=b6=c3=b6:test",
+            "cas",
+            request,
+            "redirect_uri",
+            None,
+            new_user=True,
+            auth_provider_session_id=None,
         )
 
     @override_config(
@@ -160,7 +184,13 @@ class CasHandlerTestCase(HomeserverTestCase):
 
         # check that the auth handler got called as expected
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", "cas", request, "redirect_uri", None, new_user=True
+            "@test_user:test",
+            "cas",
+            request,
+            "redirect_uri",
+            None,
+            new_user=True,
+            auth_provider_session_id=None,
         )
 
 
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index a25c89bd5b..cfe3de5266 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -252,13 +252,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
         with patch.object(self.provider, "load_metadata", patched_load_metadata):
             self.get_failure(self.provider.load_jwks(force=True), RuntimeError)
 
-        # Return empty key set if JWKS are not used
-        self.provider._scopes = []  # not asking the openid scope
-        self.http_client.get_json.reset_mock()
-        jwks = self.get_success(self.provider.load_jwks(force=True))
-        self.http_client.get_json.assert_not_called()
-        self.assertEqual(jwks, {"keys": []})
-
     @override_config({"oidc_config": DEFAULT_CONFIG})
     def test_validate_config(self):
         """Provider metadatas are extensively validated."""
@@ -455,7 +448,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.get_success(self.handler.handle_oidc_callback(request))
 
         auth_handler.complete_sso_login.assert_called_once_with(
-            expected_user_id, "oidc", request, client_redirect_url, None, new_user=True
+            expected_user_id,
+            "oidc",
+            request,
+            client_redirect_url,
+            None,
+            new_user=True,
+            auth_provider_session_id=None,
         )
         self.provider._exchange_code.assert_called_once_with(code)
         self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce)
@@ -482,17 +481,58 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.provider._fetch_userinfo.reset_mock()
 
         # With userinfo fetching
-        self.provider._scopes = []  # do not ask the "openid" scope
+        self.provider._user_profile_method = "userinfo_endpoint"
+        token = {
+            "type": "bearer",
+            "access_token": "access_token",
+        }
+        self.provider._exchange_code = simple_async_mock(return_value=token)
         self.get_success(self.handler.handle_oidc_callback(request))
 
         auth_handler.complete_sso_login.assert_called_once_with(
-            expected_user_id, "oidc", request, client_redirect_url, None, new_user=False
+            expected_user_id,
+            "oidc",
+            request,
+            client_redirect_url,
+            None,
+            new_user=False,
+            auth_provider_session_id=None,
         )
         self.provider._exchange_code.assert_called_once_with(code)
         self.provider._parse_id_token.assert_not_called()
         self.provider._fetch_userinfo.assert_called_once_with(token)
         self.render_error.assert_not_called()
 
+        # With an ID token, userinfo fetching and sid in the ID token
+        self.provider._user_profile_method = "userinfo_endpoint"
+        token = {
+            "type": "bearer",
+            "access_token": "access_token",
+            "id_token": "id_token",
+        }
+        id_token = {
+            "sid": "abcdefgh",
+        }
+        self.provider._parse_id_token = simple_async_mock(return_value=id_token)
+        self.provider._exchange_code = simple_async_mock(return_value=token)
+        auth_handler.complete_sso_login.reset_mock()
+        self.provider._fetch_userinfo.reset_mock()
+        self.get_success(self.handler.handle_oidc_callback(request))
+
+        auth_handler.complete_sso_login.assert_called_once_with(
+            expected_user_id,
+            "oidc",
+            request,
+            client_redirect_url,
+            None,
+            new_user=False,
+            auth_provider_session_id=id_token["sid"],
+        )
+        self.provider._exchange_code.assert_called_once_with(code)
+        self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce)
+        self.provider._fetch_userinfo.assert_called_once_with(token)
+        self.render_error.assert_not_called()
+
         # Handle userinfo fetching error
         self.provider._fetch_userinfo = simple_async_mock(raises=Exception())
         self.get_success(self.handler.handle_oidc_callback(request))
@@ -776,6 +816,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             client_redirect_url,
             {"phone": "1234567"},
             new_user=True,
+            auth_provider_session_id=None,
         )
 
     @override_config({"oidc_config": DEFAULT_CONFIG})
@@ -790,7 +831,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
         }
         self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", "oidc", ANY, ANY, None, new_user=True
+            "@test_user:test",
+            "oidc",
+            ANY,
+            ANY,
+            None,
+            new_user=True,
+            auth_provider_session_id=None,
         )
         auth_handler.complete_sso_login.reset_mock()
 
@@ -801,7 +848,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
         }
         self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user_2:test", "oidc", ANY, ANY, None, new_user=True
+            "@test_user_2:test",
+            "oidc",
+            ANY,
+            ANY,
+            None,
+            new_user=True,
+            auth_provider_session_id=None,
         )
         auth_handler.complete_sso_login.reset_mock()
 
@@ -838,14 +891,26 @@ class OidcHandlerTestCase(HomeserverTestCase):
         }
         self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_called_once_with(
-            user.to_string(), "oidc", ANY, ANY, None, new_user=False
+            user.to_string(),
+            "oidc",
+            ANY,
+            ANY,
+            None,
+            new_user=False,
+            auth_provider_session_id=None,
         )
         auth_handler.complete_sso_login.reset_mock()
 
         # Subsequent calls should map to the same mxid.
         self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_called_once_with(
-            user.to_string(), "oidc", ANY, ANY, None, new_user=False
+            user.to_string(),
+            "oidc",
+            ANY,
+            ANY,
+            None,
+            new_user=False,
+            auth_provider_session_id=None,
         )
         auth_handler.complete_sso_login.reset_mock()
 
@@ -860,7 +925,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
         }
         self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_called_once_with(
-            user.to_string(), "oidc", ANY, ANY, None, new_user=False
+            user.to_string(),
+            "oidc",
+            ANY,
+            ANY,
+            None,
+            new_user=False,
+            auth_provider_session_id=None,
         )
         auth_handler.complete_sso_login.reset_mock()
 
@@ -896,7 +967,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
 
         self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@TEST_USER_2:test", "oidc", ANY, ANY, None, new_user=False
+            "@TEST_USER_2:test",
+            "oidc",
+            ANY,
+            ANY,
+            None,
+            new_user=False,
+            auth_provider_session_id=None,
         )
 
     @override_config({"oidc_config": DEFAULT_CONFIG})
@@ -934,7 +1011,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
 
         # test_user is already taken, so test_user1 gets registered instead.
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user1:test", "oidc", ANY, ANY, None, new_user=True
+            "@test_user1:test",
+            "oidc",
+            ANY,
+            ANY,
+            None,
+            new_user=True,
+            auth_provider_session_id=None,
         )
         auth_handler.complete_sso_login.reset_mock()
 
@@ -1018,7 +1101,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
 
         # check that the auth handler got called as expected
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@tester:test", "oidc", ANY, ANY, None, new_user=True
+            "@tester:test",
+            "oidc",
+            ANY,
+            ANY,
+            None,
+            new_user=True,
+            auth_provider_session_id=None,
         )
 
     @override_config(
@@ -1043,7 +1132,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
 
         # check that the auth handler got called as expected
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@tester:test", "oidc", ANY, ANY, None, new_user=True
+            "@tester:test",
+            "oidc",
+            ANY,
+            ANY,
+            None,
+            new_user=True,
+            auth_provider_session_id=None,
         )
 
     @override_config(
@@ -1156,7 +1251,7 @@ async def _make_callback_with_userinfo(
 
     handler = hs.get_oidc_handler()
     provider = handler._providers["oidc"]
-    provider._exchange_code = simple_async_mock(return_value={})
+    provider._exchange_code = simple_async_mock(return_value={"id_token": ""})
     provider._parse_id_token = simple_async_mock(return_value=userinfo)
     provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
 
diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py
index 8cfc184fef..50551aa6e3 100644
--- a/tests/handlers/test_saml.py
+++ b/tests/handlers/test_saml.py
@@ -130,7 +130,13 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
         # check that the auth handler got called as expected
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", "saml", request, "redirect_uri", None, new_user=True
+            "@test_user:test",
+            "saml",
+            request,
+            "redirect_uri",
+            None,
+            new_user=True,
+            auth_provider_session_id=None,
         )
 
     @override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
@@ -156,7 +162,13 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
         # check that the auth handler got called as expected
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", "saml", request, "", None, new_user=False
+            "@test_user:test",
+            "saml",
+            request,
+            "",
+            None,
+            new_user=False,
+            auth_provider_session_id=None,
         )
 
         # Subsequent calls should map to the same mxid.
@@ -165,7 +177,13 @@ class SamlHandlerTestCase(HomeserverTestCase):
             self.handler._handle_authn_response(request, saml_response, "")
         )
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", "saml", request, "", None, new_user=False
+            "@test_user:test",
+            "saml",
+            request,
+            "",
+            None,
+            new_user=False,
+            auth_provider_session_id=None,
         )
 
     def test_map_saml_response_to_invalid_localpart(self):
@@ -213,7 +231,13 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
         # test_user is already taken, so test_user1 gets registered instead.
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user1:test", "saml", request, "", None, new_user=True
+            "@test_user1:test",
+            "saml",
+            request,
+            "",
+            None,
+            new_user=True,
+            auth_provider_session_id=None,
         )
         auth_handler.complete_sso_login.reset_mock()
 
@@ -309,7 +333,13 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
         # check that the auth handler got called as expected
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", "saml", request, "redirect_uri", None, new_user=True
+            "@test_user:test",
+            "saml",
+            request,
+            "redirect_uri",
+            None,
+            new_user=True,
+            auth_provider_session_id=None,
         )