diff options
Diffstat (limited to 'synapse/handlers/auth.py')
-rw-r--r-- | synapse/handlers/auth.py | 173 |
1 files changed, 26 insertions, 147 deletions
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 0f3ebf7ef8..333eb30625 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -21,10 +21,8 @@ import unicodedata import attr import bcrypt import pymacaroons -from canonicaljson import json from twisted.internet import defer -from twisted.web.client import PartialDownloadError import synapse.util.stringutils as stringutils from synapse.api.constants import LoginType @@ -38,6 +36,8 @@ from synapse.api.errors import ( UserDeactivatedError, ) from synapse.api.ratelimiting import Ratelimiter +from synapse.handlers.ui_auth import INTERACTIVE_AUTH_CHECKERS +from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker from synapse.logging.context import defer_to_thread from synapse.module_api import ModuleApi from synapse.types import UserID @@ -57,13 +57,13 @@ class AuthHandler(BaseHandler): hs (synapse.server.HomeServer): """ super(AuthHandler, self).__init__(hs) - self.checkers = { - LoginType.RECAPTCHA: self._check_recaptcha, - LoginType.EMAIL_IDENTITY: self._check_email_identity, - LoginType.MSISDN: self._check_msisdn, - LoginType.DUMMY: self._check_dummy_auth, - LoginType.TERMS: self._check_terms_auth, - } + + self.checkers = {} # type: dict[str, UserInteractiveAuthChecker] + for auth_checker_class in INTERACTIVE_AUTH_CHECKERS: + inst = auth_checker_class(hs) + if inst.is_enabled(): + self.checkers[inst.AUTH_TYPE] = inst + self.bcrypt_rounds = hs.config.bcrypt_rounds # This is not a cache per se, but a store of all current sessions that @@ -157,8 +157,16 @@ class AuthHandler(BaseHandler): return params + def get_enabled_auth_types(self): + """Return the enabled user-interactive authentication types + + Returns the UI-Auth types which are supported by the homeserver's current + config. + """ + return self.checkers.keys() + @defer.inlineCallbacks - def check_auth(self, flows, clientdict, clientip, password_servlet=False): + def check_auth(self, flows, clientdict, clientip): """ Takes a dictionary sent by the client in the login / registration protocol and handles the User-Interactive Auth flow. @@ -182,16 +190,6 @@ class AuthHandler(BaseHandler): clientip (str): The IP address of the client. - password_servlet (bool): Whether the request originated from - PasswordRestServlet. - XXX: This is a temporary hack to distinguish between checking - for threepid validations locally (in the case of password - resets) and using the identity server (in the case of binding - a 3PID during registration). Once we start using the - homeserver for both tasks, this distinction will no longer be - necessary. - - Returns: defer.Deferred[dict, dict, str]: a deferred tuple of (creds, params, session_id). @@ -247,9 +245,7 @@ class AuthHandler(BaseHandler): if "type" in authdict: login_type = authdict["type"] try: - result = yield self._check_auth_dict( - authdict, clientip, password_servlet=password_servlet - ) + result = yield self._check_auth_dict(authdict, clientip) if result: creds[login_type] = result self._save_session(session) @@ -280,7 +276,7 @@ class AuthHandler(BaseHandler): creds, list(clientdict), ) - return (creds, clientdict, session["id"]) + return creds, clientdict, session["id"] ret = self._auth_dict_for_flows(flows, session) ret["completed"] = list(creds) @@ -303,7 +299,7 @@ class AuthHandler(BaseHandler): sess["creds"] = {} creds = sess["creds"] - result = yield self.checkers[stagetype](authdict, clientip) + result = yield self.checkers[stagetype].check_auth(authdict, clientip) if result: creds[stagetype] = result self._save_session(sess) @@ -356,7 +352,7 @@ class AuthHandler(BaseHandler): return sess.setdefault("serverdict", {}).get(key, default) @defer.inlineCallbacks - def _check_auth_dict(self, authdict, clientip, password_servlet=False): + def _check_auth_dict(self, authdict, clientip): """Attempt to validate the auth dict provided by a client Args: @@ -374,11 +370,7 @@ class AuthHandler(BaseHandler): login_type = authdict["type"] checker = self.checkers.get(login_type) if checker is not None: - # XXX: Temporary workaround for having Synapse handle password resets - # See AuthHandler.check_auth for further details - res = yield checker( - authdict, clientip=clientip, password_servlet=password_servlet - ) + res = yield checker.check_auth(authdict, clientip=clientip) return res # build a v1-login-style dict out of the authdict and fall back to the @@ -391,119 +383,6 @@ class AuthHandler(BaseHandler): (canonical_id, callback) = yield self.validate_login(user_id, authdict) return canonical_id - @defer.inlineCallbacks - def _check_recaptcha(self, authdict, clientip, **kwargs): - try: - user_response = authdict["response"] - except KeyError: - # Client tried to provide captcha but didn't give the parameter: - # bad request. - raise LoginError( - 400, "Captcha response is required", errcode=Codes.CAPTCHA_NEEDED - ) - - logger.info( - "Submitting recaptcha response %s with remoteip %s", user_response, clientip - ) - - # TODO: get this from the homeserver rather than creating a new one for - # each request - try: - client = self.hs.get_simple_http_client() - resp_body = yield client.post_urlencoded_get_json( - self.hs.config.recaptcha_siteverify_api, - args={ - "secret": self.hs.config.recaptcha_private_key, - "response": user_response, - "remoteip": clientip, - }, - ) - except PartialDownloadError as pde: - # Twisted is silly - data = pde.response - resp_body = json.loads(data) - - if "success" in resp_body: - # Note that we do NOT check the hostname here: we explicitly - # intend the CAPTCHA to be presented by whatever client the - # user is using, we just care that they have completed a CAPTCHA. - logger.info( - "%s reCAPTCHA from hostname %s", - "Successful" if resp_body["success"] else "Failed", - resp_body.get("hostname"), - ) - if resp_body["success"]: - return True - raise LoginError(401, "", errcode=Codes.UNAUTHORIZED) - - def _check_email_identity(self, authdict, **kwargs): - return self._check_threepid("email", authdict, **kwargs) - - def _check_msisdn(self, authdict, **kwargs): - return self._check_threepid("msisdn", authdict) - - def _check_dummy_auth(self, authdict, **kwargs): - return defer.succeed(True) - - def _check_terms_auth(self, authdict, **kwargs): - return defer.succeed(True) - - @defer.inlineCallbacks - def _check_threepid(self, medium, authdict, password_servlet=False, **kwargs): - if "threepid_creds" not in authdict: - raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM) - - threepid_creds = authdict["threepid_creds"] - - identity_handler = self.hs.get_handlers().identity_handler - - logger.info("Getting validated threepid. threepidcreds: %r", (threepid_creds,)) - if ( - not password_servlet - or self.hs.config.email_password_reset_behaviour == "remote" - ): - threepid = yield identity_handler.threepid_from_creds(threepid_creds) - elif self.hs.config.email_password_reset_behaviour == "local": - row = yield self.store.get_threepid_validation_session( - medium, - threepid_creds["client_secret"], - sid=threepid_creds["sid"], - validated=True, - ) - - threepid = ( - { - "medium": row["medium"], - "address": row["address"], - "validated_at": row["validated_at"], - } - if row - else None - ) - - if row: - # Valid threepid returned, delete from the db - yield self.store.delete_threepid_session(threepid_creds["sid"]) - else: - raise SynapseError( - 400, "Password resets are not enabled on this homeserver" - ) - - if not threepid: - raise LoginError(401, "", errcode=Codes.UNAUTHORIZED) - - if threepid["medium"] != medium: - raise LoginError( - 401, - "Expecting threepid of type '%s', got '%s'" - % (medium, threepid["medium"]), - errcode=Codes.UNAUTHORIZED, - ) - - threepid["threepid_creds"] = authdict["threepid_creds"] - - return threepid - def _get_params_recaptcha(self): return {"public_key": self.hs.config.recaptcha_public_key} @@ -722,7 +601,7 @@ class AuthHandler(BaseHandler): known_login_type = True is_valid = yield provider.check_password(qualified_user_id, password) if is_valid: - return (qualified_user_id, None) + return qualified_user_id, None if not hasattr(provider, "get_supported_login_types") or not hasattr( provider, "check_auth" @@ -766,7 +645,7 @@ class AuthHandler(BaseHandler): ) if canonical_user_id: - return (canonical_user_id, None) + return canonical_user_id, None if not known_login_type: raise SynapseError(400, "Unknown login type %s" % login_type) @@ -816,7 +695,7 @@ class AuthHandler(BaseHandler): result = (result, None) return result - return (None, None) + return None, None @defer.inlineCallbacks def _check_local_password(self, user_id, password): |