diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index fbfbd44a2e..0aae929ecc 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -18,14 +18,12 @@ import logging
import time
import unicodedata
import urllib.parse
-from typing import Any, Dict, Iterable, List, Optional
+from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
import attr
import bcrypt # type: ignore[import]
import pymacaroons
-from twisted.internet import defer
-
import synapse.util.stringutils as stringutils
from synapse.api.constants import LoginType
from synapse.api.errors import (
@@ -170,15 +168,14 @@ class AuthHandler(BaseHandler):
# cast to tuple for use with str.startswith
self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist)
- @defer.inlineCallbacks
- def validate_user_via_ui_auth(
+ async def validate_user_via_ui_auth(
self,
requester: Requester,
request: SynapseRequest,
request_body: Dict[str, Any],
clientip: str,
description: str,
- ):
+ ) -> dict:
"""
Checks that the user is who they claim to be, via a UI auth.
@@ -199,7 +196,7 @@ class AuthHandler(BaseHandler):
describes the operation happening on their account.
Returns:
- defer.Deferred[dict]: the parameters for this request (which may
+ The parameters for this request (which may
have been given only in a previous call).
Raises:
@@ -229,7 +226,7 @@ class AuthHandler(BaseHandler):
flows = [[login_type] for login_type in self._supported_ui_auth_types]
try:
- result, params, _ = yield self.check_auth(
+ result, params, _ = await self.check_auth(
flows, request, request_body, clientip, description
)
except LoginError:
@@ -268,15 +265,14 @@ class AuthHandler(BaseHandler):
"""
return self.checkers.keys()
- @defer.inlineCallbacks
- def check_auth(
+ async def check_auth(
self,
flows: List[List[str]],
request: SynapseRequest,
clientdict: Dict[str, Any],
clientip: str,
description: str,
- ):
+ ) -> Tuple[dict, dict, str]:
"""
Takes a dictionary sent by the client in the login / registration
protocol and handles the User-Interactive Auth flow.
@@ -306,8 +302,7 @@ class AuthHandler(BaseHandler):
describes the operation happening on their account.
Returns:
- defer.Deferred[dict, dict, str]: a deferred tuple of
- (creds, params, session_id).
+ A tuple of (creds, params, session_id).
'creds' contains the authenticated credentials of each stage.
@@ -380,7 +375,7 @@ class AuthHandler(BaseHandler):
if "type" in authdict:
login_type = authdict["type"] # type: str
try:
- result = yield self._check_auth_dict(authdict, clientip)
+ result = await self._check_auth_dict(authdict, clientip)
if result:
creds[login_type] = result
self._save_session(session)
@@ -419,8 +414,9 @@ class AuthHandler(BaseHandler):
ret.update(errordict)
raise InteractiveAuthIncompleteError(ret)
- @defer.inlineCallbacks
- def add_oob_auth(self, stagetype: str, authdict: Dict[str, Any], clientip: str):
+ async def add_oob_auth(
+ self, stagetype: str, authdict: Dict[str, Any], clientip: str
+ ) -> bool:
"""
Adds the result of out-of-band authentication into an existing auth
session. Currently used for adding the result of fallback auth.
@@ -435,7 +431,7 @@ class AuthHandler(BaseHandler):
sess["creds"] = {}
creds = sess["creds"]
- result = yield self.checkers[stagetype].check_auth(authdict, clientip)
+ result = await self.checkers[stagetype].check_auth(authdict, clientip)
if result:
creds[stagetype] = result
self._save_session(sess)
@@ -489,8 +485,9 @@ class AuthHandler(BaseHandler):
sess = self._get_session_info(session_id)
return sess.setdefault("serverdict", {}).get(key, default)
- @defer.inlineCallbacks
- def _check_auth_dict(self, authdict: Dict[str, Any], clientip: str):
+ async def _check_auth_dict(
+ self, authdict: Dict[str, Any], clientip: str
+ ) -> Union[Dict[str, Any], str]:
"""Attempt to validate the auth dict provided by a client
Args:
@@ -498,7 +495,7 @@ class AuthHandler(BaseHandler):
clientip: IP address of the client
Returns:
- Deferred: result of the stage verification.
+ Result of the stage verification.
Raises:
StoreError if there was a problem accessing the database
@@ -508,7 +505,7 @@ class AuthHandler(BaseHandler):
login_type = authdict["type"]
checker = self.checkers.get(login_type)
if checker is not None:
- res = yield checker.check_auth(authdict, clientip=clientip)
+ res = await checker.check_auth(authdict, clientip=clientip)
return res
# build a v1-login-style dict out of the authdict and fall back to the
@@ -518,7 +515,7 @@ class AuthHandler(BaseHandler):
if user_id is None:
raise SynapseError(400, "", Codes.MISSING_PARAM)
- (canonical_id, callback) = yield self.validate_login(user_id, authdict)
+ (canonical_id, callback) = await self.validate_login(user_id, authdict)
return canonical_id
def _get_params_recaptcha(self) -> dict:
@@ -584,8 +581,7 @@ class AuthHandler(BaseHandler):
return self.sessions[session_id]
- @defer.inlineCallbacks
- def get_access_token_for_user_id(
+ async def get_access_token_for_user_id(
self, user_id: str, device_id: Optional[str], valid_until_ms: Optional[int]
):
"""
@@ -615,10 +611,10 @@ class AuthHandler(BaseHandler):
)
logger.info("Logging in user %s on device %s%s", user_id, device_id, fmt_expiry)
- yield self.auth.check_auth_blocking(user_id)
+ await self.auth.check_auth_blocking(user_id)
access_token = self.macaroon_gen.generate_access_token(user_id)
- yield self.store.add_access_token_to_user(
+ await self.store.add_access_token_to_user(
user_id, access_token, device_id, valid_until_ms
)
@@ -628,15 +624,14 @@ class AuthHandler(BaseHandler):
# device, so we double-check it here.
if device_id is not None:
try:
- yield self.store.get_device(user_id, device_id)
+ await self.store.get_device(user_id, device_id)
except StoreError:
- yield self.store.delete_access_token(access_token)
+ await self.store.delete_access_token(access_token)
raise StoreError(400, "Login raced against device deletion")
return access_token
- @defer.inlineCallbacks
- def check_user_exists(self, user_id: str):
+ async def check_user_exists(self, user_id: str) -> Optional[str]:
"""
Checks to see if a user with the given id exists. Will check case
insensitively, but return None if there are multiple inexact matches.
@@ -645,25 +640,25 @@ class AuthHandler(BaseHandler):
user_id: complete @user:id
Returns:
- defer.Deferred: (unicode) canonical_user_id, or None if zero or
- multiple matches
+ The canonical_user_id, or None if zero or multiple matches
"""
- res = yield self._find_user_id_and_pwd_hash(user_id)
+ res = await self._find_user_id_and_pwd_hash(user_id)
if res is not None:
return res[0]
return None
- @defer.inlineCallbacks
- def _find_user_id_and_pwd_hash(self, user_id: str):
+ async def _find_user_id_and_pwd_hash(
+ self, user_id: str
+ ) -> Optional[Tuple[str, str]]:
"""Checks to see if a user with the given id exists. Will check case
insensitively, but will return None if there are multiple inexact
matches.
Returns:
- tuple: A 2-tuple of `(canonical_user_id, password_hash)`
- None: if there is not exactly one match
+ A 2-tuple of `(canonical_user_id, password_hash)` or `None`
+ if there is not exactly one match
"""
- user_infos = yield self.store.get_users_by_id_case_insensitive(user_id)
+ user_infos = await self.store.get_users_by_id_case_insensitive(user_id)
result = None
if not user_infos:
@@ -696,8 +691,9 @@ class AuthHandler(BaseHandler):
"""
return self._supported_login_types
- @defer.inlineCallbacks
- def validate_login(self, username: str, login_submission: Dict[str, Any]):
+ async def validate_login(
+ self, username: str, login_submission: Dict[str, Any]
+ ) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]:
"""Authenticates the user for the /login API
Also used by the user-interactive auth flow to validate
@@ -708,7 +704,7 @@ class AuthHandler(BaseHandler):
login_submission: the whole of the login submission
(including 'type' and other relevant fields)
Returns:
- Deferred[str, func]: canonical user id, and optional callback
+ A tuple of the canonical user id, and optional callback
to be called once the access token and device id are issued
Raises:
StoreError if there was a problem accessing the database
@@ -737,7 +733,7 @@ class AuthHandler(BaseHandler):
for provider in self.password_providers:
if hasattr(provider, "check_password") and login_type == LoginType.PASSWORD:
known_login_type = True
- is_valid = yield provider.check_password(qualified_user_id, password)
+ is_valid = await provider.check_password(qualified_user_id, password)
if is_valid:
return qualified_user_id, None
@@ -769,7 +765,7 @@ class AuthHandler(BaseHandler):
% (login_type, missing_fields),
)
- result = yield provider.check_auth(username, login_type, login_dict)
+ result = await provider.check_auth(username, login_type, login_dict)
if result:
if isinstance(result, str):
result = (result, None)
@@ -778,8 +774,8 @@ class AuthHandler(BaseHandler):
if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled:
known_login_type = True
- canonical_user_id = yield self._check_local_password(
- qualified_user_id, password
+ canonical_user_id = await self._check_local_password(
+ qualified_user_id, password # type: ignore
)
if canonical_user_id:
@@ -792,8 +788,9 @@ class AuthHandler(BaseHandler):
# login, it turns all LoginErrors into a 401 anyway.
raise LoginError(403, "Invalid password", errcode=Codes.FORBIDDEN)
- @defer.inlineCallbacks
- def check_password_provider_3pid(self, medium: str, address: str, password: str):
+ async def check_password_provider_3pid(
+ self, medium: str, address: str, password: str
+ ) -> Tuple[Optional[str], Optional[Callable[[Dict[str, str]], None]]]:
"""Check if a password provider is able to validate a thirdparty login
Args:
@@ -802,9 +799,8 @@ class AuthHandler(BaseHandler):
password: The password of the user.
Returns:
- Deferred[(str|None, func|None)]: A tuple of `(user_id,
- callback)`. If authentication is successful, `user_id` is a `str`
- containing the authenticated, canonical user ID. `callback` is
+ A tuple of `(user_id, callback)`. If authentication is successful,
+ `user_id`is the authenticated, canonical user ID. `callback` is
then either a function to be later run after the server has
completed login/registration, or `None`. If authentication was
unsuccessful, `user_id` and `callback` are both `None`.
@@ -816,7 +812,7 @@ class AuthHandler(BaseHandler):
# success, to a str (which is the user_id) or a tuple of
# (user_id, callback_func), where callback_func should be run
# after we've finished everything else
- result = yield provider.check_3pid_auth(medium, address, password)
+ result = await provider.check_3pid_auth(medium, address, password)
if result:
# Check if the return value is a str or a tuple
if isinstance(result, str):
@@ -826,8 +822,7 @@ class AuthHandler(BaseHandler):
return None, None
- @defer.inlineCallbacks
- def _check_local_password(self, user_id: str, password: str):
+ async def _check_local_password(self, user_id: str, password: str) -> Optional[str]:
"""Authenticate a user against the local password database.
user_id is checked case insensitively, but will return None if there are
@@ -837,28 +832,26 @@ class AuthHandler(BaseHandler):
user_id: complete @user:id
password: the provided password
Returns:
- Deferred[unicode] the canonical_user_id, or Deferred[None] if
- unknown user/bad password
+ The canonical_user_id, or None if unknown user/bad password
"""
- lookupres = yield self._find_user_id_and_pwd_hash(user_id)
+ lookupres = await self._find_user_id_and_pwd_hash(user_id)
if not lookupres:
return None
(user_id, password_hash) = lookupres
# If the password hash is None, the account has likely been deactivated
if not password_hash:
- deactivated = yield self.store.get_user_deactivated_status(user_id)
+ deactivated = await self.store.get_user_deactivated_status(user_id)
if deactivated:
raise UserDeactivatedError("This account has been deactivated")
- result = yield self.validate_hash(password, password_hash)
+ result = await self.validate_hash(password, password_hash)
if not result:
logger.warning("Failed password login for user %s", user_id)
return None
return user_id
- @defer.inlineCallbacks
- def validate_short_term_login_token_and_get_user_id(self, login_token: str):
+ async def validate_short_term_login_token_and_get_user_id(self, login_token: str):
auth_api = self.hs.get_auth()
user_id = None
try:
@@ -868,26 +861,23 @@ class AuthHandler(BaseHandler):
except Exception:
raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
- yield self.auth.check_auth_blocking(user_id)
+ await self.auth.check_auth_blocking(user_id)
return user_id
- @defer.inlineCallbacks
- def delete_access_token(self, access_token: str):
+ async def delete_access_token(self, access_token: str):
"""Invalidate a single access token
Args:
access_token: access token to be deleted
- Returns:
- Deferred
"""
- user_info = yield self.auth.get_user_by_access_token(access_token)
- yield self.store.delete_access_token(access_token)
+ user_info = await self.auth.get_user_by_access_token(access_token)
+ await self.store.delete_access_token(access_token)
# see if any of our auth providers want to know about this
for provider in self.password_providers:
if hasattr(provider, "on_logged_out"):
- yield provider.on_logged_out(
+ await provider.on_logged_out(
user_id=str(user_info["user"]),
device_id=user_info["device_id"],
access_token=access_token,
@@ -895,12 +885,11 @@ class AuthHandler(BaseHandler):
# delete pushers associated with this access token
if user_info["token_id"] is not None:
- yield self.hs.get_pusherpool().remove_pushers_by_access_token(
+ await self.hs.get_pusherpool().remove_pushers_by_access_token(
str(user_info["user"]), (user_info["token_id"],)
)
- @defer.inlineCallbacks
- def delete_access_tokens_for_user(
+ async def delete_access_tokens_for_user(
self,
user_id: str,
except_token_id: Optional[str] = None,
@@ -914,10 +903,8 @@ class AuthHandler(BaseHandler):
device_id: ID of device the tokens are associated with.
If None, tokens associated with any device (or no device) will
be deleted
- Returns:
- Deferred
"""
- tokens_and_devices = yield self.store.user_delete_access_tokens(
+ tokens_and_devices = await self.store.user_delete_access_tokens(
user_id, except_token_id=except_token_id, device_id=device_id
)
@@ -925,17 +912,18 @@ class AuthHandler(BaseHandler):
for provider in self.password_providers:
if hasattr(provider, "on_logged_out"):
for token, token_id, device_id in tokens_and_devices:
- yield provider.on_logged_out(
+ await provider.on_logged_out(
user_id=user_id, device_id=device_id, access_token=token
)
# delete pushers associated with the access tokens
- yield self.hs.get_pusherpool().remove_pushers_by_access_token(
+ await self.hs.get_pusherpool().remove_pushers_by_access_token(
user_id, (token_id for _, token_id, _ in tokens_and_devices)
)
- @defer.inlineCallbacks
- def add_threepid(self, user_id: str, medium: str, address: str, validated_at: int):
+ async def add_threepid(
+ self, user_id: str, medium: str, address: str, validated_at: int
+ ):
# check if medium has a valid value
if medium not in ["email", "msisdn"]:
raise SynapseError(
@@ -956,14 +944,13 @@ class AuthHandler(BaseHandler):
if medium == "email":
address = address.lower()
- yield self.store.user_add_threepid(
+ await self.store.user_add_threepid(
user_id, medium, address, validated_at, self.hs.get_clock().time_msec()
)
- @defer.inlineCallbacks
- def delete_threepid(
+ async def delete_threepid(
self, user_id: str, medium: str, address: str, id_server: Optional[str] = None
- ):
+ ) -> bool:
"""Attempts to unbind the 3pid on the identity servers and deletes it
from the local database.
@@ -976,7 +963,7 @@ class AuthHandler(BaseHandler):
identity server specified when binding (if known).
Returns:
- Deferred[bool]: Returns True if successfully unbound the 3pid on
+ Returns True if successfully unbound the 3pid on
the identity server, False if identity server doesn't support the
unbind API.
"""
@@ -986,11 +973,11 @@ class AuthHandler(BaseHandler):
address = address.lower()
identity_handler = self.hs.get_handlers().identity_handler
- result = yield identity_handler.try_unbind_threepid(
+ result = await identity_handler.try_unbind_threepid(
user_id, {"medium": medium, "address": address, "id_server": id_server}
)
- yield self.store.user_delete_threepid(user_id, medium, address)
+ await self.store.user_delete_threepid(user_id, medium, address)
return result
def _save_session(self, session: Dict[str, Any]) -> None:
@@ -1000,14 +987,14 @@ class AuthHandler(BaseHandler):
session["last_used"] = self.hs.get_clock().time_msec()
self.sessions[session["id"]] = session
- def hash(self, password: str):
+ async def hash(self, password: str) -> str:
"""Computes a secure hash of password.
Args:
password: Password to hash.
Returns:
- Deferred(unicode): Hashed password.
+ Hashed password.
"""
def _do_hash():
@@ -1019,9 +1006,11 @@ class AuthHandler(BaseHandler):
bcrypt.gensalt(self.bcrypt_rounds),
).decode("ascii")
- return defer_to_thread(self.hs.get_reactor(), _do_hash)
+ return await defer_to_thread(self.hs.get_reactor(), _do_hash)
- def validate_hash(self, password: str, stored_hash: bytes):
+ async def validate_hash(
+ self, password: str, stored_hash: Union[bytes, str]
+ ) -> bool:
"""Validates that self.hash(password) == stored_hash.
Args:
@@ -1029,7 +1018,7 @@ class AuthHandler(BaseHandler):
stored_hash: Expected hash value.
Returns:
- Deferred(bool): Whether self.hash(password) == stored_hash.
+ Whether self.hash(password) == stored_hash.
"""
def _do_validate_hash():
@@ -1045,9 +1034,9 @@ class AuthHandler(BaseHandler):
if not isinstance(stored_hash, bytes):
stored_hash = stored_hash.encode("ascii")
- return defer_to_thread(self.hs.get_reactor(), _do_validate_hash)
+ return await defer_to_thread(self.hs.get_reactor(), _do_validate_hash)
else:
- return defer.succeed(False)
+ return False
def start_sso_ui_auth(self, redirect_url: str, session_id: str) -> str:
"""
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 993499f446..9bd941b5a0 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -338,8 +338,10 @@ class DeviceHandler(DeviceWorkerHandler):
else:
raise
- yield self._auth_handler.delete_access_tokens_for_user(
- user_id, device_id=device_id
+ yield defer.ensureDeferred(
+ self._auth_handler.delete_access_tokens_for_user(
+ user_id, device_id=device_id
+ )
)
yield self.store.delete_e2e_keys_by_device(user_id=user_id, device_id=device_id)
@@ -391,8 +393,10 @@ class DeviceHandler(DeviceWorkerHandler):
# Delete access tokens and e2e keys for each device. Not optimised as it is not
# considered as part of a critical path.
for device_id in device_ids:
- yield self._auth_handler.delete_access_tokens_for_user(
- user_id, device_id=device_id
+ yield defer.ensureDeferred(
+ self._auth_handler.delete_access_tokens_for_user(
+ user_id, device_id=device_id
+ )
)
yield self.store.delete_e2e_keys_by_device(
user_id=user_id, device_id=device_id
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 7ffc194f0c..3a65b46ecd 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -166,7 +166,9 @@ class RegistrationHandler(BaseHandler):
yield self.auth.check_auth_blocking(threepid=threepid)
password_hash = None
if password:
- password_hash = yield self._auth_handler.hash(password)
+ password_hash = yield defer.ensureDeferred(
+ self._auth_handler.hash(password)
+ )
if localpart is not None:
yield self.check_username(localpart, guest_access_token=guest_access_token)
@@ -540,8 +542,10 @@ class RegistrationHandler(BaseHandler):
user_id, ["guest = true"]
)
else:
- access_token = yield self._auth_handler.get_access_token_for_user_id(
- user_id, device_id=device_id, valid_until_ms=valid_until_ms
+ access_token = yield defer.ensureDeferred(
+ self._auth_handler.get_access_token_for_user_id(
+ user_id, device_id=device_id, valid_until_ms=valid_until_ms
+ )
)
return (device_id, access_token)
@@ -617,8 +621,13 @@ class RegistrationHandler(BaseHandler):
logger.info("Can't add incomplete 3pid")
return
- yield self._auth_handler.add_threepid(
- user_id, threepid["medium"], threepid["address"], threepid["validated_at"]
+ yield defer.ensureDeferred(
+ self._auth_handler.add_threepid(
+ user_id,
+ threepid["medium"],
+ threepid["address"],
+ threepid["validated_at"],
+ )
)
# And we add an email pusher for them by default, but only
@@ -670,6 +679,11 @@ class RegistrationHandler(BaseHandler):
return None
raise
- yield self._auth_handler.add_threepid(
- user_id, threepid["medium"], threepid["address"], threepid["validated_at"]
+ yield defer.ensureDeferred(
+ self._auth_handler.add_threepid(
+ user_id,
+ threepid["medium"],
+ threepid["address"],
+ threepid["validated_at"],
+ )
)
diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py
index 7d1263caf2..63d8f9aa0d 100644
--- a/synapse/handlers/set_password.py
+++ b/synapse/handlers/set_password.py
@@ -15,8 +15,6 @@
import logging
from typing import Optional
-from twisted.internet import defer
-
from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.types import Requester
@@ -34,8 +32,7 @@ class SetPasswordHandler(BaseHandler):
self._device_handler = hs.get_device_handler()
self._password_policy_handler = hs.get_password_policy_handler()
- @defer.inlineCallbacks
- def set_password(
+ async def set_password(
self,
user_id: str,
new_password: str,
@@ -46,10 +43,10 @@ class SetPasswordHandler(BaseHandler):
raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN)
self._password_policy_handler.validate_password(new_password)
- password_hash = yield self._auth_handler.hash(new_password)
+ password_hash = await self._auth_handler.hash(new_password)
try:
- yield self.store.user_set_password_hash(user_id, password_hash)
+ await self.store.user_set_password_hash(user_id, password_hash)
except StoreError as e:
if e.code == 404:
raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
@@ -61,12 +58,12 @@ class SetPasswordHandler(BaseHandler):
except_access_token_id = requester.access_token_id if requester else None
# First delete all of their other devices.
- yield self._device_handler.delete_all_devices_for_user(
+ await self._device_handler.delete_all_devices_for_user(
user_id, except_device_id=except_device_id
)
# and now delete any access tokens which weren't associated with
# devices (or were associated with this device).
- yield self._auth_handler.delete_access_tokens_for_user(
+ await self._auth_handler.delete_access_tokens_for_user(
user_id, except_token_id=except_access_token_id
)
|