From eed7c5b89eee6951ac17861b1695817470bace36 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 15 Apr 2020 12:40:18 -0400 Subject: Convert auth handler to async/await (#7261) --- synapse/handlers/auth.py | 173 ++++++++++++++++++--------------------- synapse/handlers/device.py | 12 ++- synapse/handlers/register.py | 28 +++++-- synapse/handlers/set_password.py | 13 ++- synapse/module_api/__init__.py | 6 +- 5 files changed, 119 insertions(+), 113 deletions(-) (limited to 'synapse') diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index fbfbd44a2e..0aae929ecc 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -18,14 +18,12 @@ import logging import time import unicodedata import urllib.parse -from typing import Any, Dict, Iterable, List, Optional +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union import attr import bcrypt # type: ignore[import] import pymacaroons -from twisted.internet import defer - import synapse.util.stringutils as stringutils from synapse.api.constants import LoginType from synapse.api.errors import ( @@ -170,15 +168,14 @@ class AuthHandler(BaseHandler): # cast to tuple for use with str.startswith self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist) - @defer.inlineCallbacks - def validate_user_via_ui_auth( + async def validate_user_via_ui_auth( self, requester: Requester, request: SynapseRequest, request_body: Dict[str, Any], clientip: str, description: str, - ): + ) -> dict: """ Checks that the user is who they claim to be, via a UI auth. @@ -199,7 +196,7 @@ class AuthHandler(BaseHandler): describes the operation happening on their account. Returns: - defer.Deferred[dict]: the parameters for this request (which may + The parameters for this request (which may have been given only in a previous call). Raises: @@ -229,7 +226,7 @@ class AuthHandler(BaseHandler): flows = [[login_type] for login_type in self._supported_ui_auth_types] try: - result, params, _ = yield self.check_auth( + result, params, _ = await self.check_auth( flows, request, request_body, clientip, description ) except LoginError: @@ -268,15 +265,14 @@ class AuthHandler(BaseHandler): """ return self.checkers.keys() - @defer.inlineCallbacks - def check_auth( + async def check_auth( self, flows: List[List[str]], request: SynapseRequest, clientdict: Dict[str, Any], clientip: str, description: str, - ): + ) -> Tuple[dict, dict, str]: """ Takes a dictionary sent by the client in the login / registration protocol and handles the User-Interactive Auth flow. @@ -306,8 +302,7 @@ class AuthHandler(BaseHandler): describes the operation happening on their account. Returns: - defer.Deferred[dict, dict, str]: a deferred tuple of - (creds, params, session_id). + A tuple of (creds, params, session_id). 'creds' contains the authenticated credentials of each stage. @@ -380,7 +375,7 @@ class AuthHandler(BaseHandler): if "type" in authdict: login_type = authdict["type"] # type: str try: - result = yield self._check_auth_dict(authdict, clientip) + result = await self._check_auth_dict(authdict, clientip) if result: creds[login_type] = result self._save_session(session) @@ -419,8 +414,9 @@ class AuthHandler(BaseHandler): ret.update(errordict) raise InteractiveAuthIncompleteError(ret) - @defer.inlineCallbacks - def add_oob_auth(self, stagetype: str, authdict: Dict[str, Any], clientip: str): + async def add_oob_auth( + self, stagetype: str, authdict: Dict[str, Any], clientip: str + ) -> bool: """ Adds the result of out-of-band authentication into an existing auth session. Currently used for adding the result of fallback auth. @@ -435,7 +431,7 @@ class AuthHandler(BaseHandler): sess["creds"] = {} creds = sess["creds"] - result = yield self.checkers[stagetype].check_auth(authdict, clientip) + result = await self.checkers[stagetype].check_auth(authdict, clientip) if result: creds[stagetype] = result self._save_session(sess) @@ -489,8 +485,9 @@ class AuthHandler(BaseHandler): sess = self._get_session_info(session_id) return sess.setdefault("serverdict", {}).get(key, default) - @defer.inlineCallbacks - def _check_auth_dict(self, authdict: Dict[str, Any], clientip: str): + async def _check_auth_dict( + self, authdict: Dict[str, Any], clientip: str + ) -> Union[Dict[str, Any], str]: """Attempt to validate the auth dict provided by a client Args: @@ -498,7 +495,7 @@ class AuthHandler(BaseHandler): clientip: IP address of the client Returns: - Deferred: result of the stage verification. + Result of the stage verification. Raises: StoreError if there was a problem accessing the database @@ -508,7 +505,7 @@ class AuthHandler(BaseHandler): login_type = authdict["type"] checker = self.checkers.get(login_type) if checker is not None: - res = yield checker.check_auth(authdict, clientip=clientip) + res = await checker.check_auth(authdict, clientip=clientip) return res # build a v1-login-style dict out of the authdict and fall back to the @@ -518,7 +515,7 @@ class AuthHandler(BaseHandler): if user_id is None: raise SynapseError(400, "", Codes.MISSING_PARAM) - (canonical_id, callback) = yield self.validate_login(user_id, authdict) + (canonical_id, callback) = await self.validate_login(user_id, authdict) return canonical_id def _get_params_recaptcha(self) -> dict: @@ -584,8 +581,7 @@ class AuthHandler(BaseHandler): return self.sessions[session_id] - @defer.inlineCallbacks - def get_access_token_for_user_id( + async def get_access_token_for_user_id( self, user_id: str, device_id: Optional[str], valid_until_ms: Optional[int] ): """ @@ -615,10 +611,10 @@ class AuthHandler(BaseHandler): ) logger.info("Logging in user %s on device %s%s", user_id, device_id, fmt_expiry) - yield self.auth.check_auth_blocking(user_id) + await self.auth.check_auth_blocking(user_id) access_token = self.macaroon_gen.generate_access_token(user_id) - yield self.store.add_access_token_to_user( + await self.store.add_access_token_to_user( user_id, access_token, device_id, valid_until_ms ) @@ -628,15 +624,14 @@ class AuthHandler(BaseHandler): # device, so we double-check it here. if device_id is not None: try: - yield self.store.get_device(user_id, device_id) + await self.store.get_device(user_id, device_id) except StoreError: - yield self.store.delete_access_token(access_token) + await self.store.delete_access_token(access_token) raise StoreError(400, "Login raced against device deletion") return access_token - @defer.inlineCallbacks - def check_user_exists(self, user_id: str): + async def check_user_exists(self, user_id: str) -> Optional[str]: """ Checks to see if a user with the given id exists. Will check case insensitively, but return None if there are multiple inexact matches. @@ -645,25 +640,25 @@ class AuthHandler(BaseHandler): user_id: complete @user:id Returns: - defer.Deferred: (unicode) canonical_user_id, or None if zero or - multiple matches + The canonical_user_id, or None if zero or multiple matches """ - res = yield self._find_user_id_and_pwd_hash(user_id) + res = await self._find_user_id_and_pwd_hash(user_id) if res is not None: return res[0] return None - @defer.inlineCallbacks - def _find_user_id_and_pwd_hash(self, user_id: str): + async def _find_user_id_and_pwd_hash( + self, user_id: str + ) -> Optional[Tuple[str, str]]: """Checks to see if a user with the given id exists. Will check case insensitively, but will return None if there are multiple inexact matches. Returns: - tuple: A 2-tuple of `(canonical_user_id, password_hash)` - None: if there is not exactly one match + A 2-tuple of `(canonical_user_id, password_hash)` or `None` + if there is not exactly one match """ - user_infos = yield self.store.get_users_by_id_case_insensitive(user_id) + user_infos = await self.store.get_users_by_id_case_insensitive(user_id) result = None if not user_infos: @@ -696,8 +691,9 @@ class AuthHandler(BaseHandler): """ return self._supported_login_types - @defer.inlineCallbacks - def validate_login(self, username: str, login_submission: Dict[str, Any]): + async def validate_login( + self, username: str, login_submission: Dict[str, Any] + ) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]: """Authenticates the user for the /login API Also used by the user-interactive auth flow to validate @@ -708,7 +704,7 @@ class AuthHandler(BaseHandler): login_submission: the whole of the login submission (including 'type' and other relevant fields) Returns: - Deferred[str, func]: canonical user id, and optional callback + A tuple of the canonical user id, and optional callback to be called once the access token and device id are issued Raises: StoreError if there was a problem accessing the database @@ -737,7 +733,7 @@ class AuthHandler(BaseHandler): for provider in self.password_providers: if hasattr(provider, "check_password") and login_type == LoginType.PASSWORD: known_login_type = True - is_valid = yield provider.check_password(qualified_user_id, password) + is_valid = await provider.check_password(qualified_user_id, password) if is_valid: return qualified_user_id, None @@ -769,7 +765,7 @@ class AuthHandler(BaseHandler): % (login_type, missing_fields), ) - result = yield provider.check_auth(username, login_type, login_dict) + result = await provider.check_auth(username, login_type, login_dict) if result: if isinstance(result, str): result = (result, None) @@ -778,8 +774,8 @@ class AuthHandler(BaseHandler): if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled: known_login_type = True - canonical_user_id = yield self._check_local_password( - qualified_user_id, password + canonical_user_id = await self._check_local_password( + qualified_user_id, password # type: ignore ) if canonical_user_id: @@ -792,8 +788,9 @@ class AuthHandler(BaseHandler): # login, it turns all LoginErrors into a 401 anyway. raise LoginError(403, "Invalid password", errcode=Codes.FORBIDDEN) - @defer.inlineCallbacks - def check_password_provider_3pid(self, medium: str, address: str, password: str): + async def check_password_provider_3pid( + self, medium: str, address: str, password: str + ) -> Tuple[Optional[str], Optional[Callable[[Dict[str, str]], None]]]: """Check if a password provider is able to validate a thirdparty login Args: @@ -802,9 +799,8 @@ class AuthHandler(BaseHandler): password: The password of the user. Returns: - Deferred[(str|None, func|None)]: A tuple of `(user_id, - callback)`. If authentication is successful, `user_id` is a `str` - containing the authenticated, canonical user ID. `callback` is + A tuple of `(user_id, callback)`. If authentication is successful, + `user_id`is the authenticated, canonical user ID. `callback` is then either a function to be later run after the server has completed login/registration, or `None`. If authentication was unsuccessful, `user_id` and `callback` are both `None`. @@ -816,7 +812,7 @@ class AuthHandler(BaseHandler): # success, to a str (which is the user_id) or a tuple of # (user_id, callback_func), where callback_func should be run # after we've finished everything else - result = yield provider.check_3pid_auth(medium, address, password) + result = await provider.check_3pid_auth(medium, address, password) if result: # Check if the return value is a str or a tuple if isinstance(result, str): @@ -826,8 +822,7 @@ class AuthHandler(BaseHandler): return None, None - @defer.inlineCallbacks - def _check_local_password(self, user_id: str, password: str): + async def _check_local_password(self, user_id: str, password: str) -> Optional[str]: """Authenticate a user against the local password database. user_id is checked case insensitively, but will return None if there are @@ -837,28 +832,26 @@ class AuthHandler(BaseHandler): user_id: complete @user:id password: the provided password Returns: - Deferred[unicode] the canonical_user_id, or Deferred[None] if - unknown user/bad password + The canonical_user_id, or None if unknown user/bad password """ - lookupres = yield self._find_user_id_and_pwd_hash(user_id) + lookupres = await self._find_user_id_and_pwd_hash(user_id) if not lookupres: return None (user_id, password_hash) = lookupres # If the password hash is None, the account has likely been deactivated if not password_hash: - deactivated = yield self.store.get_user_deactivated_status(user_id) + deactivated = await self.store.get_user_deactivated_status(user_id) if deactivated: raise UserDeactivatedError("This account has been deactivated") - result = yield self.validate_hash(password, password_hash) + result = await self.validate_hash(password, password_hash) if not result: logger.warning("Failed password login for user %s", user_id) return None return user_id - @defer.inlineCallbacks - def validate_short_term_login_token_and_get_user_id(self, login_token: str): + async def validate_short_term_login_token_and_get_user_id(self, login_token: str): auth_api = self.hs.get_auth() user_id = None try: @@ -868,26 +861,23 @@ class AuthHandler(BaseHandler): except Exception: raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN) - yield self.auth.check_auth_blocking(user_id) + await self.auth.check_auth_blocking(user_id) return user_id - @defer.inlineCallbacks - def delete_access_token(self, access_token: str): + async def delete_access_token(self, access_token: str): """Invalidate a single access token Args: access_token: access token to be deleted - Returns: - Deferred """ - user_info = yield self.auth.get_user_by_access_token(access_token) - yield self.store.delete_access_token(access_token) + user_info = await self.auth.get_user_by_access_token(access_token) + await self.store.delete_access_token(access_token) # see if any of our auth providers want to know about this for provider in self.password_providers: if hasattr(provider, "on_logged_out"): - yield provider.on_logged_out( + await provider.on_logged_out( user_id=str(user_info["user"]), device_id=user_info["device_id"], access_token=access_token, @@ -895,12 +885,11 @@ class AuthHandler(BaseHandler): # delete pushers associated with this access token if user_info["token_id"] is not None: - yield self.hs.get_pusherpool().remove_pushers_by_access_token( + await self.hs.get_pusherpool().remove_pushers_by_access_token( str(user_info["user"]), (user_info["token_id"],) ) - @defer.inlineCallbacks - def delete_access_tokens_for_user( + async def delete_access_tokens_for_user( self, user_id: str, except_token_id: Optional[str] = None, @@ -914,10 +903,8 @@ class AuthHandler(BaseHandler): device_id: ID of device the tokens are associated with. If None, tokens associated with any device (or no device) will be deleted - Returns: - Deferred """ - tokens_and_devices = yield self.store.user_delete_access_tokens( + tokens_and_devices = await self.store.user_delete_access_tokens( user_id, except_token_id=except_token_id, device_id=device_id ) @@ -925,17 +912,18 @@ class AuthHandler(BaseHandler): for provider in self.password_providers: if hasattr(provider, "on_logged_out"): for token, token_id, device_id in tokens_and_devices: - yield provider.on_logged_out( + await provider.on_logged_out( user_id=user_id, device_id=device_id, access_token=token ) # delete pushers associated with the access tokens - yield self.hs.get_pusherpool().remove_pushers_by_access_token( + await self.hs.get_pusherpool().remove_pushers_by_access_token( user_id, (token_id for _, token_id, _ in tokens_and_devices) ) - @defer.inlineCallbacks - def add_threepid(self, user_id: str, medium: str, address: str, validated_at: int): + async def add_threepid( + self, user_id: str, medium: str, address: str, validated_at: int + ): # check if medium has a valid value if medium not in ["email", "msisdn"]: raise SynapseError( @@ -956,14 +944,13 @@ class AuthHandler(BaseHandler): if medium == "email": address = address.lower() - yield self.store.user_add_threepid( + await self.store.user_add_threepid( user_id, medium, address, validated_at, self.hs.get_clock().time_msec() ) - @defer.inlineCallbacks - def delete_threepid( + async def delete_threepid( self, user_id: str, medium: str, address: str, id_server: Optional[str] = None - ): + ) -> bool: """Attempts to unbind the 3pid on the identity servers and deletes it from the local database. @@ -976,7 +963,7 @@ class AuthHandler(BaseHandler): identity server specified when binding (if known). Returns: - Deferred[bool]: Returns True if successfully unbound the 3pid on + Returns True if successfully unbound the 3pid on the identity server, False if identity server doesn't support the unbind API. """ @@ -986,11 +973,11 @@ class AuthHandler(BaseHandler): address = address.lower() identity_handler = self.hs.get_handlers().identity_handler - result = yield identity_handler.try_unbind_threepid( + result = await identity_handler.try_unbind_threepid( user_id, {"medium": medium, "address": address, "id_server": id_server} ) - yield self.store.user_delete_threepid(user_id, medium, address) + await self.store.user_delete_threepid(user_id, medium, address) return result def _save_session(self, session: Dict[str, Any]) -> None: @@ -1000,14 +987,14 @@ class AuthHandler(BaseHandler): session["last_used"] = self.hs.get_clock().time_msec() self.sessions[session["id"]] = session - def hash(self, password: str): + async def hash(self, password: str) -> str: """Computes a secure hash of password. Args: password: Password to hash. Returns: - Deferred(unicode): Hashed password. + Hashed password. """ def _do_hash(): @@ -1019,9 +1006,11 @@ class AuthHandler(BaseHandler): bcrypt.gensalt(self.bcrypt_rounds), ).decode("ascii") - return defer_to_thread(self.hs.get_reactor(), _do_hash) + return await defer_to_thread(self.hs.get_reactor(), _do_hash) - def validate_hash(self, password: str, stored_hash: bytes): + async def validate_hash( + self, password: str, stored_hash: Union[bytes, str] + ) -> bool: """Validates that self.hash(password) == stored_hash. Args: @@ -1029,7 +1018,7 @@ class AuthHandler(BaseHandler): stored_hash: Expected hash value. Returns: - Deferred(bool): Whether self.hash(password) == stored_hash. + Whether self.hash(password) == stored_hash. """ def _do_validate_hash(): @@ -1045,9 +1034,9 @@ class AuthHandler(BaseHandler): if not isinstance(stored_hash, bytes): stored_hash = stored_hash.encode("ascii") - return defer_to_thread(self.hs.get_reactor(), _do_validate_hash) + return await defer_to_thread(self.hs.get_reactor(), _do_validate_hash) else: - return defer.succeed(False) + return False def start_sso_ui_auth(self, redirect_url: str, session_id: str) -> str: """ diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 993499f446..9bd941b5a0 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -338,8 +338,10 @@ class DeviceHandler(DeviceWorkerHandler): else: raise - yield self._auth_handler.delete_access_tokens_for_user( - user_id, device_id=device_id + yield defer.ensureDeferred( + self._auth_handler.delete_access_tokens_for_user( + user_id, device_id=device_id + ) ) yield self.store.delete_e2e_keys_by_device(user_id=user_id, device_id=device_id) @@ -391,8 +393,10 @@ class DeviceHandler(DeviceWorkerHandler): # Delete access tokens and e2e keys for each device. Not optimised as it is not # considered as part of a critical path. for device_id in device_ids: - yield self._auth_handler.delete_access_tokens_for_user( - user_id, device_id=device_id + yield defer.ensureDeferred( + self._auth_handler.delete_access_tokens_for_user( + user_id, device_id=device_id + ) ) yield self.store.delete_e2e_keys_by_device( user_id=user_id, device_id=device_id diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 7ffc194f0c..3a65b46ecd 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -166,7 +166,9 @@ class RegistrationHandler(BaseHandler): yield self.auth.check_auth_blocking(threepid=threepid) password_hash = None if password: - password_hash = yield self._auth_handler.hash(password) + password_hash = yield defer.ensureDeferred( + self._auth_handler.hash(password) + ) if localpart is not None: yield self.check_username(localpart, guest_access_token=guest_access_token) @@ -540,8 +542,10 @@ class RegistrationHandler(BaseHandler): user_id, ["guest = true"] ) else: - access_token = yield self._auth_handler.get_access_token_for_user_id( - user_id, device_id=device_id, valid_until_ms=valid_until_ms + access_token = yield defer.ensureDeferred( + self._auth_handler.get_access_token_for_user_id( + user_id, device_id=device_id, valid_until_ms=valid_until_ms + ) ) return (device_id, access_token) @@ -617,8 +621,13 @@ class RegistrationHandler(BaseHandler): logger.info("Can't add incomplete 3pid") return - yield self._auth_handler.add_threepid( - user_id, threepid["medium"], threepid["address"], threepid["validated_at"] + yield defer.ensureDeferred( + self._auth_handler.add_threepid( + user_id, + threepid["medium"], + threepid["address"], + threepid["validated_at"], + ) ) # And we add an email pusher for them by default, but only @@ -670,6 +679,11 @@ class RegistrationHandler(BaseHandler): return None raise - yield self._auth_handler.add_threepid( - user_id, threepid["medium"], threepid["address"], threepid["validated_at"] + yield defer.ensureDeferred( + self._auth_handler.add_threepid( + user_id, + threepid["medium"], + threepid["address"], + threepid["validated_at"], + ) ) diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py index 7d1263caf2..63d8f9aa0d 100644 --- a/synapse/handlers/set_password.py +++ b/synapse/handlers/set_password.py @@ -15,8 +15,6 @@ import logging from typing import Optional -from twisted.internet import defer - from synapse.api.errors import Codes, StoreError, SynapseError from synapse.types import Requester @@ -34,8 +32,7 @@ class SetPasswordHandler(BaseHandler): self._device_handler = hs.get_device_handler() self._password_policy_handler = hs.get_password_policy_handler() - @defer.inlineCallbacks - def set_password( + async def set_password( self, user_id: str, new_password: str, @@ -46,10 +43,10 @@ class SetPasswordHandler(BaseHandler): raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN) self._password_policy_handler.validate_password(new_password) - password_hash = yield self._auth_handler.hash(new_password) + password_hash = await self._auth_handler.hash(new_password) try: - yield self.store.user_set_password_hash(user_id, password_hash) + await self.store.user_set_password_hash(user_id, password_hash) except StoreError as e: if e.code == 404: raise SynapseError(404, "Unknown user", Codes.NOT_FOUND) @@ -61,12 +58,12 @@ class SetPasswordHandler(BaseHandler): except_access_token_id = requester.access_token_id if requester else None # First delete all of their other devices. - yield self._device_handler.delete_all_devices_for_user( + await self._device_handler.delete_all_devices_for_user( user_id, except_device_id=except_device_id ) # and now delete any access tokens which weren't associated with # devices (or were associated with this device). - yield self._auth_handler.delete_access_tokens_for_user( + await self._auth_handler.delete_access_tokens_for_user( user_id, except_token_id=except_access_token_id ) diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index afc3598e11..d678c0eb9b 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -86,7 +86,7 @@ class ModuleApi(object): Deferred[str|None]: Canonical (case-corrected) user_id, or None if the user is not registered. """ - return self._auth_handler.check_user_exists(user_id) + return defer.ensureDeferred(self._auth_handler.check_user_exists(user_id)) @defer.inlineCallbacks def register(self, localpart, displayname=None, emails=[]): @@ -196,7 +196,9 @@ class ModuleApi(object): yield self._hs.get_device_handler().delete_device(user_id, device_id) else: # no associated device. Just delete the access token. - yield self._auth_handler.delete_access_token(access_token) + yield defer.ensureDeferred( + self._auth_handler.delete_access_token(access_token) + ) def run_db_interaction(self, desc, func, *args, **kwargs): """Run a function with a database connection -- cgit 1.4.1