diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index fbbf6fd834..bcd4249e09 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -29,6 +29,7 @@ from typing import (
Mapping,
Optional,
Tuple,
+ Type,
Union,
cast,
)
@@ -439,7 +440,7 @@ class AuthHandler(BaseHandler):
return ui_auth_types
- def get_enabled_auth_types(self):
+ def get_enabled_auth_types(self) -> Iterable[str]:
"""Return the enabled user-interactive authentication types
Returns the UI-Auth types which are supported by the homeserver's current
@@ -702,7 +703,7 @@ class AuthHandler(BaseHandler):
except StoreError:
raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
- async def _expire_old_sessions(self):
+ async def _expire_old_sessions(self) -> None:
"""
Invalidate any user interactive authentication sessions that have expired.
"""
@@ -1347,12 +1348,12 @@ class AuthHandler(BaseHandler):
try:
res = self.macaroon_gen.verify_short_term_login_token(login_token)
except Exception:
- raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
+ raise AuthError(403, "Invalid login token", errcode=Codes.FORBIDDEN)
await self.auth.check_auth_blocking(res.user_id)
return res
- async def delete_access_token(self, access_token: str):
+ async def delete_access_token(self, access_token: str) -> None:
"""Invalidate a single access token
Args:
@@ -1381,7 +1382,7 @@ class AuthHandler(BaseHandler):
user_id: str,
except_token_id: Optional[int] = None,
device_id: Optional[str] = None,
- ):
+ ) -> None:
"""Invalidate access tokens belonging to a user
Args:
@@ -1409,7 +1410,7 @@ class AuthHandler(BaseHandler):
async def add_threepid(
self, user_id: str, medium: str, address: str, validated_at: int
- ):
+ ) -> None:
# check if medium has a valid value
if medium not in ["email", "msisdn"]:
raise SynapseError(
@@ -1480,7 +1481,7 @@ class AuthHandler(BaseHandler):
Hashed password.
"""
- def _do_hash():
+ def _do_hash() -> str:
# Normalise the Unicode in the password
pw = unicodedata.normalize("NFKC", password)
@@ -1504,7 +1505,7 @@ class AuthHandler(BaseHandler):
Whether self.hash(password) == stored_hash.
"""
- def _do_validate_hash(checked_hash: bytes):
+ def _do_validate_hash(checked_hash: bytes) -> bool:
# Normalise the Unicode in the password
pw = unicodedata.normalize("NFKC", password)
@@ -1581,7 +1582,7 @@ class AuthHandler(BaseHandler):
client_redirect_url: str,
extra_attributes: Optional[JsonDict] = None,
new_user: bool = False,
- ):
+ ) -> None:
"""Having figured out a mxid for this user, complete the HTTP request
Args:
@@ -1627,7 +1628,7 @@ class AuthHandler(BaseHandler):
extra_attributes: Optional[JsonDict] = None,
new_user: bool = False,
user_profile_data: Optional[ProfileInfo] = None,
- ):
+ ) -> None:
"""
The synchronous portion of complete_sso_login.
@@ -1726,7 +1727,7 @@ class AuthHandler(BaseHandler):
del self._extra_attributes[user_id]
@staticmethod
- def add_query_param_to_url(url: str, param_name: str, param: Any):
+ def add_query_param_to_url(url: str, param_name: str, param: Any) -> str:
url_parts = list(urllib.parse.urlparse(url))
query = urllib.parse.parse_qsl(url_parts[4], keep_blank_values=True)
query.append((param_name, param))
@@ -1734,9 +1735,9 @@ class AuthHandler(BaseHandler):
return urllib.parse.urlunparse(url_parts)
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class MacaroonGenerator:
- hs = attr.ib()
+ hs: "HomeServer"
def generate_guest_access_token(self, user_id: str) -> str:
macaroon = self._generate_base_macaroon(user_id)
@@ -1816,7 +1817,9 @@ class PasswordProvider:
"""
@classmethod
- def load(cls, module, config, module_api: ModuleApi) -> "PasswordProvider":
+ def load(
+ cls, module: Type, config: JsonDict, module_api: ModuleApi
+ ) -> "PasswordProvider":
try:
pp = module(config=config, account_handler=module_api)
except Exception as e:
@@ -1824,7 +1827,7 @@ class PasswordProvider:
raise
return cls(pp, module_api)
- def __init__(self, pp, module_api: ModuleApi):
+ def __init__(self, pp: "PasswordProvider", module_api: ModuleApi):
self._pp = pp
self._module_api = module_api
@@ -1838,7 +1841,7 @@ class PasswordProvider:
if g:
self._supported_login_types.update(g())
- def __str__(self):
+ def __str__(self) -> str:
return str(self._pp)
def get_supported_login_types(self) -> Mapping[str, Iterable[str]]:
@@ -1876,19 +1879,19 @@ class PasswordProvider:
"""
# first grandfather in a call to check_password
if login_type == LoginType.PASSWORD:
- g = getattr(self._pp, "check_password", None)
- if g:
+ check_password = getattr(self._pp, "check_password", None)
+ if check_password:
qualified_user_id = self._module_api.get_qualified_user_id(username)
- is_valid = await self._pp.check_password(
+ is_valid = await check_password(
qualified_user_id, login_dict["password"]
)
if is_valid:
return qualified_user_id, None
- g = getattr(self._pp, "check_auth", None)
- if not g:
+ check_auth = getattr(self._pp, "check_auth", None)
+ if not check_auth:
return None
- result = await g(username, login_type, login_dict)
+ result = await check_auth(username, login_type, login_dict)
# Check if the return value is a str or a tuple
if isinstance(result, str):
|