summary refs log tree commit diff
path: root/synapse/handlers/auth.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/auth.py')
-rw-r--r--synapse/handlers/auth.py266
1 files changed, 121 insertions, 145 deletions
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index a0cf37a9f9..97b21c4093 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -134,13 +134,9 @@ class AuthHandler(BaseHandler):
         """
 
         # build a list of supported flows
-        flows = [
-            [login_type] for login_type in self._supported_login_types
-        ]
+        flows = [[login_type] for login_type in self._supported_login_types]
 
-        result, params, _ = yield self.check_auth(
-            flows, request_body, clientip,
-        )
+        result, params, _ = yield self.check_auth(flows, request_body, clientip)
 
         # find the completed login type
         for login_type in self._supported_login_types:
@@ -151,9 +147,7 @@ class AuthHandler(BaseHandler):
             break
         else:
             # this can't happen
-            raise Exception(
-                "check_auth returned True but no successful login type",
-            )
+            raise Exception("check_auth returned True but no successful login type")
 
         # check that the UI auth matched the access token
         if user_id != requester.user.to_string():
@@ -215,11 +209,11 @@ class AuthHandler(BaseHandler):
 
         authdict = None
         sid = None
-        if clientdict and 'auth' in clientdict:
-            authdict = clientdict['auth']
-            del clientdict['auth']
-            if 'session' in authdict:
-                sid = authdict['session']
+        if clientdict and "auth" in clientdict:
+            authdict = clientdict["auth"]
+            del clientdict["auth"]
+            if "session" in authdict:
+                sid = authdict["session"]
         session = self._get_session_info(sid)
 
         if len(clientdict) > 0:
@@ -232,27 +226,27 @@ class AuthHandler(BaseHandler):
             # on a home server.
             # Revisit: Assumimg the REST APIs do sensible validation, the data
             # isn't arbintrary.
-            session['clientdict'] = clientdict
+            session["clientdict"] = clientdict
             self._save_session(session)
-        elif 'clientdict' in session:
-            clientdict = session['clientdict']
+        elif "clientdict" in session:
+            clientdict = session["clientdict"]
 
         if not authdict:
             raise InteractiveAuthIncompleteError(
-                self._auth_dict_for_flows(flows, session),
+                self._auth_dict_for_flows(flows, session)
             )
 
-        if 'creds' not in session:
-            session['creds'] = {}
-        creds = session['creds']
+        if "creds" not in session:
+            session["creds"] = {}
+        creds = session["creds"]
 
         # check auth type currently being presented
         errordict = {}
-        if 'type' in authdict:
-            login_type = authdict['type']
+        if "type" in authdict:
+            login_type = authdict["type"]
             try:
                 result = yield self._check_auth_dict(
-                    authdict, clientip, password_servlet=password_servlet,
+                    authdict, clientip, password_servlet=password_servlet
                 )
                 if result:
                     creds[login_type] = result
@@ -281,16 +275,15 @@ class AuthHandler(BaseHandler):
                 # and is not sensitive).
                 logger.info(
                     "Auth completed with creds: %r. Client dict has keys: %r",
-                    creds, list(clientdict)
+                    creds,
+                    list(clientdict),
                 )
-                defer.returnValue((creds, clientdict, session['id']))
+                defer.returnValue((creds, clientdict, session["id"]))
 
         ret = self._auth_dict_for_flows(flows, session)
-        ret['completed'] = list(creds)
+        ret["completed"] = list(creds)
         ret.update(errordict)
-        raise InteractiveAuthIncompleteError(
-            ret,
-        )
+        raise InteractiveAuthIncompleteError(ret)
 
     @defer.inlineCallbacks
     def add_oob_auth(self, stagetype, authdict, clientip):
@@ -300,15 +293,13 @@ class AuthHandler(BaseHandler):
         """
         if stagetype not in self.checkers:
             raise LoginError(400, "", Codes.MISSING_PARAM)
-        if 'session' not in authdict:
+        if "session" not in authdict:
             raise LoginError(400, "", Codes.MISSING_PARAM)
 
-        sess = self._get_session_info(
-            authdict['session']
-        )
-        if 'creds' not in sess:
-            sess['creds'] = {}
-        creds = sess['creds']
+        sess = self._get_session_info(authdict["session"])
+        if "creds" not in sess:
+            sess["creds"] = {}
+        creds = sess["creds"]
 
         result = yield self.checkers[stagetype](authdict, clientip)
         if result:
@@ -329,10 +320,10 @@ class AuthHandler(BaseHandler):
                 not send a session ID, returns None.
         """
         sid = None
-        if clientdict and 'auth' in clientdict:
-            authdict = clientdict['auth']
-            if 'session' in authdict:
-                sid = authdict['session']
+        if clientdict and "auth" in clientdict:
+            authdict = clientdict["auth"]
+            if "session" in authdict:
+                sid = authdict["session"]
         return sid
 
     def set_session_data(self, session_id, key, value):
@@ -347,7 +338,7 @@ class AuthHandler(BaseHandler):
             value (any): The data to store
         """
         sess = self._get_session_info(session_id)
-        sess.setdefault('serverdict', {})[key] = value
+        sess.setdefault("serverdict", {})[key] = value
         self._save_session(sess)
 
     def get_session_data(self, session_id, key, default=None):
@@ -360,7 +351,7 @@ class AuthHandler(BaseHandler):
             default (any): Value to return if the key has not been set
         """
         sess = self._get_session_info(session_id)
-        return sess.setdefault('serverdict', {}).get(key, default)
+        return sess.setdefault("serverdict", {}).get(key, default)
 
     @defer.inlineCallbacks
     def _check_auth_dict(self, authdict, clientip, password_servlet=False):
@@ -378,15 +369,13 @@ class AuthHandler(BaseHandler):
             SynapseError if there was a problem with the request
             LoginError if there was an authentication problem.
         """
-        login_type = authdict['type']
+        login_type = authdict["type"]
         checker = self.checkers.get(login_type)
         if checker is not None:
             # XXX: Temporary workaround for having Synapse handle password resets
             # See AuthHandler.check_auth for further details
             res = yield checker(
-                authdict,
-                clientip=clientip,
-                password_servlet=password_servlet,
+                authdict, clientip=clientip, password_servlet=password_servlet
             )
             defer.returnValue(res)
 
@@ -408,13 +397,11 @@ class AuthHandler(BaseHandler):
             # Client tried to provide captcha but didn't give the parameter:
             # bad request.
             raise LoginError(
-                400, "Captcha response is required",
-                errcode=Codes.CAPTCHA_NEEDED
+                400, "Captcha response is required", errcode=Codes.CAPTCHA_NEEDED
             )
 
         logger.info(
-            "Submitting recaptcha response %s with remoteip %s",
-            user_response, clientip
+            "Submitting recaptcha response %s with remoteip %s", user_response, clientip
         )
 
         # TODO: get this from the homeserver rather than creating a new one for
@@ -424,34 +411,34 @@ class AuthHandler(BaseHandler):
             resp_body = yield client.post_urlencoded_get_json(
                 self.hs.config.recaptcha_siteverify_api,
                 args={
-                    'secret': self.hs.config.recaptcha_private_key,
-                    'response': user_response,
-                    'remoteip': clientip,
-                }
+                    "secret": self.hs.config.recaptcha_private_key,
+                    "response": user_response,
+                    "remoteip": clientip,
+                },
             )
         except PartialDownloadError as pde:
             # Twisted is silly
             data = pde.response
             resp_body = json.loads(data)
 
-        if 'success' in resp_body:
+        if "success" in resp_body:
             # Note that we do NOT check the hostname here: we explicitly
             # intend the CAPTCHA to be presented by whatever client the
             # user is using, we just care that they have completed a CAPTCHA.
             logger.info(
                 "%s reCAPTCHA from hostname %s",
-                "Successful" if resp_body['success'] else "Failed",
-                resp_body.get('hostname')
+                "Successful" if resp_body["success"] else "Failed",
+                resp_body.get("hostname"),
             )
-            if resp_body['success']:
+            if resp_body["success"]:
                 defer.returnValue(True)
         raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
 
     def _check_email_identity(self, authdict, **kwargs):
-        return self._check_threepid('email', authdict, **kwargs)
+        return self._check_threepid("email", authdict, **kwargs)
 
     def _check_msisdn(self, authdict, **kwargs):
-        return self._check_threepid('msisdn', authdict)
+        return self._check_threepid("msisdn", authdict)
 
     def _check_dummy_auth(self, authdict, **kwargs):
         return defer.succeed(True)
@@ -461,10 +448,10 @@ class AuthHandler(BaseHandler):
 
     @defer.inlineCallbacks
     def _check_threepid(self, medium, authdict, password_servlet=False, **kwargs):
-        if 'threepid_creds' not in authdict:
+        if "threepid_creds" not in authdict:
             raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM)
 
-        threepid_creds = authdict['threepid_creds']
+        threepid_creds = authdict["threepid_creds"]
 
         identity_handler = self.hs.get_handlers().identity_handler
 
@@ -482,31 +469,36 @@ class AuthHandler(BaseHandler):
                 validated=True,
             )
 
-            threepid = {
-                "medium": row["medium"],
-                "address": row["address"],
-                "validated_at": row["validated_at"],
-            } if row else None
+            threepid = (
+                {
+                    "medium": row["medium"],
+                    "address": row["address"],
+                    "validated_at": row["validated_at"],
+                }
+                if row
+                else None
+            )
 
             if row:
                 # Valid threepid returned, delete from the db
                 yield self.store.delete_threepid_session(threepid_creds["sid"])
         else:
-            raise SynapseError(400, "Password resets are not enabled on this homeserver")
+            raise SynapseError(
+                400, "Password resets are not enabled on this homeserver"
+            )
 
         if not threepid:
             raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
 
-        if threepid['medium'] != medium:
+        if threepid["medium"] != medium:
             raise LoginError(
                 401,
-                "Expecting threepid of type '%s', got '%s'" % (
-                    medium, threepid['medium'],
-                ),
-                errcode=Codes.UNAUTHORIZED
+                "Expecting threepid of type '%s', got '%s'"
+                % (medium, threepid["medium"]),
+                errcode=Codes.UNAUTHORIZED,
             )
 
-        threepid['threepid_creds'] = authdict['threepid_creds']
+        threepid["threepid_creds"] = authdict["threepid_creds"]
 
         defer.returnValue(threepid)
 
@@ -520,13 +512,14 @@ class AuthHandler(BaseHandler):
                     "version": self.hs.config.user_consent_version,
                     "en": {
                         "name": self.hs.config.user_consent_policy_name,
-                        "url": "%s_matrix/consent?v=%s" % (
+                        "url": "%s_matrix/consent?v=%s"
+                        % (
                             self.hs.config.public_baseurl,
                             self.hs.config.user_consent_version,
                         ),
                     },
-                },
-            },
+                }
+            }
         }
 
     def _auth_dict_for_flows(self, flows, session):
@@ -547,9 +540,9 @@ class AuthHandler(BaseHandler):
                     params[stage] = get_params[stage]()
 
         return {
-            "session": session['id'],
+            "session": session["id"],
             "flows": [{"stages": f} for f in public_flows],
-            "params": params
+            "params": params,
         }
 
     def _get_session_info(self, session_id):
@@ -560,9 +553,7 @@ class AuthHandler(BaseHandler):
             # create a new session
             while session_id is None or session_id in self.sessions:
                 session_id = stringutils.random_string(24)
-            self.sessions[session_id] = {
-                "id": session_id,
-            }
+            self.sessions[session_id] = {"id": session_id}
 
         return self.sessions[session_id]
 
@@ -652,7 +643,8 @@ class AuthHandler(BaseHandler):
             logger.warn(
                 "Attempted to login as %s but it matches more than one user "
                 "inexactly: %r",
-                user_id, user_infos.keys()
+                user_id,
+                user_infos.keys(),
             )
         defer.returnValue(result)
 
@@ -690,12 +682,10 @@ class AuthHandler(BaseHandler):
                 user is too high too proceed.
         """
 
-        if username.startswith('@'):
+        if username.startswith("@"):
             qualified_user_id = username
         else:
-            qualified_user_id = UserID(
-                username, self.hs.hostname
-            ).to_string()
+            qualified_user_id = UserID(username, self.hs.hostname).to_string()
 
         self.ratelimit_login_per_account(qualified_user_id)
 
@@ -713,17 +703,15 @@ class AuthHandler(BaseHandler):
                 raise SynapseError(400, "Missing parameter: password")
 
         for provider in self.password_providers:
-            if (hasattr(provider, "check_password")
-                    and login_type == LoginType.PASSWORD):
+            if hasattr(provider, "check_password") and login_type == LoginType.PASSWORD:
                 known_login_type = True
-                is_valid = yield provider.check_password(
-                    qualified_user_id, password,
-                )
+                is_valid = yield provider.check_password(qualified_user_id, password)
                 if is_valid:
                     defer.returnValue((qualified_user_id, None))
 
-            if (not hasattr(provider, "get_supported_login_types")
-                    or not hasattr(provider, "check_auth")):
+            if not hasattr(provider, "get_supported_login_types") or not hasattr(
+                provider, "check_auth"
+            ):
                 # this password provider doesn't understand custom login types
                 continue
 
@@ -744,15 +732,12 @@ class AuthHandler(BaseHandler):
                     login_dict[f] = login_submission[f]
             if missing_fields:
                 raise SynapseError(
-                    400, "Missing parameters for login type %s: %s" % (
-                        login_type,
-                        missing_fields,
-                    ),
+                    400,
+                    "Missing parameters for login type %s: %s"
+                    % (login_type, missing_fields),
                 )
 
-            result = yield provider.check_auth(
-                username, login_type, login_dict,
-            )
+            result = yield provider.check_auth(username, login_type, login_dict)
             if result:
                 if isinstance(result, str):
                     result = (result, None)
@@ -762,7 +747,7 @@ class AuthHandler(BaseHandler):
             known_login_type = True
 
             canonical_user_id = yield self._check_local_password(
-                qualified_user_id, password,
+                qualified_user_id, password
             )
 
             if canonical_user_id:
@@ -773,7 +758,8 @@ class AuthHandler(BaseHandler):
 
         # unknown username or invalid password.
         self._failed_attempts_ratelimiter.ratelimit(
-            qualified_user_id.lower(), time_now_s=self._clock.time(),
+            qualified_user_id.lower(),
+            time_now_s=self._clock.time(),
             rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
             burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
             update=True,
@@ -781,10 +767,7 @@ class AuthHandler(BaseHandler):
 
         # We raise a 403 here, but note that if we're doing user-interactive
         # login, it turns all LoginErrors into a 401 anyway.
-        raise LoginError(
-            403, "Invalid password",
-            errcode=Codes.FORBIDDEN
-        )
+        raise LoginError(403, "Invalid password", errcode=Codes.FORBIDDEN)
 
     @defer.inlineCallbacks
     def check_password_provider_3pid(self, medium, address, password):
@@ -810,9 +793,7 @@ class AuthHandler(BaseHandler):
                 # success, to a str (which is the user_id) or a tuple of
                 # (user_id, callback_func), where callback_func should be run
                 # after we've finished everything else
-                result = yield provider.check_3pid_auth(
-                    medium, address, password,
-                )
+                result = yield provider.check_3pid_auth(medium, address, password)
                 if result:
                     # Check if the return value is a str or a tuple
                     if isinstance(result, str):
@@ -853,8 +834,7 @@ class AuthHandler(BaseHandler):
     @defer.inlineCallbacks
     def issue_access_token(self, user_id, device_id=None):
         access_token = self.macaroon_gen.generate_access_token(user_id)
-        yield self.store.add_access_token_to_user(user_id, access_token,
-                                                  device_id)
+        yield self.store.add_access_token_to_user(user_id, access_token, device_id)
         defer.returnValue(access_token)
 
     @defer.inlineCallbacks
@@ -896,12 +876,13 @@ class AuthHandler(BaseHandler):
         # delete pushers associated with this access token
         if user_info["token_id"] is not None:
             yield self.hs.get_pusherpool().remove_pushers_by_access_token(
-                str(user_info["user"]), (user_info["token_id"], )
+                str(user_info["user"]), (user_info["token_id"],)
             )
 
     @defer.inlineCallbacks
-    def delete_access_tokens_for_user(self, user_id, except_token_id=None,
-                                      device_id=None):
+    def delete_access_tokens_for_user(
+        self, user_id, except_token_id=None, device_id=None
+    ):
         """Invalidate access tokens belonging to a user
 
         Args:
@@ -915,7 +896,7 @@ class AuthHandler(BaseHandler):
             Deferred
         """
         tokens_and_devices = yield self.store.user_delete_access_tokens(
-            user_id, except_token_id=except_token_id, device_id=device_id,
+            user_id, except_token_id=except_token_id, device_id=device_id
         )
 
         # see if any of our auth providers want to know about this
@@ -923,14 +904,12 @@ class AuthHandler(BaseHandler):
             if hasattr(provider, "on_logged_out"):
                 for token, token_id, device_id in tokens_and_devices:
                     yield provider.on_logged_out(
-                        user_id=user_id,
-                        device_id=device_id,
-                        access_token=token,
+                        user_id=user_id, device_id=device_id, access_token=token
                     )
 
         # delete pushers associated with the access tokens
         yield self.hs.get_pusherpool().remove_pushers_by_access_token(
-            user_id, (token_id for _, token_id, _ in tokens_and_devices),
+            user_id, (token_id for _, token_id, _ in tokens_and_devices)
         )
 
     @defer.inlineCallbacks
@@ -944,12 +923,11 @@ class AuthHandler(BaseHandler):
         # of specific types of threepid (and fixes the fact that checking
         # for the presence of an email address during password reset was
         # case sensitive).
-        if medium == 'email':
+        if medium == "email":
             address = address.lower()
 
         yield self.store.user_add_threepid(
-            user_id, medium, address, validated_at,
-            self.hs.get_clock().time_msec()
+            user_id, medium, address, validated_at, self.hs.get_clock().time_msec()
         )
 
     @defer.inlineCallbacks
@@ -973,22 +951,15 @@ class AuthHandler(BaseHandler):
         """
 
         # 'Canonicalise' email addresses as per above
-        if medium == 'email':
+        if medium == "email":
             address = address.lower()
 
         identity_handler = self.hs.get_handlers().identity_handler
         result = yield identity_handler.try_unbind_threepid(
-            user_id,
-            {
-                'medium': medium,
-                'address': address,
-                'id_server': id_server,
-            },
+            user_id, {"medium": medium, "address": address, "id_server": id_server}
         )
 
-        yield self.store.user_delete_threepid(
-            user_id, medium, address,
-        )
+        yield self.store.user_delete_threepid(user_id, medium, address)
         defer.returnValue(result)
 
     def _save_session(self, session):
@@ -1006,14 +977,15 @@ class AuthHandler(BaseHandler):
         Returns:
             Deferred(unicode): Hashed password.
         """
+
         def _do_hash():
             # Normalise the Unicode in the password
             pw = unicodedata.normalize("NFKC", password)
 
             return bcrypt.hashpw(
-                pw.encode('utf8') + self.hs.config.password_pepper.encode("utf8"),
+                pw.encode("utf8") + self.hs.config.password_pepper.encode("utf8"),
                 bcrypt.gensalt(self.bcrypt_rounds),
-            ).decode('ascii')
+            ).decode("ascii")
 
         return logcontext.defer_to_thread(self.hs.get_reactor(), _do_hash)
 
@@ -1027,18 +999,19 @@ class AuthHandler(BaseHandler):
         Returns:
             Deferred(bool): Whether self.hash(password) == stored_hash.
         """
+
         def _do_validate_hash():
             # Normalise the Unicode in the password
             pw = unicodedata.normalize("NFKC", password)
 
             return bcrypt.checkpw(
-                pw.encode('utf8') + self.hs.config.password_pepper.encode("utf8"),
-                stored_hash
+                pw.encode("utf8") + self.hs.config.password_pepper.encode("utf8"),
+                stored_hash,
             )
 
         if stored_hash:
             if not isinstance(stored_hash, bytes):
-                stored_hash = stored_hash.encode('ascii')
+                stored_hash = stored_hash.encode("ascii")
 
             return logcontext.defer_to_thread(self.hs.get_reactor(), _do_validate_hash)
         else:
@@ -1058,14 +1031,16 @@ class AuthHandler(BaseHandler):
                 for this user is too high too proceed.
         """
         self._failed_attempts_ratelimiter.ratelimit(
-            user_id.lower(), time_now_s=self._clock.time(),
+            user_id.lower(),
+            time_now_s=self._clock.time(),
             rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
             burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
             update=False,
         )
 
         self._account_ratelimiter.ratelimit(
-            user_id.lower(), time_now_s=self._clock.time(),
+            user_id.lower(),
+            time_now_s=self._clock.time(),
             rate_hz=self.hs.config.rc_login_account.per_second,
             burst_count=self.hs.config.rc_login_account.burst_count,
             update=True,
@@ -1083,9 +1058,9 @@ class MacaroonGenerator(object):
         macaroon.add_first_party_caveat("type = access")
         # Include a nonce, to make sure that each login gets a different
         # access token.
-        macaroon.add_first_party_caveat("nonce = %s" % (
-            stringutils.random_string_with_symbols(16),
-        ))
+        macaroon.add_first_party_caveat(
+            "nonce = %s" % (stringutils.random_string_with_symbols(16),)
+        )
         for caveat in extra_caveats:
             macaroon.add_first_party_caveat(caveat)
         return macaroon.serialize()
@@ -1116,7 +1091,8 @@ class MacaroonGenerator(object):
         macaroon = pymacaroons.Macaroon(
             location=self.hs.config.server_name,
             identifier="key",
-            key=self.hs.config.macaroon_secret_key)
+            key=self.hs.config.macaroon_secret_key,
+        )
         macaroon.add_first_party_caveat("gen = 1")
         macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
         return macaroon