summary refs log tree commit diff
diff options
context:
space:
mode:
authorRichard van der Hoff <richard@matrix.org>2019-09-25 11:24:22 +0100
committerRichard van der Hoff <richard@matrix.org>2019-09-25 11:24:22 +0100
commitf8d977f3b78977135ec524f86b4366f697d8be15 (patch)
tree18323cb05cb8c7dbddc632ab42e749dfe162bb43
parentremove errant print (diff)
parentfix flake8 fails (diff)
downloadsynapse-f8d977f3b78977135ec524f86b4366f697d8be15.tar.xz
Merge branch 'rav/refactor_ui_auth' into neilj/fix_check_threepid_for_msisdns
-rw-r--r--changelog.d/6037.feature1
-rw-r--r--changelog.d/6069.bugfix1
-rw-r--r--changelog.d/6092.bugfix1
-rw-r--r--changelog.d/6099.misc1
-rw-r--r--changelog.d/6105.misc1
-rw-r--r--docs/sample_config.yaml26
-rw-r--r--synapse/config/saml2_config.py124
-rw-r--r--synapse/handlers/auth.py146
-rw-r--r--synapse/handlers/saml_handler.py106
-rw-r--r--synapse/handlers/ui_auth/__init__.py22
-rw-r--r--synapse/handlers/ui_auth/checkers.py222
-rw-r--r--synapse/rest/client/v1/login.py14
-rw-r--r--synapse/storage/background_updates.py2
-rw-r--r--synapse/storage/registration.py53
-rw-r--r--synapse/storage/schema/delta/56/user_external_ids.sql24
-rw-r--r--synapse/util/module_loader.py20
-rw-r--r--tests/rest/client/v2_alpha/test_auth.py26
17 files changed, 621 insertions, 169 deletions
diff --git a/changelog.d/6037.feature b/changelog.d/6037.feature
new file mode 100644
index 0000000000..85553d2da0
--- /dev/null
+++ b/changelog.d/6037.feature
@@ -0,0 +1 @@
+Make the process for mapping SAML2 users to matrix IDs more flexible.
diff --git a/changelog.d/6069.bugfix b/changelog.d/6069.bugfix
new file mode 100644
index 0000000000..a437ac41a9
--- /dev/null
+++ b/changelog.d/6069.bugfix
@@ -0,0 +1 @@
+Fix a bug which caused SAML attribute maps to be overridden by defaults.
diff --git a/changelog.d/6092.bugfix b/changelog.d/6092.bugfix
new file mode 100644
index 0000000000..01a7498ec6
--- /dev/null
+++ b/changelog.d/6092.bugfix
@@ -0,0 +1 @@
+Fix the logged number of updated items for the users_set_deactivated_flag background update.
diff --git a/changelog.d/6099.misc b/changelog.d/6099.misc
new file mode 100644
index 0000000000..8415c6759b
--- /dev/null
+++ b/changelog.d/6099.misc
@@ -0,0 +1 @@
+Remove unused parameter to get_user_id_by_threepid.
diff --git a/changelog.d/6105.misc b/changelog.d/6105.misc
new file mode 100644
index 0000000000..2e838a35c6
--- /dev/null
+++ b/changelog.d/6105.misc
@@ -0,0 +1 @@
+Refactor the user-interactive auth handling.
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index 46af6edf1f..da31728037 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -1174,6 +1174,32 @@ saml2_config:
   #
   #saml_session_lifetime: 5m
 
+  # The SAML attribute (after mapping via the attribute maps) to use to derive
+  # the Matrix ID from. 'uid' by default.
+  #
+  #mxid_source_attribute: displayName
+
+  # The mapping system to use for mapping the saml attribute onto a matrix ID.
+  # Options include:
+  #  * 'hexencode' (which maps unpermitted characters to '=xx')
+  #  * 'dotreplace' (which replaces unpermitted characters with '.').
+  # The default is 'hexencode'.
+  #
+  #mxid_mapping: dotreplace
+
+  # In previous versions of synapse, the mapping from SAML attribute to MXID was
+  # always calculated dynamically rather than stored in a table. For backwards-
+  # compatibility, we will look for user_ids matching such a pattern before
+  # creating a new account.
+  #
+  # This setting controls the SAML attribute which will be used for this
+  # backwards-compatibility lookup. Typically it should be 'uid', but if the
+  # attribute maps are changed, it may be necessary to change it.
+  #
+  # The default is 'uid'.
+  #
+  #grandfathered_mxid_source_attribute: upn
+
 
 
 # Enable CAS for registration and login.
diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py
index c46ac087db..ab34b41ca8 100644
--- a/synapse/config/saml2_config.py
+++ b/synapse/config/saml2_config.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -12,11 +13,47 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+
+import re
+
 from synapse.python_dependencies import DependencyException, check_requirements
+from synapse.types import (
+    map_username_to_mxid_localpart,
+    mxid_localpart_allowed_characters,
+)
+from synapse.util.module_loader import load_python_module
 
 from ._base import Config, ConfigError
 
 
+def _dict_merge(merge_dict, into_dict):
+    """Do a deep merge of two dicts
+
+    Recursively merges `merge_dict` into `into_dict`:
+      * For keys where both `merge_dict` and `into_dict` have a dict value, the values
+        are recursively merged
+      * For all other keys, the values in `into_dict` (if any) are overwritten with
+        the value from `merge_dict`.
+
+    Args:
+        merge_dict (dict): dict to merge
+        into_dict (dict): target dict
+    """
+    for k, v in merge_dict.items():
+        if k not in into_dict:
+            into_dict[k] = v
+            continue
+
+        current_val = into_dict[k]
+
+        if isinstance(v, dict) and isinstance(current_val, dict):
+            _dict_merge(v, current_val)
+            continue
+
+        # otherwise we just overwrite
+        into_dict[k] = v
+
+
 class SAML2Config(Config):
     def read_config(self, config, **kwargs):
         self.saml2_enabled = False
@@ -36,21 +73,40 @@ class SAML2Config(Config):
 
         self.saml2_enabled = True
 
-        import saml2.config
+        self.saml2_mxid_source_attribute = saml2_config.get(
+            "mxid_source_attribute", "uid"
+        )
 
-        self.saml2_sp_config = saml2.config.SPConfig()
-        self.saml2_sp_config.load(self._default_saml_config_dict())
-        self.saml2_sp_config.load(saml2_config.get("sp_config", {}))
+        self.saml2_grandfathered_mxid_source_attribute = saml2_config.get(
+            "grandfathered_mxid_source_attribute", "uid"
+        )
+
+        saml2_config_dict = self._default_saml_config_dict()
+        _dict_merge(
+            merge_dict=saml2_config.get("sp_config", {}), into_dict=saml2_config_dict
+        )
 
         config_path = saml2_config.get("config_path", None)
         if config_path is not None:
-            self.saml2_sp_config.load_file(config_path)
+            mod = load_python_module(config_path)
+            _dict_merge(merge_dict=mod.CONFIG, into_dict=saml2_config_dict)
+
+        import saml2.config
+
+        self.saml2_sp_config = saml2.config.SPConfig()
+        self.saml2_sp_config.load(saml2_config_dict)
 
         # session lifetime: in milliseconds
         self.saml2_session_lifetime = self.parse_duration(
             saml2_config.get("saml_session_lifetime", "5m")
         )
 
+        mapping = saml2_config.get("mxid_mapping", "hexencode")
+        try:
+            self.saml2_mxid_mapper = MXID_MAPPER_MAP[mapping]
+        except KeyError:
+            raise ConfigError("%s is not a known mxid_mapping" % (mapping,))
+
     def _default_saml_config_dict(self):
         import saml2
 
@@ -58,6 +114,13 @@ class SAML2Config(Config):
         if public_baseurl is None:
             raise ConfigError("saml2_config requires a public_baseurl to be set")
 
+        required_attributes = {"uid", self.saml2_mxid_source_attribute}
+
+        optional_attributes = {"displayName"}
+        if self.saml2_grandfathered_mxid_source_attribute:
+            optional_attributes.add(self.saml2_grandfathered_mxid_source_attribute)
+        optional_attributes -= required_attributes
+
         metadata_url = public_baseurl + "_matrix/saml2/metadata.xml"
         response_url = public_baseurl + "_matrix/saml2/authn_response"
         return {
@@ -69,8 +132,9 @@ class SAML2Config(Config):
                             (response_url, saml2.BINDING_HTTP_POST)
                         ]
                     },
-                    "required_attributes": ["uid"],
-                    "optional_attributes": ["mail", "surname", "givenname"],
+                    "required_attributes": list(required_attributes),
+                    "optional_attributes": list(optional_attributes),
+                    # "name_id_format": saml2.saml.NAMEID_FORMAT_PERSISTENT,
                 }
             },
         }
@@ -146,6 +210,52 @@ class SAML2Config(Config):
           # The default is 5 minutes.
           #
           #saml_session_lifetime: 5m
+
+          # The SAML attribute (after mapping via the attribute maps) to use to derive
+          # the Matrix ID from. 'uid' by default.
+          #
+          #mxid_source_attribute: displayName
+
+          # The mapping system to use for mapping the saml attribute onto a matrix ID.
+          # Options include:
+          #  * 'hexencode' (which maps unpermitted characters to '=xx')
+          #  * 'dotreplace' (which replaces unpermitted characters with '.').
+          # The default is 'hexencode'.
+          #
+          #mxid_mapping: dotreplace
+
+          # In previous versions of synapse, the mapping from SAML attribute to MXID was
+          # always calculated dynamically rather than stored in a table. For backwards-
+          # compatibility, we will look for user_ids matching such a pattern before
+          # creating a new account.
+          #
+          # This setting controls the SAML attribute which will be used for this
+          # backwards-compatibility lookup. Typically it should be 'uid', but if the
+          # attribute maps are changed, it may be necessary to change it.
+          #
+          # The default is 'uid'.
+          #
+          #grandfathered_mxid_source_attribute: upn
         """ % {
             "config_dir_path": config_dir_path
         }
+
+
+DOT_REPLACE_PATTERN = re.compile(
+    ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),))
+)
+
+
+def dot_replace_for_mxid(username: str) -> str:
+    username = username.lower()
+    username = DOT_REPLACE_PATTERN.sub(".", username)
+
+    # regular mxids aren't allowed to start with an underscore either
+    username = re.sub("^_", "", username)
+    return username
+
+
+MXID_MAPPER_MAP = {
+    "hexencode": map_username_to_mxid_localpart,
+    "dotreplace": dot_replace_for_mxid,
+}
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 56c5ea526b..f920c2f6c1 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -21,10 +21,8 @@ import unicodedata
 import attr
 import bcrypt
 import pymacaroons
-from canonicaljson import json
 
 from twisted.internet import defer
-from twisted.web.client import PartialDownloadError
 
 import synapse.util.stringutils as stringutils
 from synapse.api.constants import LoginType
@@ -38,7 +36,8 @@ from synapse.api.errors import (
     UserDeactivatedError,
 )
 from synapse.api.ratelimiting import Ratelimiter
-from synapse.config.emailconfig import ThreepidBehaviour
+from synapse.handlers.ui_auth import INTERACTIVE_AUTH_CHECKERS
+from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
 from synapse.logging.context import defer_to_thread
 from synapse.module_api import ModuleApi
 from synapse.types import UserID
@@ -58,13 +57,12 @@ class AuthHandler(BaseHandler):
             hs (synapse.server.HomeServer):
         """
         super(AuthHandler, self).__init__(hs)
-        self.checkers = {
-            LoginType.RECAPTCHA: self._check_recaptcha,
-            LoginType.EMAIL_IDENTITY: self._check_email_identity,
-            LoginType.MSISDN: self._check_msisdn,
-            LoginType.DUMMY: self._check_dummy_auth,
-            LoginType.TERMS: self._check_terms_auth,
-        }
+
+        self.checkers = {}  # type: dict[str, UserInteractiveAuthChecker]
+        for auth_checker_class in INTERACTIVE_AUTH_CHECKERS:
+            inst = auth_checker_class(hs)
+            self.checkers[inst.AUTH_TYPE] = inst
+
         self.bcrypt_rounds = hs.config.bcrypt_rounds
 
         # This is not a cache per se, but a store of all current sessions that
@@ -292,7 +290,7 @@ class AuthHandler(BaseHandler):
             sess["creds"] = {}
         creds = sess["creds"]
 
-        result = yield self.checkers[stagetype](authdict, clientip)
+        result = yield self.checkers[stagetype].check_auth(authdict, clientip)
         if result:
             creds[stagetype] = result
             self._save_session(sess)
@@ -363,7 +361,7 @@ class AuthHandler(BaseHandler):
         login_type = authdict["type"]
         checker = self.checkers.get(login_type)
         if checker is not None:
-            res = yield checker(authdict, clientip=clientip)
+            res = yield checker.check_auth(authdict, clientip=clientip)
             return res
 
         # build a v1-login-style dict out of the authdict and fall back to the
@@ -376,130 +374,6 @@ class AuthHandler(BaseHandler):
         (canonical_id, callback) = yield self.validate_login(user_id, authdict)
         return canonical_id
 
-    @defer.inlineCallbacks
-    def _check_recaptcha(self, authdict, clientip, **kwargs):
-        try:
-            user_response = authdict["response"]
-        except KeyError:
-            # Client tried to provide captcha but didn't give the parameter:
-            # bad request.
-            raise LoginError(
-                400, "Captcha response is required", errcode=Codes.CAPTCHA_NEEDED
-            )
-
-        logger.info(
-            "Submitting recaptcha response %s with remoteip %s", user_response, clientip
-        )
-
-        # TODO: get this from the homeserver rather than creating a new one for
-        # each request
-        try:
-            client = self.hs.get_simple_http_client()
-            resp_body = yield client.post_urlencoded_get_json(
-                self.hs.config.recaptcha_siteverify_api,
-                args={
-                    "secret": self.hs.config.recaptcha_private_key,
-                    "response": user_response,
-                    "remoteip": clientip,
-                },
-            )
-        except PartialDownloadError as pde:
-            # Twisted is silly
-            data = pde.response
-            resp_body = json.loads(data)
-
-        if "success" in resp_body:
-            # Note that we do NOT check the hostname here: we explicitly
-            # intend the CAPTCHA to be presented by whatever client the
-            # user is using, we just care that they have completed a CAPTCHA.
-            logger.info(
-                "%s reCAPTCHA from hostname %s",
-                "Successful" if resp_body["success"] else "Failed",
-                resp_body.get("hostname"),
-            )
-            if resp_body["success"]:
-                return True
-        raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
-
-    def _check_email_identity(self, authdict, **kwargs):
-        return self._check_threepid("email", authdict, **kwargs)
-
-    def _check_msisdn(self, authdict, **kwargs):
-        return self._check_threepid("msisdn", authdict)
-
-    def _check_dummy_auth(self, authdict, **kwargs):
-        return defer.succeed(True)
-
-    def _check_terms_auth(self, authdict, **kwargs):
-        return defer.succeed(True)
-
-    @defer.inlineCallbacks
-    def _check_threepid(self, medium, authdict, **kwargs):
-        if "threepid_creds" not in authdict:
-            raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM)
-
-        threepid_creds = authdict["threepid_creds"]
-
-        identity_handler = self.hs.get_handlers().identity_handler
-
-        logger.info("Getting validated threepid. threepidcreds: %r", (threepid_creds,))
-
-        # msisdns are currently always ThreepidBehaviour.REMOTE
-        if medium == "msisdn":
-            if self.hs.config.account_threepid_delegate_msisdn:
-                threepid = yield identity_handler.threepid_from_creds(
-                    self.hs.config.account_threepid_delegate_msisdn, threepid_creds
-                )
-            else:
-                raise SynapseError(
-                    400, "SMS delegation is not enabled on this homeserver"
-                )
-        elif medium == "email":
-            if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
-                if medium == "email":
-                    threepid = yield identity_handler.threepid_from_creds(
-                        self.hs.config.account_threepid_delegate_email, threepid_creds
-                    )
-            elif self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
-                row = yield self.store.get_threepid_validation_session(
-                    medium,
-                    threepid_creds["client_secret"],
-                    sid=threepid_creds["sid"],
-                    validated=True,
-                )
-
-                threepid = (
-                    {
-                        "medium": row["medium"],
-                        "address": row["address"],
-                        "validated_at": row["validated_at"],
-                    }
-                    if row
-                    else None
-                )
-
-                if row:
-                    # Valid threepid returned, delete from the db
-                    yield self.store.delete_threepid_session(threepid_creds["sid"])
-            else:
-                raise SynapseError(400, "Email is not enabled on this homeserver")
-        else:
-            raise SynapseError(400, "Unrecognized threepid medium: %s" % (medium,))
-        if not threepid:
-            raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
-
-        if threepid["medium"] != medium:
-            raise LoginError(
-                401,
-                "Expecting threepid of type '%s', got '%s'"
-                % (medium, threepid["medium"]),
-                errcode=Codes.UNAUTHORIZED,
-            )
-
-        threepid["threepid_creds"] = authdict["threepid_creds"]
-
-        return threepid
-
     def _get_params_recaptcha(self):
         return {"public_key": self.hs.config.recaptcha_public_key}
 
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index a1ce6929cf..cc9e6b9bd0 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -21,6 +21,8 @@ from saml2.client import Saml2Client
 from synapse.api.errors import SynapseError
 from synapse.http.servlet import parse_string
 from synapse.rest.client.v1.login import SSOAuthHandler
+from synapse.types import UserID, map_username_to_mxid_localpart
+from synapse.util.async_helpers import Linearizer
 
 logger = logging.getLogger(__name__)
 
@@ -29,12 +31,26 @@ class SamlHandler:
     def __init__(self, hs):
         self._saml_client = Saml2Client(hs.config.saml2_sp_config)
         self._sso_auth_handler = SSOAuthHandler(hs)
+        self._registration_handler = hs.get_registration_handler()
+
+        self._clock = hs.get_clock()
+        self._datastore = hs.get_datastore()
+        self._hostname = hs.hostname
+        self._saml2_session_lifetime = hs.config.saml2_session_lifetime
+        self._mxid_source_attribute = hs.config.saml2_mxid_source_attribute
+        self._grandfathered_mxid_source_attribute = (
+            hs.config.saml2_grandfathered_mxid_source_attribute
+        )
+        self._mxid_mapper = hs.config.saml2_mxid_mapper
+
+        # identifier for the external_ids table
+        self._auth_provider_id = "saml"
 
         # a map from saml session id to Saml2SessionData object
         self._outstanding_requests_dict = {}
 
-        self._clock = hs.get_clock()
-        self._saml2_session_lifetime = hs.config.saml2_session_lifetime
+        # a lock on the mappings
+        self._mapping_lock = Linearizer(name="saml_mapping", clock=self._clock)
 
     def handle_redirect_request(self, client_redirect_url):
         """Handle an incoming request to /login/sso/redirect
@@ -60,7 +76,7 @@ class SamlHandler:
         # this shouldn't happen!
         raise Exception("prepare_for_authenticate didn't return a Location header")
 
-    def handle_saml_response(self, request):
+    async def handle_saml_response(self, request):
         """Handle an incoming request to /_matrix/saml2/authn_response
 
         Args:
@@ -77,6 +93,10 @@ class SamlHandler:
         # the dict.
         self.expire_sessions()
 
+        user_id = await self._map_saml_response_to_user(resp_bytes)
+        self._sso_auth_handler.complete_sso_login(user_id, request, relay_state)
+
+    async def _map_saml_response_to_user(self, resp_bytes):
         try:
             saml2_auth = self._saml_client.parse_authn_request_response(
                 resp_bytes,
@@ -91,18 +111,88 @@ class SamlHandler:
             logger.warning("SAML2 response was not signed")
             raise SynapseError(400, "SAML2 response was not signed")
 
-        if "uid" not in saml2_auth.ava:
+        logger.info("SAML2 response: %s", saml2_auth.origxml)
+        logger.info("SAML2 mapped attributes: %s", saml2_auth.ava)
+
+        try:
+            remote_user_id = saml2_auth.ava["uid"][0]
+        except KeyError:
             logger.warning("SAML2 response lacks a 'uid' attestation")
             raise SynapseError(400, "uid not in SAML2 response")
 
+        try:
+            mxid_source = saml2_auth.ava[self._mxid_source_attribute][0]
+        except KeyError:
+            logger.warning(
+                "SAML2 response lacks a '%s' attestation", self._mxid_source_attribute
+            )
+            raise SynapseError(
+                400, "%s not in SAML2 response" % (self._mxid_source_attribute,)
+            )
+
         self._outstanding_requests_dict.pop(saml2_auth.in_response_to, None)
 
-        username = saml2_auth.ava["uid"][0]
         displayName = saml2_auth.ava.get("displayName", [None])[0]
 
-        return self._sso_auth_handler.on_successful_auth(
-            username, request, relay_state, user_display_name=displayName
-        )
+        with (await self._mapping_lock.queue(self._auth_provider_id)):
+            # first of all, check if we already have a mapping for this user
+            logger.info(
+                "Looking for existing mapping for user %s:%s",
+                self._auth_provider_id,
+                remote_user_id,
+            )
+            registered_user_id = await self._datastore.get_user_by_external_id(
+                self._auth_provider_id, remote_user_id
+            )
+            if registered_user_id is not None:
+                logger.info("Found existing mapping %s", registered_user_id)
+                return registered_user_id
+
+            # backwards-compatibility hack: see if there is an existing user with a
+            # suitable mapping from the uid
+            if (
+                self._grandfathered_mxid_source_attribute
+                and self._grandfathered_mxid_source_attribute in saml2_auth.ava
+            ):
+                attrval = saml2_auth.ava[self._grandfathered_mxid_source_attribute][0]
+                user_id = UserID(
+                    map_username_to_mxid_localpart(attrval), self._hostname
+                ).to_string()
+                logger.info(
+                    "Looking for existing account based on mapped %s %s",
+                    self._grandfathered_mxid_source_attribute,
+                    user_id,
+                )
+
+                users = await self._datastore.get_users_by_id_case_insensitive(user_id)
+                if users:
+                    registered_user_id = list(users.keys())[0]
+                    logger.info("Grandfathering mapping to %s", registered_user_id)
+                    await self._datastore.record_user_external_id(
+                        self._auth_provider_id, remote_user_id, registered_user_id
+                    )
+                    return registered_user_id
+
+            # figure out a new mxid for this user
+            base_mxid_localpart = self._mxid_mapper(mxid_source)
+
+            suffix = 0
+            while True:
+                localpart = base_mxid_localpart + (str(suffix) if suffix else "")
+                if not await self._datastore.get_users_by_id_case_insensitive(
+                    UserID(localpart, self._hostname).to_string()
+                ):
+                    break
+                suffix += 1
+            logger.info("Allocating mxid for new user with localpart %s", localpart)
+
+            registered_user_id = await self._registration_handler.register_user(
+                localpart=localpart, default_display_name=displayName
+            )
+            await self._datastore.record_user_external_id(
+                self._auth_provider_id, remote_user_id, registered_user_id
+            )
+            return registered_user_id
 
     def expire_sessions(self):
         expire_before = self._clock.time_msec() - self._saml2_session_lifetime
diff --git a/synapse/handlers/ui_auth/__init__.py b/synapse/handlers/ui_auth/__init__.py
new file mode 100644
index 0000000000..824f37f8f8
--- /dev/null
+++ b/synapse/handlers/ui_auth/__init__.py
@@ -0,0 +1,22 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""This module implements user-interactive auth verification.
+
+TODO: move more stuff out of AuthHandler in here.
+
+"""
+
+from synapse.handlers.ui_auth.checkers import INTERACTIVE_AUTH_CHECKERS  # noqa: F401
diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py
new file mode 100644
index 0000000000..9b802ea265
--- /dev/null
+++ b/synapse/handlers/ui_auth/checkers.py
@@ -0,0 +1,222 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from canonicaljson import json
+
+from twisted.internet import defer
+from twisted.web.client import PartialDownloadError
+
+from synapse.api.constants import LoginType
+from synapse.api.errors import Codes, LoginError, SynapseError
+from synapse.config.emailconfig import ThreepidBehaviour
+
+logger = logging.getLogger(__name__)
+
+
+class UserInteractiveAuthChecker:
+    """Abstract base class for an interactive auth checker"""
+
+    def __init__(self, hs):
+        pass
+
+    @defer.inlineCallbacks
+    def check_auth(self, authdict, clientip):
+        """Given the authentication dict from the client, attempt to check this step
+
+        Args:
+            authdict (dict): authentication dictionary from the client
+            clientip (str): The IP address of the client.
+
+        Raises:
+            SynapseError if authentication failed
+
+        Returns:
+            Deferred: the result of authentication (to pass back to the client?)
+        """
+        raise NotImplementedError()
+
+
+class DummyAuthChecker(UserInteractiveAuthChecker):
+    AUTH_TYPE = LoginType.DUMMY
+
+    def check_auth(self, authdict, clientip):
+        return defer.succeed(True)
+
+
+class TermsAuthChecker(UserInteractiveAuthChecker):
+    AUTH_TYPE = LoginType.TERMS
+
+    def check_auth(self, authdict, clientip):
+        return defer.succeed(True)
+
+
+class RecaptchaAuthChecker(UserInteractiveAuthChecker):
+    AUTH_TYPE = LoginType.RECAPTCHA
+
+    def __init__(self, hs):
+        super().__init__(hs)
+        self._http_client = hs.get_simple_http_client()
+        self._url = hs.config.recaptcha_siteverify_api
+        self._secret = hs.config.recaptcha_private_key
+
+    @defer.inlineCallbacks
+    def check_auth(self, authdict, clientip):
+        try:
+            user_response = authdict["response"]
+        except KeyError:
+            # Client tried to provide captcha but didn't give the parameter:
+            # bad request.
+            raise LoginError(
+                400, "Captcha response is required", errcode=Codes.CAPTCHA_NEEDED
+            )
+
+        logger.info(
+            "Submitting recaptcha response %s with remoteip %s", user_response, clientip
+        )
+
+        # TODO: get this from the homeserver rather than creating a new one for
+        # each request
+        try:
+            resp_body = yield self._http_client.post_urlencoded_get_json(
+                self._url,
+                args={
+                    "secret": self._secret,
+                    "response": user_response,
+                    "remoteip": clientip,
+                },
+            )
+        except PartialDownloadError as pde:
+            # Twisted is silly
+            data = pde.response
+            resp_body = json.loads(data)
+
+        if "success" in resp_body:
+            # Note that we do NOT check the hostname here: we explicitly
+            # intend the CAPTCHA to be presented by whatever client the
+            # user is using, we just care that they have completed a CAPTCHA.
+            logger.info(
+                "%s reCAPTCHA from hostname %s",
+                "Successful" if resp_body["success"] else "Failed",
+                resp_body.get("hostname"),
+            )
+            if resp_body["success"]:
+                return True
+        raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
+
+
+class _BaseThreepidAuthChecker:
+    def __init__(self, hs):
+        self.hs = hs
+        self.store = hs.get_datastore()
+
+    @defer.inlineCallbacks
+    def _check_threepid(self, medium, authdict, **kwargs):
+        if "threepid_creds" not in authdict:
+            raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM)
+
+        threepid_creds = authdict["threepid_creds"]
+
+        identity_handler = self.hs.get_handlers().identity_handler
+
+        logger.info("Getting validated threepid. threepidcreds: %r", (threepid_creds,))
+
+        # msisdns are currently always ThreepidBehaviour.REMOTE
+        if medium == "msisdn":
+            if self.hs.config.account_threepid_delegate_msisdn:
+                threepid = yield identity_handler.threepid_from_creds(
+                    self.hs.config.account_threepid_delegate_msisdn, threepid_creds
+                )
+            else:
+                raise SynapseError(
+                    400, "SMS delegation is not enabled on this homeserver"
+                )
+        elif medium == "email":
+            if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
+                if medium == "email":
+                    threepid = yield identity_handler.threepid_from_creds(
+                        self.hs.config.account_threepid_delegate_email, threepid_creds
+                    )
+            elif self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+                row = yield self.store.get_threepid_validation_session(
+                    medium,
+                    threepid_creds["client_secret"],
+                    sid=threepid_creds["sid"],
+                    validated=True,
+                )
+
+                threepid = (
+                    {
+                        "medium": row["medium"],
+                        "address": row["address"],
+                        "validated_at": row["validated_at"],
+                    }
+                    if row
+                    else None
+                )
+
+                if row:
+                    # Valid threepid returned, delete from the db
+                    yield self.store.delete_threepid_session(threepid_creds["sid"])
+            else:
+                raise SynapseError(400, "Email is not enabled on this homeserver")
+        else:
+            raise SynapseError(400, "Unrecognized threepid medium: %s" % (medium,))
+        if not threepid:
+            raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
+
+        if threepid["medium"] != medium:
+            raise LoginError(
+                401,
+                "Expecting threepid of type '%s', got '%s'"
+                % (medium, threepid["medium"]),
+                errcode=Codes.UNAUTHORIZED,
+            )
+
+        threepid["threepid_creds"] = authdict["threepid_creds"]
+
+        return threepid
+
+
+class EmailIdentityAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
+    AUTH_TYPE = LoginType.EMAIL_IDENTITY
+
+    def __init__(self, hs):
+        UserInteractiveAuthChecker.__init__(self, hs)
+        _BaseThreepidAuthChecker.__init__(self, hs)
+
+    def check_auth(self, authdict, clientip):
+        return self._check_threepid("email", authdict)
+
+
+class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
+    AUTH_TYPE = LoginType.MSISDN
+
+    def __init__(self, hs):
+        UserInteractiveAuthChecker.__init__(self, hs)
+        _BaseThreepidAuthChecker.__init__(self, hs)
+
+    def check_auth(self, authdict, clientip):
+        return self._check_threepid("msisdn", authdict)
+
+
+INTERACTIVE_AUTH_CHECKERS = [
+    DummyAuthChecker,
+    TermsAuthChecker,
+    RecaptchaAuthChecker,
+    EmailIdentityAuthChecker,
+    MsisdnAuthChecker,
+]
+"""A list of UserInteractiveAuthChecker classes"""
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 25a1b67092..9cddbc752a 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -29,6 +29,7 @@ from synapse.http.servlet import (
     parse_json_object_from_request,
     parse_string,
 )
+from synapse.http.site import SynapseRequest
 from synapse.rest.client.v2_alpha._base import client_patterns
 from synapse.rest.well_known import WellKnownBuilder
 from synapse.types import UserID, map_username_to_mxid_localpart
@@ -507,6 +508,19 @@ class SSOAuthHandler(object):
                 localpart=localpart, default_display_name=user_display_name
             )
 
+        self.complete_sso_login(registered_user_id, request, client_redirect_url)
+
+    def complete_sso_login(
+        self, registered_user_id: str, request: SynapseRequest, client_redirect_url: str
+    ):
+        """Having figured out a mxid for this user, complete the HTTP request
+
+        Args:
+            registered_user_id:
+            request:
+            client_redirect_url:
+        """
+
         login_token = self._macaroon_gen.generate_short_term_login_token(
             registered_user_id
         )
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index e5f0668f09..9522acd972 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -218,7 +218,7 @@ class BackgroundUpdateStore(SQLBaseStore):
         duration_ms = time_stop - time_start
 
         logger.info(
-            "Updating %r. Updated %r items in %rms."
+            "Running background update %r. Processed %r items in %rms."
             " (total_rate=%r/ms, current_rate=%r/ms, total_updated=%r, batch_size=%r)",
             update_name,
             items_updated,
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 805411a6b2..241a7be51e 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -22,6 +22,7 @@ from six import iterkeys
 from six.moves import range
 
 from twisted.internet import defer
+from twisted.internet.defer import Deferred
 
 from synapse.api.constants import UserTypes
 from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
@@ -384,6 +385,26 @@ class RegistrationWorkerStore(SQLBaseStore):
 
         return self.runInteraction("get_users_by_id_case_insensitive", f)
 
+    async def get_user_by_external_id(
+        self, auth_provider: str, external_id: str
+    ) -> str:
+        """Look up a user by their external auth id
+
+        Args:
+            auth_provider: identifier for the remote auth provider
+            external_id: id on that system
+
+        Returns:
+            str|None: the mxid of the user, or None if they are not known
+        """
+        return await self._simple_select_one_onecol(
+            table="user_external_ids",
+            keyvalues={"auth_provider": auth_provider, "external_id": external_id},
+            retcol="user_id",
+            allow_none=True,
+            desc="get_user_by_external_id",
+        )
+
     @defer.inlineCallbacks
     def count_all_users(self):
         """Counts all users registered on the homeserver."""
@@ -495,7 +516,7 @@ class RegistrationWorkerStore(SQLBaseStore):
         )
 
     @defer.inlineCallbacks
-    def get_user_id_by_threepid(self, medium, address, require_verified=False):
+    def get_user_id_by_threepid(self, medium, address):
         """Returns user id from threepid
 
         Args:
@@ -844,7 +865,7 @@ class RegistrationStore(
             rows = self.cursor_to_dict(txn)
 
             if not rows:
-                return True
+                return True, 0
 
             rows_processed_nb = 0
 
@@ -860,18 +881,18 @@ class RegistrationStore(
             )
 
             if batch_size > len(rows):
-                return True
+                return True, len(rows)
             else:
-                return False
+                return False, len(rows)
 
-        end = yield self.runInteraction(
+        end, nb_processed = yield self.runInteraction(
             "users_set_deactivated_flag", _background_update_set_deactivated_flag_txn
         )
 
         if end:
             yield self._end_background_update("users_set_deactivated_flag")
 
-        return batch_size
+        return nb_processed
 
     @defer.inlineCallbacks
     def add_access_token_to_user(self, user_id, token, device_id, valid_until_ms):
@@ -1032,6 +1053,26 @@ class RegistrationStore(
         self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
         txn.call_after(self.is_guest.invalidate, (user_id,))
 
+    def record_user_external_id(
+        self, auth_provider: str, external_id: str, user_id: str
+    ) -> Deferred:
+        """Record a mapping from an external user id to a mxid
+
+        Args:
+            auth_provider: identifier for the remote auth provider
+            external_id: id on that system
+            user_id: complete mxid that it is mapped to
+        """
+        return self._simple_insert(
+            table="user_external_ids",
+            values={
+                "auth_provider": auth_provider,
+                "external_id": external_id,
+                "user_id": user_id,
+            },
+            desc="record_user_external_id",
+        )
+
     def user_set_password_hash(self, user_id, password_hash):
         """
         NB. This does *not* evict any cache because the one use for this
diff --git a/synapse/storage/schema/delta/56/user_external_ids.sql b/synapse/storage/schema/delta/56/user_external_ids.sql
new file mode 100644
index 0000000000..91390c4527
--- /dev/null
+++ b/synapse/storage/schema/delta/56/user_external_ids.sql
@@ -0,0 +1,24 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * a table which records mappings from external auth providers to mxids
+ */
+CREATE TABLE IF NOT EXISTS user_external_ids (
+    auth_provider TEXT NOT NULL,
+    external_id TEXT NOT NULL,
+    user_id TEXT NOT NULL,
+    UNIQUE (auth_provider, external_id)
+);
diff --git a/synapse/util/module_loader.py b/synapse/util/module_loader.py
index 522acd5aa8..7ff7eb1e4d 100644
--- a/synapse/util/module_loader.py
+++ b/synapse/util/module_loader.py
@@ -14,12 +14,13 @@
 # limitations under the License.
 
 import importlib
+import importlib.util
 
 from synapse.config._base import ConfigError
 
 
 def load_module(provider):
-    """ Loads a module with its config
+    """ Loads a synapse module with its config
     Take a dict with keys 'module' (the module name) and 'config'
     (the config dict).
 
@@ -38,3 +39,20 @@ def load_module(provider):
         raise ConfigError("Failed to parse config for %r: %r" % (provider["module"], e))
 
     return provider_class, provider_config
+
+
+def load_python_module(location: str):
+    """Load a python module, and return a reference to its global namespace
+
+    Args:
+        location (str): path to the module
+
+    Returns:
+        python module object
+    """
+    spec = importlib.util.spec_from_file_location(location, location)
+    if spec is None:
+        raise Exception("Unable to load module at %s" % (location,))
+    mod = importlib.util.module_from_spec(spec)
+    spec.loader.exec_module(mod)
+    return mod
diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py
index b9ef46e8fb..b6df1396ad 100644
--- a/tests/rest/client/v2_alpha/test_auth.py
+++ b/tests/rest/client/v2_alpha/test_auth.py
@@ -18,11 +18,22 @@ from twisted.internet.defer import succeed
 
 import synapse.rest.admin
 from synapse.api.constants import LoginType
+from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
 from synapse.rest.client.v2_alpha import auth, register
 
 from tests import unittest
 
 
+class DummyRecaptchaChecker(UserInteractiveAuthChecker):
+    def __init__(self, hs):
+        super().__init__(hs)
+        self.recaptcha_attempts = []
+
+    def check_auth(self, authdict, clientip):
+        self.recaptcha_attempts.append((authdict, clientip))
+        return succeed(True)
+
+
 class FallbackAuthTests(unittest.HomeserverTestCase):
 
     servlets = [
@@ -44,15 +55,9 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
         return hs
 
     def prepare(self, reactor, clock, hs):
+        self.recaptcha_checker = DummyRecaptchaChecker(hs)
         auth_handler = hs.get_auth_handler()
-
-        self.recaptcha_attempts = []
-
-        def _recaptcha(authdict, clientip):
-            self.recaptcha_attempts.append((authdict, clientip))
-            return succeed(True)
-
-        auth_handler.checkers[LoginType.RECAPTCHA] = _recaptcha
+        auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker
 
     @unittest.INFO
     def test_fallback_captcha(self):
@@ -89,8 +94,9 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
         self.assertEqual(request.code, 200)
 
         # The recaptcha handler is called with the response given
-        self.assertEqual(len(self.recaptcha_attempts), 1)
-        self.assertEqual(self.recaptcha_attempts[0][0]["response"], "a")
+        attempts = self.recaptcha_checker.recaptcha_attempts
+        self.assertEqual(len(attempts), 1)
+        self.assertEqual(attempts[0][0]["response"], "a")
 
         # also complete the dummy auth
         request, channel = self.make_request(