diff options
Diffstat (limited to 'synapse/handlers/auth.py')
-rw-r--r-- | synapse/handlers/auth.py | 45 |
1 files changed, 24 insertions, 21 deletions
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 3ea6270083..bcd4249e09 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -29,6 +29,7 @@ from typing import ( Mapping, Optional, Tuple, + Type, Union, cast, ) @@ -439,7 +440,7 @@ class AuthHandler(BaseHandler): return ui_auth_types - def get_enabled_auth_types(self): + def get_enabled_auth_types(self) -> Iterable[str]: """Return the enabled user-interactive authentication types Returns the UI-Auth types which are supported by the homeserver's current @@ -702,7 +703,7 @@ class AuthHandler(BaseHandler): except StoreError: raise SynapseError(400, "Unknown session ID: %s" % (session_id,)) - async def _expire_old_sessions(self): + async def _expire_old_sessions(self) -> None: """ Invalidate any user interactive authentication sessions that have expired. """ @@ -1352,7 +1353,7 @@ class AuthHandler(BaseHandler): await self.auth.check_auth_blocking(res.user_id) return res - async def delete_access_token(self, access_token: str): + async def delete_access_token(self, access_token: str) -> None: """Invalidate a single access token Args: @@ -1381,7 +1382,7 @@ class AuthHandler(BaseHandler): user_id: str, except_token_id: Optional[int] = None, device_id: Optional[str] = None, - ): + ) -> None: """Invalidate access tokens belonging to a user Args: @@ -1409,7 +1410,7 @@ class AuthHandler(BaseHandler): async def add_threepid( self, user_id: str, medium: str, address: str, validated_at: int - ): + ) -> None: # check if medium has a valid value if medium not in ["email", "msisdn"]: raise SynapseError( @@ -1480,7 +1481,7 @@ class AuthHandler(BaseHandler): Hashed password. """ - def _do_hash(): + def _do_hash() -> str: # Normalise the Unicode in the password pw = unicodedata.normalize("NFKC", password) @@ -1504,7 +1505,7 @@ class AuthHandler(BaseHandler): Whether self.hash(password) == stored_hash. """ - def _do_validate_hash(checked_hash: bytes): + def _do_validate_hash(checked_hash: bytes) -> bool: # Normalise the Unicode in the password pw = unicodedata.normalize("NFKC", password) @@ -1581,7 +1582,7 @@ class AuthHandler(BaseHandler): client_redirect_url: str, extra_attributes: Optional[JsonDict] = None, new_user: bool = False, - ): + ) -> None: """Having figured out a mxid for this user, complete the HTTP request Args: @@ -1627,7 +1628,7 @@ class AuthHandler(BaseHandler): extra_attributes: Optional[JsonDict] = None, new_user: bool = False, user_profile_data: Optional[ProfileInfo] = None, - ): + ) -> None: """ The synchronous portion of complete_sso_login. @@ -1726,7 +1727,7 @@ class AuthHandler(BaseHandler): del self._extra_attributes[user_id] @staticmethod - def add_query_param_to_url(url: str, param_name: str, param: Any): + def add_query_param_to_url(url: str, param_name: str, param: Any) -> str: url_parts = list(urllib.parse.urlparse(url)) query = urllib.parse.parse_qsl(url_parts[4], keep_blank_values=True) query.append((param_name, param)) @@ -1734,9 +1735,9 @@ class AuthHandler(BaseHandler): return urllib.parse.urlunparse(url_parts) -@attr.s(slots=True) +@attr.s(slots=True, auto_attribs=True) class MacaroonGenerator: - hs = attr.ib() + hs: "HomeServer" def generate_guest_access_token(self, user_id: str) -> str: macaroon = self._generate_base_macaroon(user_id) @@ -1816,7 +1817,9 @@ class PasswordProvider: """ @classmethod - def load(cls, module, config, module_api: ModuleApi) -> "PasswordProvider": + def load( + cls, module: Type, config: JsonDict, module_api: ModuleApi + ) -> "PasswordProvider": try: pp = module(config=config, account_handler=module_api) except Exception as e: @@ -1824,7 +1827,7 @@ class PasswordProvider: raise return cls(pp, module_api) - def __init__(self, pp, module_api: ModuleApi): + def __init__(self, pp: "PasswordProvider", module_api: ModuleApi): self._pp = pp self._module_api = module_api @@ -1838,7 +1841,7 @@ class PasswordProvider: if g: self._supported_login_types.update(g()) - def __str__(self): + def __str__(self) -> str: return str(self._pp) def get_supported_login_types(self) -> Mapping[str, Iterable[str]]: @@ -1876,19 +1879,19 @@ class PasswordProvider: """ # first grandfather in a call to check_password if login_type == LoginType.PASSWORD: - g = getattr(self._pp, "check_password", None) - if g: + check_password = getattr(self._pp, "check_password", None) + if check_password: qualified_user_id = self._module_api.get_qualified_user_id(username) - is_valid = await self._pp.check_password( + is_valid = await check_password( qualified_user_id, login_dict["password"] ) if is_valid: return qualified_user_id, None - g = getattr(self._pp, "check_auth", None) - if not g: + check_auth = getattr(self._pp, "check_auth", None) + if not check_auth: return None - result = await g(username, login_type, login_dict) + result = await check_auth(username, login_type, login_dict) # Check if the return value is a str or a tuple if isinstance(result, str): |