diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py
index 341135822e..b1a5df9638 100644
--- a/synapse/handlers/account_data.py
+++ b/synapse/handlers/account_data.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,14 +13,157 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import random
from typing import TYPE_CHECKING, List, Tuple
+from synapse.replication.http.account_data import (
+ ReplicationAddTagRestServlet,
+ ReplicationRemoveTagRestServlet,
+ ReplicationRoomAccountDataRestServlet,
+ ReplicationUserAccountDataRestServlet,
+)
from synapse.types import JsonDict, UserID
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
+class AccountDataHandler:
+ def __init__(self, hs: "HomeServer"):
+ self._store = hs.get_datastore()
+ self._instance_name = hs.get_instance_name()
+ self._notifier = hs.get_notifier()
+
+ self._user_data_client = ReplicationUserAccountDataRestServlet.make_client(hs)
+ self._room_data_client = ReplicationRoomAccountDataRestServlet.make_client(hs)
+ self._add_tag_client = ReplicationAddTagRestServlet.make_client(hs)
+ self._remove_tag_client = ReplicationRemoveTagRestServlet.make_client(hs)
+ self._account_data_writers = hs.config.worker.writers.account_data
+
+ async def add_account_data_to_room(
+ self, user_id: str, room_id: str, account_data_type: str, content: JsonDict
+ ) -> int:
+ """Add some account_data to a room for a user.
+
+ Args:
+ user_id: The user to add a tag for.
+ room_id: The room to add a tag for.
+ account_data_type: The type of account_data to add.
+ content: A json object to associate with the tag.
+
+ Returns:
+ The maximum stream ID.
+ """
+ if self._instance_name in self._account_data_writers:
+ max_stream_id = await self._store.add_account_data_to_room(
+ user_id, room_id, account_data_type, content
+ )
+
+ self._notifier.on_new_event(
+ "account_data_key", max_stream_id, users=[user_id]
+ )
+
+ return max_stream_id
+ else:
+ response = await self._room_data_client(
+ instance_name=random.choice(self._account_data_writers),
+ user_id=user_id,
+ room_id=room_id,
+ account_data_type=account_data_type,
+ content=content,
+ )
+ return response["max_stream_id"]
+
+ async def add_account_data_for_user(
+ self, user_id: str, account_data_type: str, content: JsonDict
+ ) -> int:
+ """Add some account_data to a room for a user.
+
+ Args:
+ user_id: The user to add a tag for.
+ account_data_type: The type of account_data to add.
+ content: A json object to associate with the tag.
+
+ Returns:
+ The maximum stream ID.
+ """
+
+ if self._instance_name in self._account_data_writers:
+ max_stream_id = await self._store.add_account_data_for_user(
+ user_id, account_data_type, content
+ )
+
+ self._notifier.on_new_event(
+ "account_data_key", max_stream_id, users=[user_id]
+ )
+ return max_stream_id
+ else:
+ response = await self._user_data_client(
+ instance_name=random.choice(self._account_data_writers),
+ user_id=user_id,
+ account_data_type=account_data_type,
+ content=content,
+ )
+ return response["max_stream_id"]
+
+ async def add_tag_to_room(
+ self, user_id: str, room_id: str, tag: str, content: JsonDict
+ ) -> int:
+ """Add a tag to a room for a user.
+
+ Args:
+ user_id: The user to add a tag for.
+ room_id: The room to add a tag for.
+ tag: The tag name to add.
+ content: A json object to associate with the tag.
+
+ Returns:
+ The next account data ID.
+ """
+ if self._instance_name in self._account_data_writers:
+ max_stream_id = await self._store.add_tag_to_room(
+ user_id, room_id, tag, content
+ )
+
+ self._notifier.on_new_event(
+ "account_data_key", max_stream_id, users=[user_id]
+ )
+ return max_stream_id
+ else:
+ response = await self._add_tag_client(
+ instance_name=random.choice(self._account_data_writers),
+ user_id=user_id,
+ room_id=room_id,
+ tag=tag,
+ content=content,
+ )
+ return response["max_stream_id"]
+
+ async def remove_tag_from_room(self, user_id: str, room_id: str, tag: str) -> int:
+ """Remove a tag from a room for a user.
+
+ Returns:
+ The next account data ID.
+ """
+ if self._instance_name in self._account_data_writers:
+ max_stream_id = await self._store.remove_tag_from_room(
+ user_id, room_id, tag
+ )
+
+ self._notifier.on_new_event(
+ "account_data_key", max_stream_id, users=[user_id]
+ )
+ return max_stream_id
+ else:
+ response = await self._remove_tag_client(
+ instance_name=random.choice(self._account_data_writers),
+ user_id=user_id,
+ room_id=room_id,
+ tag=tag,
+ )
+ return response["max_stream_id"]
+
+
class AccountDataEventSource:
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index f4434673dc..0e98db22b3 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -49,8 +49,13 @@ from synapse.api.errors import (
UserDeactivatedError,
)
from synapse.api.ratelimiting import Ratelimiter
-from synapse.handlers.ui_auth import INTERACTIVE_AUTH_CHECKERS
+from synapse.handlers._base import BaseHandler
+from synapse.handlers.ui_auth import (
+ INTERACTIVE_AUTH_CHECKERS,
+ UIAuthSessionDataConstants,
+)
from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
+from synapse.http import get_request_user_agent
from synapse.http.server import finish_request, respond_with_html
from synapse.http.site import SynapseRequest
from synapse.logging.context import defer_to_thread
@@ -62,8 +67,6 @@ from synapse.util.async_helpers import maybe_awaitable
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.threepids import canonicalise_email
-from ._base import BaseHandler
-
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
@@ -260,10 +263,6 @@ class AuthHandler(BaseHandler):
# authenticating for an operation to occur on their account.
self._sso_auth_confirm_template = hs.config.sso_auth_confirm_template
- # The following template is shown after a successful user interactive
- # authentication session. It tells the user they can close the window.
- self._sso_auth_success_template = hs.config.sso_auth_success_template
-
# The following template is shown during the SSO authentication process if
# the account is deactivated.
self._sso_account_deactivated_template = (
@@ -284,7 +283,6 @@ class AuthHandler(BaseHandler):
requester: Requester,
request: SynapseRequest,
request_body: Dict[str, Any],
- clientip: str,
description: str,
) -> Tuple[dict, Optional[str]]:
"""
@@ -301,8 +299,6 @@ class AuthHandler(BaseHandler):
request_body: The body of the request sent by the client
- clientip: The IP address of the client.
-
description: A human readable string to be displayed to the user that
describes the operation happening on their account.
@@ -338,10 +334,10 @@ class AuthHandler(BaseHandler):
request_body.pop("auth", None)
return request_body, None
- user_id = requester.user.to_string()
+ requester_user_id = requester.user.to_string()
# Check if we should be ratelimited due to too many previous failed attempts
- self._failed_uia_attempts_ratelimiter.ratelimit(user_id, update=False)
+ self._failed_uia_attempts_ratelimiter.ratelimit(requester_user_id, update=False)
# build a list of supported flows
supported_ui_auth_types = await self._get_available_ui_auth_types(
@@ -349,13 +345,16 @@ class AuthHandler(BaseHandler):
)
flows = [[login_type] for login_type in supported_ui_auth_types]
+ def get_new_session_data() -> JsonDict:
+ return {UIAuthSessionDataConstants.REQUEST_USER_ID: requester_user_id}
+
try:
result, params, session_id = await self.check_ui_auth(
- flows, request, request_body, clientip, description
+ flows, request, request_body, description, get_new_session_data,
)
except LoginError:
# Update the ratelimiter to say we failed (`can_do_action` doesn't raise).
- self._failed_uia_attempts_ratelimiter.can_do_action(user_id)
+ self._failed_uia_attempts_ratelimiter.can_do_action(requester_user_id)
raise
# find the completed login type
@@ -363,14 +362,14 @@ class AuthHandler(BaseHandler):
if login_type not in result:
continue
- user_id = result[login_type]
+ validated_user_id = result[login_type]
break
else:
# this can't happen
raise Exception("check_auth returned True but no successful login type")
# check that the UI auth matched the access token
- if user_id != requester.user.to_string():
+ if validated_user_id != requester_user_id:
raise AuthError(403, "Invalid auth")
# Note that the access token has been validated.
@@ -402,13 +401,9 @@ class AuthHandler(BaseHandler):
# if sso is enabled, allow the user to log in via SSO iff they have a mapping
# from sso to mxid.
- if self.hs.config.saml2.saml2_enabled or self.hs.config.oidc.oidc_enabled:
- if await self.store.get_external_ids_by_user(user.to_string()):
- ui_auth_types.add(LoginType.SSO)
-
- # Our CAS impl does not (yet) correctly register users in user_external_ids,
- # so always offer that if it's available.
- if self.hs.config.cas.cas_enabled:
+ if await self.hs.get_sso_handler().get_identity_providers_for_user(
+ user.to_string()
+ ):
ui_auth_types.add(LoginType.SSO)
return ui_auth_types
@@ -426,8 +421,8 @@ class AuthHandler(BaseHandler):
flows: List[List[str]],
request: SynapseRequest,
clientdict: Dict[str, Any],
- clientip: str,
description: str,
+ get_new_session_data: Optional[Callable[[], JsonDict]] = None,
) -> Tuple[dict, dict, str]:
"""
Takes a dictionary sent by the client in the login / registration
@@ -448,11 +443,16 @@ class AuthHandler(BaseHandler):
clientdict: The dictionary from the client root level, not the
'auth' key: this method prompts for auth if none is sent.
- clientip: The IP address of the client.
-
description: A human readable string to be displayed to the user that
describes the operation happening on their account.
+ get_new_session_data:
+ an optional callback which will be called when starting a new session.
+ it should return data to be stored as part of the session.
+
+ The keys of the returned data should be entries in
+ UIAuthSessionDataConstants.
+
Returns:
A tuple of (creds, params, session_id).
@@ -480,10 +480,15 @@ class AuthHandler(BaseHandler):
# If there's no session ID, create a new session.
if not sid:
+ new_session_data = get_new_session_data() if get_new_session_data else {}
+
session = await self.store.create_ui_auth_session(
clientdict, uri, method, description
)
+ for k, v in new_session_data.items():
+ await self.set_session_data(session.session_id, k, v)
+
else:
try:
session = await self.store.get_ui_auth_session(sid)
@@ -539,7 +544,8 @@ class AuthHandler(BaseHandler):
# authentication flow.
await self.store.set_ui_auth_clientdict(sid, clientdict)
- user_agent = request.get_user_agent("")
+ user_agent = get_request_user_agent(request)
+ clientip = request.getClientIP()
await self.store.add_user_agent_ip_to_ui_auth_session(
session.session_id, user_agent, clientip
@@ -644,7 +650,8 @@ class AuthHandler(BaseHandler):
Args:
session_id: The ID of this session as returned from check_auth
- key: The key to store the data under
+ key: The key to store the data under. An entry from
+ UIAuthSessionDataConstants.
value: The data to store
"""
try:
@@ -660,7 +667,8 @@ class AuthHandler(BaseHandler):
Args:
session_id: The ID of this session as returned from check_auth
- key: The key to store the data under
+ key: The key the data was stored under. An entry from
+ UIAuthSessionDataConstants.
default: Value to return if the key has not been set
"""
try:
@@ -1334,12 +1342,12 @@ class AuthHandler(BaseHandler):
else:
return False
- async def start_sso_ui_auth(self, redirect_url: str, session_id: str) -> str:
+ async def start_sso_ui_auth(self, request: SynapseRequest, session_id: str) -> str:
"""
Get the HTML for the SSO redirect confirmation page.
Args:
- redirect_url: The URL to redirect to the SSO provider.
+ request: The incoming HTTP request
session_id: The user interactive authentication session ID.
Returns:
@@ -1349,30 +1357,38 @@ class AuthHandler(BaseHandler):
session = await self.store.get_ui_auth_session(session_id)
except StoreError:
raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
- return self._sso_auth_confirm_template.render(
- description=session.description, redirect_url=redirect_url,
+
+ user_id_to_verify = await self.get_session_data(
+ session_id, UIAuthSessionDataConstants.REQUEST_USER_ID
+ ) # type: str
+
+ idps = await self.hs.get_sso_handler().get_identity_providers_for_user(
+ user_id_to_verify
)
- async def complete_sso_ui_auth(
- self, registered_user_id: str, session_id: str, request: Request,
- ):
- """Having figured out a mxid for this user, complete the HTTP request
+ if not idps:
+ # we checked that the user had some remote identities before offering an SSO
+ # flow, so either it's been deleted or the client has requested SSO despite
+ # it not being offered.
+ raise SynapseError(400, "User has no SSO identities")
- Args:
- registered_user_id: The registered user ID to complete SSO login for.
- session_id: The ID of the user-interactive auth session.
- request: The request to complete.
- """
- # Mark the stage of the authentication as successful.
- # Save the user who authenticated with SSO, this will be used to ensure
- # that the account be modified is also the person who logged in.
- await self.store.mark_ui_auth_stage_complete(
- session_id, LoginType.SSO, registered_user_id
+ # for now, just pick one
+ idp_id, sso_auth_provider = next(iter(idps.items()))
+ if len(idps) > 0:
+ logger.warning(
+ "User %r has previously logged in with multiple SSO IdPs; arbitrarily "
+ "picking %r",
+ user_id_to_verify,
+ idp_id,
+ )
+
+ redirect_url = await sso_auth_provider.handle_redirect_request(
+ request, None, session_id
)
- # Render the HTML and return.
- html = self._sso_auth_success_template
- respond_with_html(request, 200, html)
+ return self._sso_auth_confirm_template.render(
+ description=session.description, redirect_url=redirect_url,
+ )
async def complete_sso_login(
self,
@@ -1488,8 +1504,8 @@ class AuthHandler(BaseHandler):
@staticmethod
def add_query_param_to_url(url: str, param_name: str, param: Any):
url_parts = list(urllib.parse.urlparse(url))
- query = dict(urllib.parse.parse_qsl(url_parts[4]))
- query.update({param_name: param})
+ query = urllib.parse.parse_qsl(url_parts[4], keep_blank_values=True)
+ query.append((param_name, param))
url_parts[4] = urllib.parse.urlencode(query)
return urllib.parse.urlunparse(url_parts)
diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
index f3430c6713..0f342c607b 100644
--- a/synapse/handlers/cas_handler.py
+++ b/synapse/handlers/cas_handler.py
@@ -80,6 +80,10 @@ class CasHandler:
# user-facing name of this auth provider
self.idp_name = "CAS"
+ # we do not currently support icons for CAS auth, but this is required by
+ # the SsoIdentityProvider protocol type.
+ self.idp_icon = None
+
self._sso_handler = hs.get_sso_handler()
self._sso_handler.register_identity_provider(self)
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index e808142365..c4a3b26a84 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Optional
from synapse.api.errors import SynapseError
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.types import UserID, create_requester
+from synapse.types import Requester, UserID, create_requester
from ._base import BaseHandler
@@ -38,6 +38,7 @@ class DeactivateAccountHandler(BaseHandler):
self._device_handler = hs.get_device_handler()
self._room_member_handler = hs.get_room_member_handler()
self._identity_handler = hs.get_identity_handler()
+ self._profile_handler = hs.get_profile_handler()
self.user_directory_handler = hs.get_user_directory_handler()
self._server_name = hs.hostname
@@ -52,16 +53,23 @@ class DeactivateAccountHandler(BaseHandler):
self._account_validity_enabled = hs.config.account_validity.enabled
async def deactivate_account(
- self, user_id: str, erase_data: bool, id_server: Optional[str] = None
+ self,
+ user_id: str,
+ erase_data: bool,
+ requester: Requester,
+ id_server: Optional[str] = None,
+ by_admin: bool = False,
) -> bool:
"""Deactivate a user's account
Args:
user_id: ID of user to be deactivated
erase_data: whether to GDPR-erase the user's data
+ requester: The user attempting to make this change.
id_server: Use the given identity server when unbinding
any threepids. If None then will attempt to unbind using the
identity server specified when binding (if known).
+ by_admin: Whether this change was made by an administrator.
Returns:
True if identity server supports removing threepids, otherwise False.
@@ -121,6 +129,12 @@ class DeactivateAccountHandler(BaseHandler):
# Mark the user as erased, if they asked for that
if erase_data:
+ user = UserID.from_string(user_id)
+ # Remove avatar URL from this user
+ await self._profile_handler.set_avatar_url(user, requester, "", by_admin)
+ # Remove displayname from this user
+ await self._profile_handler.set_displayname(user, requester, "", by_admin)
+
logger.info("Marking %s as erased", user_id)
await self.store.mark_user_erased(user_id)
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index fc974a82e8..0c7737e09d 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -163,7 +163,7 @@ class DeviceMessageHandler:
await self.store.mark_remote_user_device_cache_as_stale(sender_user_id)
# Immediately attempt a resync in the background
- run_in_background(self._user_device_resync, sender_user_id)
+ run_in_background(self._user_device_resync, user_id=sender_user_id)
async def send_device_message(
self,
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index c05036ad1f..f61844d688 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -476,8 +476,6 @@ class IdentityHandler(BaseHandler):
except RequestTimedOutError:
raise SynapseError(500, "Timed out contacting identity server")
- assert self.hs.config.public_baseurl
-
# we need to tell the client to send the token back to us, since it doesn't
# otherwise know where to send it, so add submit_url response parameter
# (see also MSC2078)
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 6835c6c462..1607e12935 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -14,7 +14,7 @@
# limitations under the License.
import inspect
import logging
-from typing import TYPE_CHECKING, Dict, Generic, List, Optional, Tuple, TypeVar
+from typing import TYPE_CHECKING, Dict, Generic, List, Optional, TypeVar
from urllib.parse import urlencode
import attr
@@ -35,7 +35,7 @@ from typing_extensions import TypedDict
from twisted.web.client import readBody
from synapse.config import ConfigError
-from synapse.handlers._base import BaseHandler
+from synapse.config.oidc_config import OidcProviderConfig
from synapse.handlers.sso import MappingException, UserAttributes
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
@@ -71,6 +71,144 @@ JWK = Dict[str, str]
JWKS = TypedDict("JWKS", {"keys": List[JWK]})
+class OidcHandler:
+ """Handles requests related to the OpenID Connect login flow.
+ """
+
+ def __init__(self, hs: "HomeServer"):
+ self._sso_handler = hs.get_sso_handler()
+
+ provider_confs = hs.config.oidc.oidc_providers
+ # we should not have been instantiated if there is no configured provider.
+ assert provider_confs
+
+ self._token_generator = OidcSessionTokenGenerator(hs)
+ self._providers = {
+ p.idp_id: OidcProvider(hs, self._token_generator, p) for p in provider_confs
+ } # type: Dict[str, OidcProvider]
+
+ async def load_metadata(self) -> None:
+ """Validate the config and load the metadata from the remote endpoint.
+
+ Called at startup to ensure we have everything we need.
+ """
+ for idp_id, p in self._providers.items():
+ try:
+ await p.load_metadata()
+ await p.load_jwks()
+ except Exception as e:
+ raise Exception(
+ "Error while initialising OIDC provider %r" % (idp_id,)
+ ) from e
+
+ async def handle_oidc_callback(self, request: SynapseRequest) -> None:
+ """Handle an incoming request to /_synapse/oidc/callback
+
+ Since we might want to display OIDC-related errors in a user-friendly
+ way, we don't raise SynapseError from here. Instead, we call
+ ``self._sso_handler.render_error`` which displays an HTML page for the error.
+
+ Most of the OpenID Connect logic happens here:
+
+ - first, we check if there was any error returned by the provider and
+ display it
+ - then we fetch the session cookie, decode and verify it
+ - the ``state`` query parameter should match with the one stored in the
+ session cookie
+
+ Once we know the session is legit, we then delegate to the OIDC Provider
+ implementation, which will exchange the code with the provider and complete the
+ login/authentication.
+
+ Args:
+ request: the incoming request from the browser.
+ """
+
+ # The provider might redirect with an error.
+ # In that case, just display it as-is.
+ if b"error" in request.args:
+ # error response from the auth server. see:
+ # https://tools.ietf.org/html/rfc6749#section-4.1.2.1
+ # https://openid.net/specs/openid-connect-core-1_0.html#AuthError
+ error = request.args[b"error"][0].decode()
+ description = request.args.get(b"error_description", [b""])[0].decode()
+
+ # Most of the errors returned by the provider could be due by
+ # either the provider misbehaving or Synapse being misconfigured.
+ # The only exception of that is "access_denied", where the user
+ # probably cancelled the login flow. In other cases, log those errors.
+ if error != "access_denied":
+ logger.error("Error from the OIDC provider: %s %s", error, description)
+
+ self._sso_handler.render_error(request, error, description)
+ return
+
+ # otherwise, it is presumably a successful response. see:
+ # https://tools.ietf.org/html/rfc6749#section-4.1.2
+
+ # Fetch the session cookie
+ session = request.getCookie(SESSION_COOKIE_NAME) # type: Optional[bytes]
+ if session is None:
+ logger.info("No session cookie found")
+ self._sso_handler.render_error(
+ request, "missing_session", "No session cookie found"
+ )
+ return
+
+ # Remove the cookie. There is a good chance that if the callback failed
+ # once, it will fail next time and the code will already be exchanged.
+ # Removing it early avoids spamming the provider with token requests.
+ request.addCookie(
+ SESSION_COOKIE_NAME,
+ b"",
+ path="/_synapse/oidc",
+ expires="Thu, Jan 01 1970 00:00:00 UTC",
+ httpOnly=True,
+ sameSite="lax",
+ )
+
+ # Check for the state query parameter
+ if b"state" not in request.args:
+ logger.info("State parameter is missing")
+ self._sso_handler.render_error(
+ request, "invalid_request", "State parameter is missing"
+ )
+ return
+
+ state = request.args[b"state"][0].decode()
+
+ # Deserialize the session token and verify it.
+ try:
+ session_data = self._token_generator.verify_oidc_session_token(
+ session, state
+ )
+ except (MacaroonDeserializationException, ValueError) as e:
+ logger.exception("Invalid session")
+ self._sso_handler.render_error(request, "invalid_session", str(e))
+ return
+ except MacaroonInvalidSignatureException as e:
+ logger.exception("Could not verify session")
+ self._sso_handler.render_error(request, "mismatching_session", str(e))
+ return
+
+ oidc_provider = self._providers.get(session_data.idp_id)
+ if not oidc_provider:
+ logger.error("OIDC session uses unknown IdP %r", oidc_provider)
+ self._sso_handler.render_error(request, "unknown_idp", "Unknown IdP")
+ return
+
+ if b"code" not in request.args:
+ logger.info("Code parameter is missing")
+ self._sso_handler.render_error(
+ request, "invalid_request", "Code parameter is missing"
+ )
+ return
+
+ code = request.args[b"code"][0].decode()
+
+ await oidc_provider.handle_oidc_callback(request, session_data, code)
+
+
class OidcError(Exception):
"""Used to catch errors when calling the token_endpoint
"""
@@ -85,44 +223,56 @@ class OidcError(Exception):
return self.error
-class OidcHandler(BaseHandler):
- """Handles requests related to the OpenID Connect login flow.
+class OidcProvider:
+ """Wraps the config for a single OIDC IdentityProvider
+
+ Provides methods for handling redirect requests and callbacks via that particular
+ IdP.
"""
- def __init__(self, hs: "HomeServer"):
- super().__init__(hs)
+ def __init__(
+ self,
+ hs: "HomeServer",
+ token_generator: "OidcSessionTokenGenerator",
+ provider: OidcProviderConfig,
+ ):
+ self._store = hs.get_datastore()
+
+ self._token_generator = token_generator
+
self._callback_url = hs.config.oidc_callback_url # type: str
- self._scopes = hs.config.oidc_scopes # type: List[str]
- self._user_profile_method = hs.config.oidc_user_profile_method # type: str
+
+ self._scopes = provider.scopes
+ self._user_profile_method = provider.user_profile_method
self._client_auth = ClientAuth(
- hs.config.oidc_client_id,
- hs.config.oidc_client_secret,
- hs.config.oidc_client_auth_method,
+ provider.client_id, provider.client_secret, provider.client_auth_method,
) # type: ClientAuth
- self._client_auth_method = hs.config.oidc_client_auth_method # type: str
+ self._client_auth_method = provider.client_auth_method
self._provider_metadata = OpenIDProviderMetadata(
- issuer=hs.config.oidc_issuer,
- authorization_endpoint=hs.config.oidc_authorization_endpoint,
- token_endpoint=hs.config.oidc_token_endpoint,
- userinfo_endpoint=hs.config.oidc_userinfo_endpoint,
- jwks_uri=hs.config.oidc_jwks_uri,
+ issuer=provider.issuer,
+ authorization_endpoint=provider.authorization_endpoint,
+ token_endpoint=provider.token_endpoint,
+ userinfo_endpoint=provider.userinfo_endpoint,
+ jwks_uri=provider.jwks_uri,
) # type: OpenIDProviderMetadata
- self._provider_needs_discovery = hs.config.oidc_discover # type: bool
- self._user_mapping_provider = hs.config.oidc_user_mapping_provider_class(
- hs.config.oidc_user_mapping_provider_config
- ) # type: OidcMappingProvider
- self._skip_verification = hs.config.oidc_skip_verification # type: bool
- self._allow_existing_users = hs.config.oidc_allow_existing_users # type: bool
+ self._provider_needs_discovery = provider.discover
+ self._user_mapping_provider = provider.user_mapping_provider_class(
+ provider.user_mapping_provider_config
+ )
+ self._skip_verification = provider.skip_verification
+ self._allow_existing_users = provider.allow_existing_users
self._http_client = hs.get_proxied_http_client()
self._server_name = hs.config.server_name # type: str
- self._macaroon_secret_key = hs.config.macaroon_secret_key
# identifier for the external_ids table
- self.idp_id = "oidc"
+ self.idp_id = provider.idp_id
# user-facing name of this auth provider
- self.idp_name = "OIDC"
+ self.idp_name = provider.idp_name
+
+ # MXC URI for icon for this auth provider
+ self.idp_icon = provider.idp_icon
self._sso_handler = hs.get_sso_handler()
@@ -519,11 +669,14 @@ class OidcHandler(BaseHandler):
if not client_redirect_url:
client_redirect_url = b""
- cookie = self._generate_oidc_session_token(
+ cookie = self._token_generator.generate_oidc_session_token(
state=state,
- nonce=nonce,
- client_redirect_url=client_redirect_url.decode(),
- ui_auth_session_id=ui_auth_session_id,
+ session_data=OidcSessionData(
+ idp_id=self.idp_id,
+ nonce=nonce,
+ client_redirect_url=client_redirect_url.decode(),
+ ui_auth_session_id=ui_auth_session_id,
+ ),
)
request.addCookie(
SESSION_COOKIE_NAME,
@@ -546,22 +699,16 @@ class OidcHandler(BaseHandler):
nonce=nonce,
)
- async def handle_oidc_callback(self, request: SynapseRequest) -> None:
+ async def handle_oidc_callback(
+ self, request: SynapseRequest, session_data: "OidcSessionData", code: str
+ ) -> None:
"""Handle an incoming request to /_synapse/oidc/callback
- Since we might want to display OIDC-related errors in a user-friendly
- way, we don't raise SynapseError from here. Instead, we call
- ``self._sso_handler.render_error`` which displays an HTML page for the error.
+ By this time we have already validated the session on the synapse side, and
+ now need to do the provider-specific operations. This includes:
- Most of the OpenID Connect logic happens here:
-
- - first, we check if there was any error returned by the provider and
- display it
- - then we fetch the session cookie, decode and verify it
- - the ``state`` query parameter should match with the one stored in the
- session cookie
- - once we known this session is legit, exchange the code with the
- provider using the ``token_endpoint`` (see ``_exchange_code``)
+ - exchange the code with the provider using the ``token_endpoint`` (see
+ ``_exchange_code``)
- once we have the token, use it to either extract the UserInfo from
the ``id_token`` (``_parse_id_token``), or use the ``access_token``
to fetch UserInfo from the ``userinfo_endpoint``
@@ -571,88 +718,12 @@ class OidcHandler(BaseHandler):
Args:
request: the incoming request from the browser.
+ session_data: the session data, extracted from our cookie
+ code: The authorization code we got from the callback.
"""
-
- # The provider might redirect with an error.
- # In that case, just display it as-is.
- if b"error" in request.args:
- # error response from the auth server. see:
- # https://tools.ietf.org/html/rfc6749#section-4.1.2.1
- # https://openid.net/specs/openid-connect-core-1_0.html#AuthError
- error = request.args[b"error"][0].decode()
- description = request.args.get(b"error_description", [b""])[0].decode()
-
- # Most of the errors returned by the provider could be due by
- # either the provider misbehaving or Synapse being misconfigured.
- # The only exception of that is "access_denied", where the user
- # probably cancelled the login flow. In other cases, log those errors.
- if error != "access_denied":
- logger.error("Error from the OIDC provider: %s %s", error, description)
-
- self._sso_handler.render_error(request, error, description)
- return
-
- # otherwise, it is presumably a successful response. see:
- # https://tools.ietf.org/html/rfc6749#section-4.1.2
-
- # Fetch the session cookie
- session = request.getCookie(SESSION_COOKIE_NAME) # type: Optional[bytes]
- if session is None:
- logger.info("No session cookie found")
- self._sso_handler.render_error(
- request, "missing_session", "No session cookie found"
- )
- return
-
- # Remove the cookie. There is a good chance that if the callback failed
- # once, it will fail next time and the code will already be exchanged.
- # Removing it early avoids spamming the provider with token requests.
- request.addCookie(
- SESSION_COOKIE_NAME,
- b"",
- path="/_synapse/oidc",
- expires="Thu, Jan 01 1970 00:00:00 UTC",
- httpOnly=True,
- sameSite="lax",
- )
-
- # Check for the state query parameter
- if b"state" not in request.args:
- logger.info("State parameter is missing")
- self._sso_handler.render_error(
- request, "invalid_request", "State parameter is missing"
- )
- return
-
- state = request.args[b"state"][0].decode()
-
- # Deserialize the session token and verify it.
- try:
- (
- nonce,
- client_redirect_url,
- ui_auth_session_id,
- ) = self._verify_oidc_session_token(session, state)
- except MacaroonDeserializationException as e:
- logger.exception("Invalid session")
- self._sso_handler.render_error(request, "invalid_session", str(e))
- return
- except MacaroonInvalidSignatureException as e:
- logger.exception("Could not verify session")
- self._sso_handler.render_error(request, "mismatching_session", str(e))
- return
-
# Exchange the code with the provider
- if b"code" not in request.args:
- logger.info("Code parameter is missing")
- self._sso_handler.render_error(
- request, "invalid_request", "Code parameter is missing"
- )
- return
-
- logger.debug("Exchanging code")
- code = request.args[b"code"][0].decode()
try:
+ logger.debug("Exchanging code")
token = await self._exchange_code(code)
except OidcError as e:
logger.exception("Could not exchange code")
@@ -674,14 +745,14 @@ class OidcHandler(BaseHandler):
else:
logger.debug("Extracting userinfo from id_token")
try:
- userinfo = await self._parse_id_token(token, nonce=nonce)
+ userinfo = await self._parse_id_token(token, nonce=session_data.nonce)
except Exception as e:
logger.exception("Invalid id_token")
self._sso_handler.render_error(request, "invalid_token", str(e))
return
# first check if we're doing a UIA
- if ui_auth_session_id:
+ if session_data.ui_auth_session_id:
try:
remote_user_id = self._remote_id_from_userinfo(userinfo)
except Exception as e:
@@ -690,7 +761,7 @@ class OidcHandler(BaseHandler):
return
return await self._sso_handler.complete_sso_ui_auth_request(
- self.idp_id, remote_user_id, ui_auth_session_id, request
+ self.idp_id, remote_user_id, session_data.ui_auth_session_id, request
)
# otherwise, it's a login
@@ -698,133 +769,12 @@ class OidcHandler(BaseHandler):
# Call the mapper to register/login the user
try:
await self._complete_oidc_login(
- userinfo, token, request, client_redirect_url
+ userinfo, token, request, session_data.client_redirect_url
)
except MappingException as e:
logger.exception("Could not map user")
self._sso_handler.render_error(request, "mapping_error", str(e))
- def _generate_oidc_session_token(
- self,
- state: str,
- nonce: str,
- client_redirect_url: str,
- ui_auth_session_id: Optional[str],
- duration_in_ms: int = (60 * 60 * 1000),
- ) -> str:
- """Generates a signed token storing data about an OIDC session.
-
- When Synapse initiates an authorization flow, it creates a random state
- and a random nonce. Those parameters are given to the provider and
- should be verified when the client comes back from the provider.
- It is also used to store the client_redirect_url, which is used to
- complete the SSO login flow.
-
- Args:
- state: The ``state`` parameter passed to the OIDC provider.
- nonce: The ``nonce`` parameter passed to the OIDC provider.
- client_redirect_url: The URL the client gave when it initiated the
- flow.
- ui_auth_session_id: The session ID of the ongoing UI Auth (or
- None if this is a login).
- duration_in_ms: An optional duration for the token in milliseconds.
- Defaults to an hour.
-
- Returns:
- A signed macaroon token with the session information.
- """
- macaroon = pymacaroons.Macaroon(
- location=self._server_name, identifier="key", key=self._macaroon_secret_key,
- )
- macaroon.add_first_party_caveat("gen = 1")
- macaroon.add_first_party_caveat("type = session")
- macaroon.add_first_party_caveat("state = %s" % (state,))
- macaroon.add_first_party_caveat("nonce = %s" % (nonce,))
- macaroon.add_first_party_caveat(
- "client_redirect_url = %s" % (client_redirect_url,)
- )
- if ui_auth_session_id:
- macaroon.add_first_party_caveat(
- "ui_auth_session_id = %s" % (ui_auth_session_id,)
- )
- now = self.clock.time_msec()
- expiry = now + duration_in_ms
- macaroon.add_first_party_caveat("time < %d" % (expiry,))
-
- return macaroon.serialize()
-
- def _verify_oidc_session_token(
- self, session: bytes, state: str
- ) -> Tuple[str, str, Optional[str]]:
- """Verifies and extract an OIDC session token.
-
- This verifies that a given session token was issued by this homeserver
- and extract the nonce and client_redirect_url caveats.
-
- Args:
- session: The session token to verify
- state: The state the OIDC provider gave back
-
- Returns:
- The nonce, client_redirect_url, and ui_auth_session_id for this session
- """
- macaroon = pymacaroons.Macaroon.deserialize(session)
-
- v = pymacaroons.Verifier()
- v.satisfy_exact("gen = 1")
- v.satisfy_exact("type = session")
- v.satisfy_exact("state = %s" % (state,))
- v.satisfy_general(lambda c: c.startswith("nonce = "))
- v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
- # Sometimes there's a UI auth session ID, it seems to be OK to attempt
- # to always satisfy this.
- v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = "))
- v.satisfy_general(self._verify_expiry)
-
- v.verify(macaroon, self._macaroon_secret_key)
-
- # Extract the `nonce`, `client_redirect_url`, and maybe the
- # `ui_auth_session_id` from the token.
- nonce = self._get_value_from_macaroon(macaroon, "nonce")
- client_redirect_url = self._get_value_from_macaroon(
- macaroon, "client_redirect_url"
- )
- try:
- ui_auth_session_id = self._get_value_from_macaroon(
- macaroon, "ui_auth_session_id"
- ) # type: Optional[str]
- except ValueError:
- ui_auth_session_id = None
-
- return nonce, client_redirect_url, ui_auth_session_id
-
- def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str:
- """Extracts a caveat value from a macaroon token.
-
- Args:
- macaroon: the token
- key: the key of the caveat to extract
-
- Returns:
- The extracted value
-
- Raises:
- Exception: if the caveat was not in the macaroon
- """
- prefix = key + " = "
- for caveat in macaroon.caveats:
- if caveat.caveat_id.startswith(prefix):
- return caveat.caveat_id[len(prefix) :]
- raise ValueError("No %s caveat in macaroon" % (key,))
-
- def _verify_expiry(self, caveat: str) -> bool:
- prefix = "time < "
- if not caveat.startswith(prefix):
- return False
- expiry = int(caveat[len(prefix) :])
- now = self.clock.time_msec()
- return now < expiry
-
async def _complete_oidc_login(
self,
userinfo: UserInfo,
@@ -901,8 +851,8 @@ class OidcHandler(BaseHandler):
# and attempt to match it.
attributes = await oidc_response_to_user_attributes(failures=0)
- user_id = UserID(attributes.localpart, self.server_name).to_string()
- users = await self.store.get_users_by_id_case_insensitive(user_id)
+ user_id = UserID(attributes.localpart, self._server_name).to_string()
+ users = await self._store.get_users_by_id_case_insensitive(user_id)
if users:
# If an existing matrix ID is returned, then use it.
if len(users) == 1:
@@ -954,6 +904,157 @@ class OidcHandler(BaseHandler):
return str(remote_user_id)
+class OidcSessionTokenGenerator:
+ """Methods for generating and checking OIDC Session cookies."""
+
+ def __init__(self, hs: "HomeServer"):
+ self._clock = hs.get_clock()
+ self._server_name = hs.hostname
+ self._macaroon_secret_key = hs.config.key.macaroon_secret_key
+
+ def generate_oidc_session_token(
+ self,
+ state: str,
+ session_data: "OidcSessionData",
+ duration_in_ms: int = (60 * 60 * 1000),
+ ) -> str:
+ """Generates a signed token storing data about an OIDC session.
+
+ When Synapse initiates an authorization flow, it creates a random state
+ and a random nonce. Those parameters are given to the provider and
+ should be verified when the client comes back from the provider.
+ It is also used to store the client_redirect_url, which is used to
+ complete the SSO login flow.
+
+ Args:
+ state: The ``state`` parameter passed to the OIDC provider.
+ session_data: data to include in the session token.
+ duration_in_ms: An optional duration for the token in milliseconds.
+ Defaults to an hour.
+
+ Returns:
+ A signed macaroon token with the session information.
+ """
+ macaroon = pymacaroons.Macaroon(
+ location=self._server_name, identifier="key", key=self._macaroon_secret_key,
+ )
+ macaroon.add_first_party_caveat("gen = 1")
+ macaroon.add_first_party_caveat("type = session")
+ macaroon.add_first_party_caveat("state = %s" % (state,))
+ macaroon.add_first_party_caveat("idp_id = %s" % (session_data.idp_id,))
+ macaroon.add_first_party_caveat("nonce = %s" % (session_data.nonce,))
+ macaroon.add_first_party_caveat(
+ "client_redirect_url = %s" % (session_data.client_redirect_url,)
+ )
+ if session_data.ui_auth_session_id:
+ macaroon.add_first_party_caveat(
+ "ui_auth_session_id = %s" % (session_data.ui_auth_session_id,)
+ )
+ now = self._clock.time_msec()
+ expiry = now + duration_in_ms
+ macaroon.add_first_party_caveat("time < %d" % (expiry,))
+
+ return macaroon.serialize()
+
+ def verify_oidc_session_token(
+ self, session: bytes, state: str
+ ) -> "OidcSessionData":
+ """Verifies and extract an OIDC session token.
+
+ This verifies that a given session token was issued by this homeserver
+ and extract the nonce and client_redirect_url caveats.
+
+ Args:
+ session: The session token to verify
+ state: The state the OIDC provider gave back
+
+ Returns:
+ The data extracted from the session cookie
+
+ Raises:
+ ValueError if an expected caveat is missing from the macaroon.
+ """
+ macaroon = pymacaroons.Macaroon.deserialize(session)
+
+ v = pymacaroons.Verifier()
+ v.satisfy_exact("gen = 1")
+ v.satisfy_exact("type = session")
+ v.satisfy_exact("state = %s" % (state,))
+ v.satisfy_general(lambda c: c.startswith("nonce = "))
+ v.satisfy_general(lambda c: c.startswith("idp_id = "))
+ v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
+ # Sometimes there's a UI auth session ID, it seems to be OK to attempt
+ # to always satisfy this.
+ v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = "))
+ v.satisfy_general(self._verify_expiry)
+
+ v.verify(macaroon, self._macaroon_secret_key)
+
+ # Extract the session data from the token.
+ nonce = self._get_value_from_macaroon(macaroon, "nonce")
+ idp_id = self._get_value_from_macaroon(macaroon, "idp_id")
+ client_redirect_url = self._get_value_from_macaroon(
+ macaroon, "client_redirect_url"
+ )
+ try:
+ ui_auth_session_id = self._get_value_from_macaroon(
+ macaroon, "ui_auth_session_id"
+ ) # type: Optional[str]
+ except ValueError:
+ ui_auth_session_id = None
+
+ return OidcSessionData(
+ nonce=nonce,
+ idp_id=idp_id,
+ client_redirect_url=client_redirect_url,
+ ui_auth_session_id=ui_auth_session_id,
+ )
+
+ def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str:
+ """Extracts a caveat value from a macaroon token.
+
+ Args:
+ macaroon: the token
+ key: the key of the caveat to extract
+
+ Returns:
+ The extracted value
+
+ Raises:
+ ValueError: if the caveat was not in the macaroon
+ """
+ prefix = key + " = "
+ for caveat in macaroon.caveats:
+ if caveat.caveat_id.startswith(prefix):
+ return caveat.caveat_id[len(prefix) :]
+ raise ValueError("No %s caveat in macaroon" % (key,))
+
+ def _verify_expiry(self, caveat: str) -> bool:
+ prefix = "time < "
+ if not caveat.startswith(prefix):
+ return False
+ expiry = int(caveat[len(prefix) :])
+ now = self._clock.time_msec()
+ return now < expiry
+
+
+@attr.s(frozen=True, slots=True)
+class OidcSessionData:
+ """The attributes which are stored in a OIDC session cookie"""
+
+ # the Identity Provider being used
+ idp_id = attr.ib(type=str)
+
+ # The `nonce` parameter passed to the OIDC provider.
+ nonce = attr.ib(type=str)
+
+ # The URL the client gave when it initiated the flow. ("" if this is a UI Auth)
+ client_redirect_url = attr.ib(type=str)
+
+ # The session ID of the ongoing UI Auth (None if this is a login)
+ ui_auth_session_id = attr.ib(type=Optional[str], default=None)
+
+
UserAttributeDict = TypedDict(
"UserAttributeDict", {"localpart": Optional[str], "display_name": Optional[str]}
)
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 36f9ee4b71..c02b951031 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -286,13 +286,19 @@ class ProfileHandler(BaseHandler):
400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,)
)
+ avatar_url_to_set = new_avatar_url # type: Optional[str]
+ if new_avatar_url == "":
+ avatar_url_to_set = None
+
# Same like set_displayname
if by_admin:
requester = create_requester(
target_user, authenticated_entity=requester.authenticated_entity
)
- await self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url)
+ await self.store.set_profile_avatar_url(
+ target_user.localpart, avatar_url_to_set
+ )
if self.hs.config.user_directory_search_all_users:
profile = await self.store.get_profileinfo(target_user.localpart)
diff --git a/synapse/handlers/read_marker.py b/synapse/handlers/read_marker.py
index a7550806e6..6bb2fd936b 100644
--- a/synapse/handlers/read_marker.py
+++ b/synapse/handlers/read_marker.py
@@ -31,8 +31,8 @@ class ReadMarkerHandler(BaseHandler):
super().__init__(hs)
self.server_name = hs.config.server_name
self.store = hs.get_datastore()
+ self.account_data_handler = hs.get_account_data_handler()
self.read_marker_linearizer = Linearizer(name="read_marker")
- self.notifier = hs.get_notifier()
async def received_client_read_marker(
self, room_id: str, user_id: str, event_id: str
@@ -59,7 +59,6 @@ class ReadMarkerHandler(BaseHandler):
if should_update:
content = {"event_id": event_id}
- max_id = await self.store.add_account_data_to_room(
+ await self.account_data_handler.add_account_data_to_room(
user_id, room_id, "m.fully_read", content
)
- self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index a9abdf42e0..cc21fc2284 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -32,10 +32,26 @@ class ReceiptsHandler(BaseHandler):
self.server_name = hs.config.server_name
self.store = hs.get_datastore()
self.hs = hs
- self.federation = hs.get_federation_sender()
- hs.get_federation_registry().register_edu_handler(
- "m.receipt", self._received_remote_receipt
- )
+
+ # We only need to poke the federation sender explicitly if its on the
+ # same instance. Other federation sender instances will get notified by
+ # `synapse.app.generic_worker.FederationSenderHandler` when it sees it
+ # in the receipts stream.
+ self.federation_sender = None
+ if hs.should_send_federation():
+ self.federation_sender = hs.get_federation_sender()
+
+ # If we can handle the receipt EDUs we do so, otherwise we route them
+ # to the appropriate worker.
+ if hs.get_instance_name() in hs.config.worker.writers.receipts:
+ hs.get_federation_registry().register_edu_handler(
+ "m.receipt", self._received_remote_receipt
+ )
+ else:
+ hs.get_federation_registry().register_instances_for_edu(
+ "m.receipt", hs.config.worker.writers.receipts,
+ )
+
self.clock = self.hs.get_clock()
self.state = hs.get_state_handler()
@@ -125,7 +141,8 @@ class ReceiptsHandler(BaseHandler):
if not is_new:
return
- await self.federation.send_read_receipt(receipt)
+ if self.federation_sender:
+ await self.federation_sender.send_read_receipt(receipt)
class ReceiptEventSource:
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 3bece6d668..ee27d99135 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -38,7 +38,6 @@ from synapse.api.filtering import Filter
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.events import EventBase
from synapse.events.utils import copy_power_levels_contents
-from synapse.http.endpoint import parse_and_validate_server_name
from synapse.storage.state import StateFilter
from synapse.types import (
JsonDict,
@@ -55,6 +54,7 @@ from synapse.types import (
from synapse.util import stringutils
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.response_cache import ResponseCache
+from synapse.util.stringutils import parse_and_validate_server_name
from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index cb5a29bc7e..e001e418f9 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -63,6 +63,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
self.registration_handler = hs.get_registration_handler()
self.profile_handler = hs.get_profile_handler()
self.event_creation_handler = hs.get_event_creation_handler()
+ self.account_data_handler = hs.get_account_data_handler()
self.member_linearizer = Linearizer(name="member")
@@ -253,7 +254,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
direct_rooms[key].append(new_room_id)
# Save back to user's m.direct account data
- await self.store.add_account_data_for_user(
+ await self.account_data_handler.add_account_data_for_user(
user_id, AccountDataTypes.DIRECT, direct_rooms
)
break
@@ -263,7 +264,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# Copy each room tag to the new room
for tag, tag_content in room_tags.items():
- await self.store.add_tag_to_room(user_id, new_room_id, tag, tag_content)
+ await self.account_data_handler.add_tag_to_room(
+ user_id, new_room_id, tag, tag_content
+ )
async def update_membership(
self,
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index a8376543c9..38461cf79d 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -78,6 +78,10 @@ class SamlHandler(BaseHandler):
# user-facing name of this auth provider
self.idp_name = "SAML"
+ # we do not currently support icons for SAML auth, but this is required by
+ # the SsoIdentityProvider protocol type.
+ self.idp_icon = None
+
# a map from saml session id to Saml2SessionData object
self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData]
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 2da1ea2223..d493327a10 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -22,7 +22,10 @@ from typing_extensions import NoReturn, Protocol
from twisted.web.http import Request
+from synapse.api.constants import LoginType
from synapse.api.errors import Codes, RedirectException, SynapseError
+from synapse.handlers.ui_auth import UIAuthSessionDataConstants
+from synapse.http import get_request_user_agent
from synapse.http.server import respond_with_html
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict, UserID, contains_invalid_mxid_characters
@@ -72,6 +75,11 @@ class SsoIdentityProvider(Protocol):
def idp_name(self) -> str:
"""User-facing name for this provider"""
+ @property
+ def idp_icon(self) -> Optional[str]:
+ """Optional MXC URI for user-facing icon"""
+ return None
+
@abc.abstractmethod
async def handle_redirect_request(
self,
@@ -145,8 +153,13 @@ class SsoHandler:
self._store = hs.get_datastore()
self._server_name = hs.hostname
self._registration_handler = hs.get_registration_handler()
- self._error_template = hs.config.sso_error_template
self._auth_handler = hs.get_auth_handler()
+ self._error_template = hs.config.sso_error_template
+ self._bad_user_template = hs.config.sso_auth_bad_user_template
+
+ # The following template is shown after a successful user interactive
+ # authentication session. It tells the user they can close the window.
+ self._sso_auth_success_template = hs.config.sso_auth_success_template
# a lock on the mappings
self._mapping_lock = Linearizer(name="sso_user_mapping", clock=hs.get_clock())
@@ -166,6 +179,37 @@ class SsoHandler:
"""Get the configured identity providers"""
return self._identity_providers
+ async def get_identity_providers_for_user(
+ self, user_id: str
+ ) -> Mapping[str, SsoIdentityProvider]:
+ """Get the SsoIdentityProviders which a user has used
+
+ Given a user id, get the identity providers that that user has used to log in
+ with in the past (and thus could use to re-identify themselves for UI Auth).
+
+ Args:
+ user_id: MXID of user to look up
+
+ Raises:
+ a map of idp_id to SsoIdentityProvider
+ """
+ external_ids = await self._store.get_external_ids_by_user(user_id)
+
+ valid_idps = {}
+ for idp_id, _ in external_ids:
+ idp = self._identity_providers.get(idp_id)
+ if not idp:
+ logger.warning(
+ "User %r has an SSO mapping for IdP %r, but this is no longer "
+ "configured.",
+ user_id,
+ idp_id,
+ )
+ else:
+ valid_idps[idp_id] = idp
+
+ return valid_idps
+
def render_error(
self,
request: Request,
@@ -362,7 +406,7 @@ class SsoHandler:
attributes,
auth_provider_id,
remote_user_id,
- request.get_user_agent(""),
+ get_request_user_agent(request),
request.getClientIP(),
)
@@ -545,19 +589,45 @@ class SsoHandler:
auth_provider_id, remote_user_id,
)
+ user_id_to_verify = await self._auth_handler.get_session_data(
+ ui_auth_session_id, UIAuthSessionDataConstants.REQUEST_USER_ID
+ ) # type: str
+
if not user_id:
logger.warning(
"Remote user %s/%s has not previously logged in here: UIA will fail",
auth_provider_id,
remote_user_id,
)
- # Let the UIA flow handle this the same as if they presented creds for a
- # different user.
- user_id = ""
+ elif user_id != user_id_to_verify:
+ logger.warning(
+ "Remote user %s/%s mapped onto incorrect user %s: UIA will fail",
+ auth_provider_id,
+ remote_user_id,
+ user_id,
+ )
+ else:
+ # success!
+ # Mark the stage of the authentication as successful.
+ await self._store.mark_ui_auth_stage_complete(
+ ui_auth_session_id, LoginType.SSO, user_id
+ )
+
+ # Render the HTML confirmation page and return.
+ html = self._sso_auth_success_template
+ respond_with_html(request, 200, html)
+ return
+
+ # the user_id didn't match: mark the stage of the authentication as unsuccessful
+ await self._store.mark_ui_auth_stage_complete(
+ ui_auth_session_id, LoginType.SSO, ""
+ )
- await self._auth_handler.complete_sso_ui_auth(
- user_id, ui_auth_session_id, request
+ # render an error page.
+ html = self._bad_user_template.render(
+ server_name=self._server_name, user_id_to_verify=user_id_to_verify,
)
+ respond_with_html(request, 200, html)
async def check_username_availability(
self, localpart: str, session_id: str,
@@ -628,7 +698,7 @@ class SsoHandler:
attributes,
session.auth_provider_id,
session.remote_user_id,
- request.get_user_agent(""),
+ get_request_user_agent(request),
request.getClientIP(),
)
diff --git a/synapse/handlers/ui_auth/__init__.py b/synapse/handlers/ui_auth/__init__.py
index 824f37f8f8..a68d5e790e 100644
--- a/synapse/handlers/ui_auth/__init__.py
+++ b/synapse/handlers/ui_auth/__init__.py
@@ -20,3 +20,18 @@ TODO: move more stuff out of AuthHandler in here.
"""
from synapse.handlers.ui_auth.checkers import INTERACTIVE_AUTH_CHECKERS # noqa: F401
+
+
+class UIAuthSessionDataConstants:
+ """Constants for use with AuthHandler.set_session_data"""
+
+ # used during registration and password reset to store a hashed copy of the
+ # password, so that the client does not need to submit it each time.
+ PASSWORD_HASH = "password_hash"
+
+ # used during registration to store the mxid of the registered user
+ REGISTERED_USER_ID = "registered_user_id"
+
+ # used by validate_user_via_ui_auth to store the mxid of the user we are validating
+ # for.
+ REQUEST_USER_ID = "request_user_id"
|