diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 4d9c4e5834..61607cf2ba 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -39,6 +39,7 @@ import attr
import bcrypt
import pymacaroons
import unpaddedbase64
+from pymacaroons.exceptions import MacaroonVerificationFailedException
from twisted.web.server import Request
@@ -182,8 +183,11 @@ class LoginTokenAttributes:
user_id = attr.ib(type=str)
- # the SSO Identity Provider that the user authenticated with, to get this token
auth_provider_id = attr.ib(type=str)
+ """The SSO Identity Provider that the user authenticated with, to get this token."""
+
+ auth_provider_session_id = attr.ib(type=Optional[str])
+ """The session ID advertised by the SSO Identity Provider."""
class AuthHandler:
@@ -1650,6 +1654,7 @@ class AuthHandler:
client_redirect_url: str,
extra_attributes: Optional[JsonDict] = None,
new_user: bool = False,
+ auth_provider_session_id: Optional[str] = None,
) -> None:
"""Having figured out a mxid for this user, complete the HTTP request
@@ -1665,6 +1670,7 @@ class AuthHandler:
during successful login. Must be JSON serializable.
new_user: True if we should use wording appropriate to a user who has just
registered.
+ auth_provider_session_id: The session ID from the SSO IdP received during login.
"""
# If the account has been deactivated, do not proceed with the login
# flow.
@@ -1685,6 +1691,7 @@ class AuthHandler:
extra_attributes,
new_user=new_user,
user_profile_data=profile,
+ auth_provider_session_id=auth_provider_session_id,
)
def _complete_sso_login(
@@ -1696,6 +1703,7 @@ class AuthHandler:
extra_attributes: Optional[JsonDict] = None,
new_user: bool = False,
user_profile_data: Optional[ProfileInfo] = None,
+ auth_provider_session_id: Optional[str] = None,
) -> None:
"""
The synchronous portion of complete_sso_login.
@@ -1717,7 +1725,9 @@ class AuthHandler:
# Create a login token
login_token = self.macaroon_gen.generate_short_term_login_token(
- registered_user_id, auth_provider_id=auth_provider_id
+ registered_user_id,
+ auth_provider_id=auth_provider_id,
+ auth_provider_session_id=auth_provider_session_id,
)
# Append the login token to the original redirect URL (i.e. with its query
@@ -1822,6 +1832,7 @@ class MacaroonGenerator:
self,
user_id: str,
auth_provider_id: str,
+ auth_provider_session_id: Optional[str] = None,
duration_in_ms: int = (2 * 60 * 1000),
) -> str:
macaroon = self._generate_base_macaroon(user_id)
@@ -1830,6 +1841,10 @@ class MacaroonGenerator:
expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,))
macaroon.add_first_party_caveat("auth_provider_id = %s" % (auth_provider_id,))
+ if auth_provider_session_id is not None:
+ macaroon.add_first_party_caveat(
+ "auth_provider_session_id = %s" % (auth_provider_session_id,)
+ )
return macaroon.serialize()
def verify_short_term_login_token(self, token: str) -> LoginTokenAttributes:
@@ -1851,15 +1866,28 @@ class MacaroonGenerator:
user_id = get_value_from_macaroon(macaroon, "user_id")
auth_provider_id = get_value_from_macaroon(macaroon, "auth_provider_id")
+ auth_provider_session_id: Optional[str] = None
+ try:
+ auth_provider_session_id = get_value_from_macaroon(
+ macaroon, "auth_provider_session_id"
+ )
+ except MacaroonVerificationFailedException:
+ pass
+
v = pymacaroons.Verifier()
v.satisfy_exact("gen = 1")
v.satisfy_exact("type = login")
v.satisfy_general(lambda c: c.startswith("user_id = "))
v.satisfy_general(lambda c: c.startswith("auth_provider_id = "))
+ v.satisfy_general(lambda c: c.startswith("auth_provider_session_id = "))
satisfy_expiry(v, self.hs.get_clock().time_msec)
v.verify(macaroon, self.hs.config.key.macaroon_secret_key)
- return LoginTokenAttributes(user_id=user_id, auth_provider_id=auth_provider_id)
+ return LoginTokenAttributes(
+ user_id=user_id,
+ auth_provider_id=auth_provider_id,
+ auth_provider_session_id=auth_provider_session_id,
+ )
def generate_delete_pusher_token(self, user_id: str) -> str:
macaroon = self._generate_base_macaroon(user_id)
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 68b446eb66..82ee11e921 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -301,6 +301,8 @@ class DeviceHandler(DeviceWorkerHandler):
user_id: str,
device_id: Optional[str],
initial_device_display_name: Optional[str] = None,
+ auth_provider_id: Optional[str] = None,
+ auth_provider_session_id: Optional[str] = None,
) -> str:
"""
If the given device has not been registered, register it with the
@@ -312,6 +314,8 @@ class DeviceHandler(DeviceWorkerHandler):
user_id: @user:id
device_id: device id supplied by client
initial_device_display_name: device display name from client
+ auth_provider_id: The SSO IdP the user used, if any.
+ auth_provider_session_id: The session ID (sid) got from the SSO IdP.
Returns:
device id (generated if none was supplied)
"""
@@ -323,6 +327,8 @@ class DeviceHandler(DeviceWorkerHandler):
user_id=user_id,
device_id=device_id,
initial_device_display_name=initial_device_display_name,
+ auth_provider_id=auth_provider_id,
+ auth_provider_session_id=auth_provider_session_id,
)
if new_device:
await self.notify_device_update(user_id, [device_id])
@@ -337,6 +343,8 @@ class DeviceHandler(DeviceWorkerHandler):
user_id=user_id,
device_id=new_device_id,
initial_device_display_name=initial_device_display_name,
+ auth_provider_id=auth_provider_id,
+ auth_provider_session_id=auth_provider_session_id,
)
if new_device:
await self.notify_device_update(user_id, [new_device_id])
diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py
index 3665d91513..deb3539751 100644
--- a/synapse/handlers/oidc.py
+++ b/synapse/handlers/oidc.py
@@ -23,7 +23,7 @@ from authlib.common.security import generate_token
from authlib.jose import JsonWebToken, jwt
from authlib.oauth2.auth import ClientAuth
from authlib.oauth2.rfc6749.parameters import prepare_grant_uri
-from authlib.oidc.core import CodeIDToken, ImplicitIDToken, UserInfo
+from authlib.oidc.core import CodeIDToken, UserInfo
from authlib.oidc.discovery import OpenIDProviderMetadata, get_well_known_url
from jinja2 import Environment, Template
from pymacaroons.exceptions import (
@@ -117,7 +117,8 @@ class OidcHandler:
for idp_id, p in self._providers.items():
try:
await p.load_metadata()
- await p.load_jwks()
+ if not p._uses_userinfo:
+ await p.load_jwks()
except Exception as e:
raise Exception(
"Error while initialising OIDC provider %r" % (idp_id,)
@@ -498,10 +499,6 @@ class OidcProvider:
return await self._jwks.get()
async def _load_jwks(self) -> JWKS:
- if self._uses_userinfo:
- # We're not using jwt signing, return an empty jwk set
- return {"keys": []}
-
metadata = await self.load_metadata()
# Load the JWKS using the `jwks_uri` metadata.
@@ -663,7 +660,7 @@ class OidcProvider:
return UserInfo(resp)
- async def _parse_id_token(self, token: Token, nonce: str) -> UserInfo:
+ async def _parse_id_token(self, token: Token, nonce: str) -> CodeIDToken:
"""Return an instance of UserInfo from token's ``id_token``.
Args:
@@ -673,7 +670,7 @@ class OidcProvider:
request. This value should match the one inside the token.
Returns:
- An object representing the user.
+ The decoded claims in the ID token.
"""
metadata = await self.load_metadata()
claims_params = {
@@ -684,9 +681,6 @@ class OidcProvider:
# If we got an `access_token`, there should be an `at_hash` claim
# in the `id_token` that we can check against.
claims_params["access_token"] = token["access_token"]
- claims_cls = CodeIDToken
- else:
- claims_cls = ImplicitIDToken
alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"])
jwt = JsonWebToken(alg_values)
@@ -703,7 +697,7 @@ class OidcProvider:
claims = jwt.decode(
id_token,
key=jwk_set,
- claims_cls=claims_cls,
+ claims_cls=CodeIDToken,
claims_options=claim_options,
claims_params=claims_params,
)
@@ -713,7 +707,7 @@ class OidcProvider:
claims = jwt.decode(
id_token,
key=jwk_set,
- claims_cls=claims_cls,
+ claims_cls=CodeIDToken,
claims_options=claim_options,
claims_params=claims_params,
)
@@ -721,7 +715,8 @@ class OidcProvider:
logger.debug("Decoded id_token JWT %r; validating", claims)
claims.validate(leeway=120) # allows 2 min of clock skew
- return UserInfo(claims)
+
+ return claims
async def handle_redirect_request(
self,
@@ -837,8 +832,22 @@ class OidcProvider:
logger.debug("Successfully obtained OAuth2 token data: %r", token)
- # Now that we have a token, get the userinfo, either by decoding the
- # `id_token` or by fetching the `userinfo_endpoint`.
+ # If there is an id_token, it should be validated, regardless of the
+ # userinfo endpoint is used or not.
+ if token.get("id_token") is not None:
+ try:
+ id_token = await self._parse_id_token(token, nonce=session_data.nonce)
+ sid = id_token.get("sid")
+ except Exception as e:
+ logger.exception("Invalid id_token")
+ self._sso_handler.render_error(request, "invalid_token", str(e))
+ return
+ else:
+ id_token = None
+ sid = None
+
+ # Now that we have a token, get the userinfo either from the `id_token`
+ # claims or by fetching the `userinfo_endpoint`.
if self._uses_userinfo:
try:
userinfo = await self._fetch_userinfo(token)
@@ -846,13 +855,14 @@ class OidcProvider:
logger.exception("Could not fetch userinfo")
self._sso_handler.render_error(request, "fetch_error", str(e))
return
+ elif id_token is not None:
+ userinfo = UserInfo(id_token)
else:
- try:
- userinfo = await self._parse_id_token(token, nonce=session_data.nonce)
- except Exception as e:
- logger.exception("Invalid id_token")
- self._sso_handler.render_error(request, "invalid_token", str(e))
- return
+ logger.error("Missing id_token in token response")
+ self._sso_handler.render_error(
+ request, "invalid_token", "Missing id_token in token response"
+ )
+ return
# first check if we're doing a UIA
if session_data.ui_auth_session_id:
@@ -884,7 +894,7 @@ class OidcProvider:
# Call the mapper to register/login the user
try:
await self._complete_oidc_login(
- userinfo, token, request, session_data.client_redirect_url
+ userinfo, token, request, session_data.client_redirect_url, sid
)
except MappingException as e:
logger.exception("Could not map user")
@@ -896,6 +906,7 @@ class OidcProvider:
token: Token,
request: SynapseRequest,
client_redirect_url: str,
+ sid: Optional[str],
) -> None:
"""Given a UserInfo response, complete the login flow
@@ -1008,6 +1019,7 @@ class OidcProvider:
oidc_response_to_user_attributes,
grandfather_existing_users,
extra_attributes,
+ auth_provider_session_id=sid,
)
def _remote_id_from_userinfo(self, userinfo: UserInfo) -> str:
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index b14ddd8267..f08a516a75 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -746,6 +746,7 @@ class RegistrationHandler:
is_appservice_ghost: bool = False,
auth_provider_id: Optional[str] = None,
should_issue_refresh_token: bool = False,
+ auth_provider_session_id: Optional[str] = None,
) -> Tuple[str, str, Optional[int], Optional[str]]:
"""Register a device for a user and generate an access token.
@@ -756,9 +757,9 @@ class RegistrationHandler:
device_id: The device ID to check, or None to generate a new one.
initial_display_name: An optional display name for the device.
is_guest: Whether this is a guest account
- auth_provider_id: The SSO IdP the user used, if any (just used for the
- prometheus metrics).
+ auth_provider_id: The SSO IdP the user used, if any.
should_issue_refresh_token: Whether it should also issue a refresh token
+ auth_provider_session_id: The session ID received during login from the SSO IdP.
Returns:
Tuple of device ID, access token, access token expiration time and refresh token
"""
@@ -769,6 +770,8 @@ class RegistrationHandler:
is_guest=is_guest,
is_appservice_ghost=is_appservice_ghost,
should_issue_refresh_token=should_issue_refresh_token,
+ auth_provider_id=auth_provider_id,
+ auth_provider_session_id=auth_provider_session_id,
)
login_counter.labels(
@@ -791,6 +794,8 @@ class RegistrationHandler:
is_guest: bool = False,
is_appservice_ghost: bool = False,
should_issue_refresh_token: bool = False,
+ auth_provider_id: Optional[str] = None,
+ auth_provider_session_id: Optional[str] = None,
) -> LoginDict:
"""Helper for register_device
@@ -822,7 +827,11 @@ class RegistrationHandler:
refresh_token_id = None
registered_device_id = await self.device_handler.check_device_registered(
- user_id, device_id, initial_display_name
+ user_id,
+ device_id,
+ initial_display_name,
+ auth_provider_id=auth_provider_id,
+ auth_provider_session_id=auth_provider_session_id,
)
if is_guest:
assert access_token_expiry is None
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 49fde01cf0..65c27bc64a 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -365,6 +365,7 @@ class SsoHandler:
sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
grandfather_existing_users: Callable[[], Awaitable[Optional[str]]],
extra_login_attributes: Optional[JsonDict] = None,
+ auth_provider_session_id: Optional[str] = None,
) -> None:
"""
Given an SSO ID, retrieve the user ID for it and possibly register the user.
@@ -415,6 +416,8 @@ class SsoHandler:
extra_login_attributes: An optional dictionary of extra
attributes to be provided to the client in the login response.
+ auth_provider_session_id: An optional session ID from the IdP.
+
Raises:
MappingException if there was a problem mapping the response to a user.
RedirectException: if the mapping provider needs to redirect the user
@@ -490,6 +493,7 @@ class SsoHandler:
client_redirect_url,
extra_login_attributes,
new_user=new_user,
+ auth_provider_session_id=auth_provider_session_id,
)
async def _call_attribute_mapper(
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index a8154168be..6bfb4b8d1b 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -626,6 +626,7 @@ class ModuleApi:
user_id: str,
duration_in_ms: int = (2 * 60 * 1000),
auth_provider_id: str = "",
+ auth_provider_session_id: Optional[str] = None,
) -> str:
"""Generate a login token suitable for m.login.token authentication
@@ -643,6 +644,7 @@ class ModuleApi:
return self._hs.get_macaroon_generator().generate_short_term_login_token(
user_id,
auth_provider_id,
+ auth_provider_session_id,
duration_in_ms,
)
diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py
index 0db419ea57..daacc34cea 100644
--- a/synapse/replication/http/login.py
+++ b/synapse/replication/http/login.py
@@ -46,6 +46,8 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
is_guest,
is_appservice_ghost,
should_issue_refresh_token,
+ auth_provider_id,
+ auth_provider_session_id,
):
"""
Args:
@@ -63,6 +65,8 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
"is_guest": is_guest,
"is_appservice_ghost": is_appservice_ghost,
"should_issue_refresh_token": should_issue_refresh_token,
+ "auth_provider_id": auth_provider_id,
+ "auth_provider_session_id": auth_provider_session_id,
}
async def _handle_request(self, request, user_id):
@@ -73,6 +77,8 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
is_guest = content["is_guest"]
is_appservice_ghost = content["is_appservice_ghost"]
should_issue_refresh_token = content["should_issue_refresh_token"]
+ auth_provider_id = content["auth_provider_id"]
+ auth_provider_session_id = content["auth_provider_session_id"]
res = await self.registration_handler.register_device_inner(
user_id,
@@ -81,6 +87,8 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
is_guest,
is_appservice_ghost=is_appservice_ghost,
should_issue_refresh_token=should_issue_refresh_token,
+ auth_provider_id=auth_provider_id,
+ auth_provider_session_id=auth_provider_session_id,
)
return 200, res
diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py
index a66ee4fb3d..1b23fa18cf 100644
--- a/synapse/rest/client/login.py
+++ b/synapse/rest/client/login.py
@@ -303,6 +303,7 @@ class LoginRestServlet(RestServlet):
ratelimit: bool = True,
auth_provider_id: Optional[str] = None,
should_issue_refresh_token: bool = False,
+ auth_provider_session_id: Optional[str] = None,
) -> LoginResponse:
"""Called when we've successfully authed the user and now need to
actually login them in (e.g. create devices). This gets called on
@@ -318,10 +319,10 @@ class LoginRestServlet(RestServlet):
create_non_existent_users: Whether to create the user if they don't
exist. Defaults to False.
ratelimit: Whether to ratelimit the login request.
- auth_provider_id: The SSO IdP the user used, if any (just used for the
- prometheus metrics).
+ auth_provider_id: The SSO IdP the user used, if any.
should_issue_refresh_token: True if this login should issue
a refresh token alongside the access token.
+ auth_provider_session_id: The session ID got during login from the SSO IdP.
Returns:
result: Dictionary of account information after successful login.
@@ -354,6 +355,7 @@ class LoginRestServlet(RestServlet):
initial_display_name,
auth_provider_id=auth_provider_id,
should_issue_refresh_token=should_issue_refresh_token,
+ auth_provider_session_id=auth_provider_session_id,
)
result = LoginResponse(
@@ -399,6 +401,7 @@ class LoginRestServlet(RestServlet):
self.auth_handler._sso_login_callback,
auth_provider_id=res.auth_provider_id,
should_issue_refresh_token=should_issue_refresh_token,
+ auth_provider_session_id=res.auth_provider_session_id,
)
async def _do_jwt_login(
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 9ccc66e589..d5a4a661cd 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -139,6 +139,27 @@ class DeviceWorkerStore(SQLBaseStore):
return {d["device_id"]: d for d in devices}
+ async def get_devices_by_auth_provider_session_id(
+ self, auth_provider_id: str, auth_provider_session_id: str
+ ) -> List[Dict[str, Any]]:
+ """Retrieve the list of devices associated with a SSO IdP session ID.
+
+ Args:
+ auth_provider_id: The SSO IdP ID as defined in the server config
+ auth_provider_session_id: The session ID within the IdP
+ Returns:
+ A list of dicts containing the device_id and the user_id of each device
+ """
+ return await self.db_pool.simple_select_list(
+ table="device_auth_providers",
+ keyvalues={
+ "auth_provider_id": auth_provider_id,
+ "auth_provider_session_id": auth_provider_session_id,
+ },
+ retcols=("user_id", "device_id"),
+ desc="get_devices_by_auth_provider_session_id",
+ )
+
@trace
async def get_device_updates_by_remote(
self, destination: str, from_stream_id: int, limit: int
@@ -1070,7 +1091,12 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
)
async def store_device(
- self, user_id: str, device_id: str, initial_device_display_name: Optional[str]
+ self,
+ user_id: str,
+ device_id: str,
+ initial_device_display_name: Optional[str],
+ auth_provider_id: Optional[str] = None,
+ auth_provider_session_id: Optional[str] = None,
) -> bool:
"""Ensure the given device is known; add it to the store if not
@@ -1079,6 +1105,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
device_id: id of device
initial_device_display_name: initial displayname of the device.
Ignored if device exists.
+ auth_provider_id: The SSO IdP the user used, if any.
+ auth_provider_session_id: The session ID (sid) got from a OIDC login.
Returns:
Whether the device was inserted or an existing device existed with that ID.
@@ -1115,6 +1143,18 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
if hidden:
raise StoreError(400, "The device ID is in use", Codes.FORBIDDEN)
+ if auth_provider_id and auth_provider_session_id:
+ await self.db_pool.simple_insert(
+ "device_auth_providers",
+ values={
+ "user_id": user_id,
+ "device_id": device_id,
+ "auth_provider_id": auth_provider_id,
+ "auth_provider_session_id": auth_provider_session_id,
+ },
+ desc="store_device_auth_provider",
+ )
+
self.device_id_exists_cache.set(key, True)
return inserted
except StoreError:
@@ -1168,6 +1208,14 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
keyvalues={"user_id": user_id},
)
+ self.db_pool.simple_delete_many_txn(
+ txn,
+ table="device_auth_providers",
+ column="device_id",
+ values=device_ids,
+ keyvalues={"user_id": user_id},
+ )
+
await self.db_pool.runInteraction("delete_devices", _delete_devices_txn)
for device_id in device_ids:
self.device_id_exists_cache.invalidate((user_id, device_id))
diff --git a/synapse/storage/schema/main/delta/65/11_devices_auth_provider_session.sql b/synapse/storage/schema/main/delta/65/11_devices_auth_provider_session.sql
new file mode 100644
index 0000000000..a65bfb520d
--- /dev/null
+++ b/synapse/storage/schema/main/delta/65/11_devices_auth_provider_session.sql
@@ -0,0 +1,27 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Track the auth provider used by each login as well as the session ID
+CREATE TABLE device_auth_providers (
+ user_id TEXT NOT NULL,
+ device_id TEXT NOT NULL,
+ auth_provider_id TEXT NOT NULL,
+ auth_provider_session_id TEXT NOT NULL
+);
+
+CREATE INDEX device_auth_providers_devices
+ ON device_auth_providers (user_id, device_id);
+CREATE INDEX device_auth_providers_sessions
+ ON device_auth_providers (auth_provider_id, auth_provider_session_id);
|