summary refs log tree commit diff
path: root/synapse/handlers/auth.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/auth.py')
-rw-r--r--synapse/handlers/auth.py173
1 files changed, 26 insertions, 147 deletions
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 0f3ebf7ef8..333eb30625 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -21,10 +21,8 @@ import unicodedata
 import attr
 import bcrypt
 import pymacaroons
-from canonicaljson import json
 
 from twisted.internet import defer
-from twisted.web.client import PartialDownloadError
 
 import synapse.util.stringutils as stringutils
 from synapse.api.constants import LoginType
@@ -38,6 +36,8 @@ from synapse.api.errors import (
     UserDeactivatedError,
 )
 from synapse.api.ratelimiting import Ratelimiter
+from synapse.handlers.ui_auth import INTERACTIVE_AUTH_CHECKERS
+from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
 from synapse.logging.context import defer_to_thread
 from synapse.module_api import ModuleApi
 from synapse.types import UserID
@@ -57,13 +57,13 @@ class AuthHandler(BaseHandler):
             hs (synapse.server.HomeServer):
         """
         super(AuthHandler, self).__init__(hs)
-        self.checkers = {
-            LoginType.RECAPTCHA: self._check_recaptcha,
-            LoginType.EMAIL_IDENTITY: self._check_email_identity,
-            LoginType.MSISDN: self._check_msisdn,
-            LoginType.DUMMY: self._check_dummy_auth,
-            LoginType.TERMS: self._check_terms_auth,
-        }
+
+        self.checkers = {}  # type: dict[str, UserInteractiveAuthChecker]
+        for auth_checker_class in INTERACTIVE_AUTH_CHECKERS:
+            inst = auth_checker_class(hs)
+            if inst.is_enabled():
+                self.checkers[inst.AUTH_TYPE] = inst
+
         self.bcrypt_rounds = hs.config.bcrypt_rounds
 
         # This is not a cache per se, but a store of all current sessions that
@@ -157,8 +157,16 @@ class AuthHandler(BaseHandler):
 
         return params
 
+    def get_enabled_auth_types(self):
+        """Return the enabled user-interactive authentication types
+
+        Returns the UI-Auth types which are supported by the homeserver's current
+        config.
+        """
+        return self.checkers.keys()
+
     @defer.inlineCallbacks
-    def check_auth(self, flows, clientdict, clientip, password_servlet=False):
+    def check_auth(self, flows, clientdict, clientip):
         """
         Takes a dictionary sent by the client in the login / registration
         protocol and handles the User-Interactive Auth flow.
@@ -182,16 +190,6 @@ class AuthHandler(BaseHandler):
 
             clientip (str): The IP address of the client.
 
-            password_servlet (bool): Whether the request originated from
-                PasswordRestServlet.
-                XXX: This is a temporary hack to distinguish between checking
-                for threepid validations locally (in the case of password
-                resets) and using the identity server (in the case of binding
-                a 3PID during registration). Once we start using the
-                homeserver for both tasks, this distinction will no longer be
-                necessary.
-
-
         Returns:
             defer.Deferred[dict, dict, str]: a deferred tuple of
                 (creds, params, session_id).
@@ -247,9 +245,7 @@ class AuthHandler(BaseHandler):
         if "type" in authdict:
             login_type = authdict["type"]
             try:
-                result = yield self._check_auth_dict(
-                    authdict, clientip, password_servlet=password_servlet
-                )
+                result = yield self._check_auth_dict(authdict, clientip)
                 if result:
                     creds[login_type] = result
                     self._save_session(session)
@@ -280,7 +276,7 @@ class AuthHandler(BaseHandler):
                     creds,
                     list(clientdict),
                 )
-                return (creds, clientdict, session["id"])
+                return creds, clientdict, session["id"]
 
         ret = self._auth_dict_for_flows(flows, session)
         ret["completed"] = list(creds)
@@ -303,7 +299,7 @@ class AuthHandler(BaseHandler):
             sess["creds"] = {}
         creds = sess["creds"]
 
-        result = yield self.checkers[stagetype](authdict, clientip)
+        result = yield self.checkers[stagetype].check_auth(authdict, clientip)
         if result:
             creds[stagetype] = result
             self._save_session(sess)
@@ -356,7 +352,7 @@ class AuthHandler(BaseHandler):
         return sess.setdefault("serverdict", {}).get(key, default)
 
     @defer.inlineCallbacks
-    def _check_auth_dict(self, authdict, clientip, password_servlet=False):
+    def _check_auth_dict(self, authdict, clientip):
         """Attempt to validate the auth dict provided by a client
 
         Args:
@@ -374,11 +370,7 @@ class AuthHandler(BaseHandler):
         login_type = authdict["type"]
         checker = self.checkers.get(login_type)
         if checker is not None:
-            # XXX: Temporary workaround for having Synapse handle password resets
-            # See AuthHandler.check_auth for further details
-            res = yield checker(
-                authdict, clientip=clientip, password_servlet=password_servlet
-            )
+            res = yield checker.check_auth(authdict, clientip=clientip)
             return res
 
         # build a v1-login-style dict out of the authdict and fall back to the
@@ -391,119 +383,6 @@ class AuthHandler(BaseHandler):
         (canonical_id, callback) = yield self.validate_login(user_id, authdict)
         return canonical_id
 
-    @defer.inlineCallbacks
-    def _check_recaptcha(self, authdict, clientip, **kwargs):
-        try:
-            user_response = authdict["response"]
-        except KeyError:
-            # Client tried to provide captcha but didn't give the parameter:
-            # bad request.
-            raise LoginError(
-                400, "Captcha response is required", errcode=Codes.CAPTCHA_NEEDED
-            )
-
-        logger.info(
-            "Submitting recaptcha response %s with remoteip %s", user_response, clientip
-        )
-
-        # TODO: get this from the homeserver rather than creating a new one for
-        # each request
-        try:
-            client = self.hs.get_simple_http_client()
-            resp_body = yield client.post_urlencoded_get_json(
-                self.hs.config.recaptcha_siteverify_api,
-                args={
-                    "secret": self.hs.config.recaptcha_private_key,
-                    "response": user_response,
-                    "remoteip": clientip,
-                },
-            )
-        except PartialDownloadError as pde:
-            # Twisted is silly
-            data = pde.response
-            resp_body = json.loads(data)
-
-        if "success" in resp_body:
-            # Note that we do NOT check the hostname here: we explicitly
-            # intend the CAPTCHA to be presented by whatever client the
-            # user is using, we just care that they have completed a CAPTCHA.
-            logger.info(
-                "%s reCAPTCHA from hostname %s",
-                "Successful" if resp_body["success"] else "Failed",
-                resp_body.get("hostname"),
-            )
-            if resp_body["success"]:
-                return True
-        raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
-
-    def _check_email_identity(self, authdict, **kwargs):
-        return self._check_threepid("email", authdict, **kwargs)
-
-    def _check_msisdn(self, authdict, **kwargs):
-        return self._check_threepid("msisdn", authdict)
-
-    def _check_dummy_auth(self, authdict, **kwargs):
-        return defer.succeed(True)
-
-    def _check_terms_auth(self, authdict, **kwargs):
-        return defer.succeed(True)
-
-    @defer.inlineCallbacks
-    def _check_threepid(self, medium, authdict, password_servlet=False, **kwargs):
-        if "threepid_creds" not in authdict:
-            raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM)
-
-        threepid_creds = authdict["threepid_creds"]
-
-        identity_handler = self.hs.get_handlers().identity_handler
-
-        logger.info("Getting validated threepid. threepidcreds: %r", (threepid_creds,))
-        if (
-            not password_servlet
-            or self.hs.config.email_password_reset_behaviour == "remote"
-        ):
-            threepid = yield identity_handler.threepid_from_creds(threepid_creds)
-        elif self.hs.config.email_password_reset_behaviour == "local":
-            row = yield self.store.get_threepid_validation_session(
-                medium,
-                threepid_creds["client_secret"],
-                sid=threepid_creds["sid"],
-                validated=True,
-            )
-
-            threepid = (
-                {
-                    "medium": row["medium"],
-                    "address": row["address"],
-                    "validated_at": row["validated_at"],
-                }
-                if row
-                else None
-            )
-
-            if row:
-                # Valid threepid returned, delete from the db
-                yield self.store.delete_threepid_session(threepid_creds["sid"])
-        else:
-            raise SynapseError(
-                400, "Password resets are not enabled on this homeserver"
-            )
-
-        if not threepid:
-            raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
-
-        if threepid["medium"] != medium:
-            raise LoginError(
-                401,
-                "Expecting threepid of type '%s', got '%s'"
-                % (medium, threepid["medium"]),
-                errcode=Codes.UNAUTHORIZED,
-            )
-
-        threepid["threepid_creds"] = authdict["threepid_creds"]
-
-        return threepid
-
     def _get_params_recaptcha(self):
         return {"public_key": self.hs.config.recaptcha_public_key}
 
@@ -722,7 +601,7 @@ class AuthHandler(BaseHandler):
                 known_login_type = True
                 is_valid = yield provider.check_password(qualified_user_id, password)
                 if is_valid:
-                    return (qualified_user_id, None)
+                    return qualified_user_id, None
 
             if not hasattr(provider, "get_supported_login_types") or not hasattr(
                 provider, "check_auth"
@@ -766,7 +645,7 @@ class AuthHandler(BaseHandler):
             )
 
             if canonical_user_id:
-                return (canonical_user_id, None)
+                return canonical_user_id, None
 
         if not known_login_type:
             raise SynapseError(400, "Unknown login type %s" % login_type)
@@ -816,7 +695,7 @@ class AuthHandler(BaseHandler):
                         result = (result, None)
                     return result
 
-        return (None, None)
+        return None, None
 
     @defer.inlineCallbacks
     def _check_local_password(self, user_id, password):