diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index d50d485e06..6d42a1aed8 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -18,6 +18,7 @@
"""Utilities for interacting with Identity Servers"""
import logging
+import urllib
from canonicaljson import json
@@ -31,6 +32,7 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.config.emailconfig import ThreepidBehaviour
+from synapse.http.client import SimpleHttpClient
from synapse.util.stringutils import random_string
from ._base import BaseHandler
@@ -42,7 +44,12 @@ class IdentityHandler(BaseHandler):
def __init__(self, hs):
super(IdentityHandler, self).__init__(hs)
- self.http_client = hs.get_simple_http_client()
+ self.http_client = SimpleHttpClient(hs)
+ # We create a blacklisting instance of SimpleHttpClient for contacting identity
+ # servers specified by clients
+ self.blacklisting_http_client = SimpleHttpClient(
+ hs, ip_blacklist=hs.config.federation_ip_range_blacklist
+ )
self.federation_http_client = hs.get_http_client()
self.hs = hs
@@ -143,7 +150,9 @@ class IdentityHandler(BaseHandler):
bind_url = "https://%s/_matrix/identity/api/v1/3pid/bind" % (id_server,)
try:
- data = yield self.http_client.post_json_get_json(
+ # Use the blacklisting http client as this call is only to identity servers
+ # provided by a client
+ data = yield self.blacklisting_http_client.post_json_get_json(
bind_url, bind_data, headers=headers
)
@@ -246,7 +255,11 @@ class IdentityHandler(BaseHandler):
headers = {b"Authorization": auth_headers}
try:
- yield self.http_client.post_json_get_json(url, content, headers)
+ # Use the blacklisting http client as this call is only to identity servers
+ # provided by a client
+ yield self.blacklisting_http_client.post_json_get_json(
+ url, content, headers
+ )
changed = True
except HttpResponseException as e:
changed = False
@@ -316,6 +329,15 @@ class IdentityHandler(BaseHandler):
# Generate a session id
session_id = random_string(16)
+ if next_link:
+ # Manipulate the next_link to add the sid, because the caller won't get
+ # it until we send a response, by which time we've sent the mail.
+ if "?" in next_link:
+ next_link += "&"
+ else:
+ next_link += "?"
+ next_link += "sid=" + urllib.parse.quote(session_id)
+
# Generate a new validation token
token = random_string(32)
@@ -440,13 +462,23 @@ class IdentityHandler(BaseHandler):
id_server + "/_matrix/identity/api/v1/validate/msisdn/requestToken",
params,
)
- return data
except HttpResponseException as e:
logger.info("Proxied requestToken failed: %r", e)
raise e.to_synapse_error()
except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server")
+ assert self.hs.config.public_baseurl
+
+ # we need to tell the client to send the token back to us, since it doesn't
+ # otherwise know where to send it, so add submit_url response parameter
+ # (see also MSC2078)
+ data["submit_url"] = (
+ self.hs.config.public_baseurl
+ + "_matrix/client/unstable/add_threepid/msisdn/submit_token"
+ )
+ return data
+
@defer.inlineCallbacks
def validate_threepid_session(self, client_secret, sid):
"""Validates a threepid session with only the client secret and session ID
@@ -491,6 +523,40 @@ class IdentityHandler(BaseHandler):
return validation_session
+ @defer.inlineCallbacks
+ def proxy_msisdn_submit_token(self, id_server, client_secret, sid, token):
+ """Proxy a POST submitToken request to an identity server for verification purposes
+
+ Args:
+ id_server (str): The identity server URL to contact
+
+ client_secret (str): Secret provided by the client
+
+ sid (str): The ID of the session
+
+ token (str): The verification token
+
+ Raises:
+ SynapseError: If we failed to contact the identity server
+
+ Returns:
+ Deferred[dict]: The response dict from the identity server
+ """
+ body = {"client_secret": client_secret, "sid": sid, "token": token}
+
+ try:
+ return (
+ yield self.http_client.post_json_get_json(
+ id_server + "/_matrix/identity/api/v1/validate/msisdn/submitToken",
+ body,
+ )
+ )
+ except TimeoutError:
+ raise SynapseError(500, "Timed out contacting identity server")
+ except HttpResponseException as e:
+ logger.warning("Error contacting msisdn account_threepid_delegate: %s", e)
+ raise SynapseError(400, "Error contacting the identity server")
+
def create_id_access_token_header(id_access_token):
"""Create an Authorization header for passing to SimpleHttpClient as the header value
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 39df0f128d..94cd0cf3ef 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -31,6 +31,7 @@ from synapse import types
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError, Codes, HttpResponseException, SynapseError
from synapse.handlers.identity import LookupAlgorithm, create_id_access_token_header
+from synapse.http.client import SimpleHttpClient
from synapse.types import RoomID, UserID
from synapse.util.async_helpers import Linearizer
from synapse.util.distributor import user_joined_room, user_left_room
@@ -62,7 +63,11 @@ class RoomMemberHandler(object):
self.auth = hs.get_auth()
self.state_handler = hs.get_state_handler()
self.config = hs.config
- self.simple_http_client = hs.get_simple_http_client()
+ # We create a blacklisting instance of SimpleHttpClient for contacting identity
+ # servers specified by clients
+ self.simple_http_client = SimpleHttpClient(
+ hs, ip_blacklist=hs.config.federation_ip_range_blacklist
+ )
self.federation_handler = hs.get_handlers().federation_handler
self.directory_handler = hs.get_handlers().directory_handler
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index a1ce6929cf..cc9e6b9bd0 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -21,6 +21,8 @@ from saml2.client import Saml2Client
from synapse.api.errors import SynapseError
from synapse.http.servlet import parse_string
from synapse.rest.client.v1.login import SSOAuthHandler
+from synapse.types import UserID, map_username_to_mxid_localpart
+from synapse.util.async_helpers import Linearizer
logger = logging.getLogger(__name__)
@@ -29,12 +31,26 @@ class SamlHandler:
def __init__(self, hs):
self._saml_client = Saml2Client(hs.config.saml2_sp_config)
self._sso_auth_handler = SSOAuthHandler(hs)
+ self._registration_handler = hs.get_registration_handler()
+
+ self._clock = hs.get_clock()
+ self._datastore = hs.get_datastore()
+ self._hostname = hs.hostname
+ self._saml2_session_lifetime = hs.config.saml2_session_lifetime
+ self._mxid_source_attribute = hs.config.saml2_mxid_source_attribute
+ self._grandfathered_mxid_source_attribute = (
+ hs.config.saml2_grandfathered_mxid_source_attribute
+ )
+ self._mxid_mapper = hs.config.saml2_mxid_mapper
+
+ # identifier for the external_ids table
+ self._auth_provider_id = "saml"
# a map from saml session id to Saml2SessionData object
self._outstanding_requests_dict = {}
- self._clock = hs.get_clock()
- self._saml2_session_lifetime = hs.config.saml2_session_lifetime
+ # a lock on the mappings
+ self._mapping_lock = Linearizer(name="saml_mapping", clock=self._clock)
def handle_redirect_request(self, client_redirect_url):
"""Handle an incoming request to /login/sso/redirect
@@ -60,7 +76,7 @@ class SamlHandler:
# this shouldn't happen!
raise Exception("prepare_for_authenticate didn't return a Location header")
- def handle_saml_response(self, request):
+ async def handle_saml_response(self, request):
"""Handle an incoming request to /_matrix/saml2/authn_response
Args:
@@ -77,6 +93,10 @@ class SamlHandler:
# the dict.
self.expire_sessions()
+ user_id = await self._map_saml_response_to_user(resp_bytes)
+ self._sso_auth_handler.complete_sso_login(user_id, request, relay_state)
+
+ async def _map_saml_response_to_user(self, resp_bytes):
try:
saml2_auth = self._saml_client.parse_authn_request_response(
resp_bytes,
@@ -91,18 +111,88 @@ class SamlHandler:
logger.warning("SAML2 response was not signed")
raise SynapseError(400, "SAML2 response was not signed")
- if "uid" not in saml2_auth.ava:
+ logger.info("SAML2 response: %s", saml2_auth.origxml)
+ logger.info("SAML2 mapped attributes: %s", saml2_auth.ava)
+
+ try:
+ remote_user_id = saml2_auth.ava["uid"][0]
+ except KeyError:
logger.warning("SAML2 response lacks a 'uid' attestation")
raise SynapseError(400, "uid not in SAML2 response")
+ try:
+ mxid_source = saml2_auth.ava[self._mxid_source_attribute][0]
+ except KeyError:
+ logger.warning(
+ "SAML2 response lacks a '%s' attestation", self._mxid_source_attribute
+ )
+ raise SynapseError(
+ 400, "%s not in SAML2 response" % (self._mxid_source_attribute,)
+ )
+
self._outstanding_requests_dict.pop(saml2_auth.in_response_to, None)
- username = saml2_auth.ava["uid"][0]
displayName = saml2_auth.ava.get("displayName", [None])[0]
- return self._sso_auth_handler.on_successful_auth(
- username, request, relay_state, user_display_name=displayName
- )
+ with (await self._mapping_lock.queue(self._auth_provider_id)):
+ # first of all, check if we already have a mapping for this user
+ logger.info(
+ "Looking for existing mapping for user %s:%s",
+ self._auth_provider_id,
+ remote_user_id,
+ )
+ registered_user_id = await self._datastore.get_user_by_external_id(
+ self._auth_provider_id, remote_user_id
+ )
+ if registered_user_id is not None:
+ logger.info("Found existing mapping %s", registered_user_id)
+ return registered_user_id
+
+ # backwards-compatibility hack: see if there is an existing user with a
+ # suitable mapping from the uid
+ if (
+ self._grandfathered_mxid_source_attribute
+ and self._grandfathered_mxid_source_attribute in saml2_auth.ava
+ ):
+ attrval = saml2_auth.ava[self._grandfathered_mxid_source_attribute][0]
+ user_id = UserID(
+ map_username_to_mxid_localpart(attrval), self._hostname
+ ).to_string()
+ logger.info(
+ "Looking for existing account based on mapped %s %s",
+ self._grandfathered_mxid_source_attribute,
+ user_id,
+ )
+
+ users = await self._datastore.get_users_by_id_case_insensitive(user_id)
+ if users:
+ registered_user_id = list(users.keys())[0]
+ logger.info("Grandfathering mapping to %s", registered_user_id)
+ await self._datastore.record_user_external_id(
+ self._auth_provider_id, remote_user_id, registered_user_id
+ )
+ return registered_user_id
+
+ # figure out a new mxid for this user
+ base_mxid_localpart = self._mxid_mapper(mxid_source)
+
+ suffix = 0
+ while True:
+ localpart = base_mxid_localpart + (str(suffix) if suffix else "")
+ if not await self._datastore.get_users_by_id_case_insensitive(
+ UserID(localpart, self._hostname).to_string()
+ ):
+ break
+ suffix += 1
+ logger.info("Allocating mxid for new user with localpart %s", localpart)
+
+ registered_user_id = await self._registration_handler.register_user(
+ localpart=localpart, default_display_name=displayName
+ )
+ await self._datastore.record_user_external_id(
+ self._auth_provider_id, remote_user_id, registered_user_id
+ )
+ return registered_user_id
def expire_sessions(self):
expire_before = self._clock.time_msec() - self._saml2_session_lifetime
|