summary refs log tree commit diff
path: root/synapse/handlers/auth.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/auth.py')
-rw-r--r--synapse/handlers/auth.py45
1 files changed, 24 insertions, 21 deletions
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 3ea6270083..bcd4249e09 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -29,6 +29,7 @@ from typing import (
     Mapping,
     Optional,
     Tuple,
+    Type,
     Union,
     cast,
 )
@@ -439,7 +440,7 @@ class AuthHandler(BaseHandler):
 
         return ui_auth_types
 
-    def get_enabled_auth_types(self):
+    def get_enabled_auth_types(self) -> Iterable[str]:
         """Return the enabled user-interactive authentication types
 
         Returns the UI-Auth types which are supported by the homeserver's current
@@ -702,7 +703,7 @@ class AuthHandler(BaseHandler):
         except StoreError:
             raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
 
-    async def _expire_old_sessions(self):
+    async def _expire_old_sessions(self) -> None:
         """
         Invalidate any user interactive authentication sessions that have expired.
         """
@@ -1352,7 +1353,7 @@ class AuthHandler(BaseHandler):
         await self.auth.check_auth_blocking(res.user_id)
         return res
 
-    async def delete_access_token(self, access_token: str):
+    async def delete_access_token(self, access_token: str) -> None:
         """Invalidate a single access token
 
         Args:
@@ -1381,7 +1382,7 @@ class AuthHandler(BaseHandler):
         user_id: str,
         except_token_id: Optional[int] = None,
         device_id: Optional[str] = None,
-    ):
+    ) -> None:
         """Invalidate access tokens belonging to a user
 
         Args:
@@ -1409,7 +1410,7 @@ class AuthHandler(BaseHandler):
 
     async def add_threepid(
         self, user_id: str, medium: str, address: str, validated_at: int
-    ):
+    ) -> None:
         # check if medium has a valid value
         if medium not in ["email", "msisdn"]:
             raise SynapseError(
@@ -1480,7 +1481,7 @@ class AuthHandler(BaseHandler):
             Hashed password.
         """
 
-        def _do_hash():
+        def _do_hash() -> str:
             # Normalise the Unicode in the password
             pw = unicodedata.normalize("NFKC", password)
 
@@ -1504,7 +1505,7 @@ class AuthHandler(BaseHandler):
             Whether self.hash(password) == stored_hash.
         """
 
-        def _do_validate_hash(checked_hash: bytes):
+        def _do_validate_hash(checked_hash: bytes) -> bool:
             # Normalise the Unicode in the password
             pw = unicodedata.normalize("NFKC", password)
 
@@ -1581,7 +1582,7 @@ class AuthHandler(BaseHandler):
         client_redirect_url: str,
         extra_attributes: Optional[JsonDict] = None,
         new_user: bool = False,
-    ):
+    ) -> None:
         """Having figured out a mxid for this user, complete the HTTP request
 
         Args:
@@ -1627,7 +1628,7 @@ class AuthHandler(BaseHandler):
         extra_attributes: Optional[JsonDict] = None,
         new_user: bool = False,
         user_profile_data: Optional[ProfileInfo] = None,
-    ):
+    ) -> None:
         """
         The synchronous portion of complete_sso_login.
 
@@ -1726,7 +1727,7 @@ class AuthHandler(BaseHandler):
             del self._extra_attributes[user_id]
 
     @staticmethod
-    def add_query_param_to_url(url: str, param_name: str, param: Any):
+    def add_query_param_to_url(url: str, param_name: str, param: Any) -> str:
         url_parts = list(urllib.parse.urlparse(url))
         query = urllib.parse.parse_qsl(url_parts[4], keep_blank_values=True)
         query.append((param_name, param))
@@ -1734,9 +1735,9 @@ class AuthHandler(BaseHandler):
         return urllib.parse.urlunparse(url_parts)
 
 
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
 class MacaroonGenerator:
-    hs = attr.ib()
+    hs: "HomeServer"
 
     def generate_guest_access_token(self, user_id: str) -> str:
         macaroon = self._generate_base_macaroon(user_id)
@@ -1816,7 +1817,9 @@ class PasswordProvider:
     """
 
     @classmethod
-    def load(cls, module, config, module_api: ModuleApi) -> "PasswordProvider":
+    def load(
+        cls, module: Type, config: JsonDict, module_api: ModuleApi
+    ) -> "PasswordProvider":
         try:
             pp = module(config=config, account_handler=module_api)
         except Exception as e:
@@ -1824,7 +1827,7 @@ class PasswordProvider:
             raise
         return cls(pp, module_api)
 
-    def __init__(self, pp, module_api: ModuleApi):
+    def __init__(self, pp: "PasswordProvider", module_api: ModuleApi):
         self._pp = pp
         self._module_api = module_api
 
@@ -1838,7 +1841,7 @@ class PasswordProvider:
         if g:
             self._supported_login_types.update(g())
 
-    def __str__(self):
+    def __str__(self) -> str:
         return str(self._pp)
 
     def get_supported_login_types(self) -> Mapping[str, Iterable[str]]:
@@ -1876,19 +1879,19 @@ class PasswordProvider:
         """
         # first grandfather in a call to check_password
         if login_type == LoginType.PASSWORD:
-            g = getattr(self._pp, "check_password", None)
-            if g:
+            check_password = getattr(self._pp, "check_password", None)
+            if check_password:
                 qualified_user_id = self._module_api.get_qualified_user_id(username)
-                is_valid = await self._pp.check_password(
+                is_valid = await check_password(
                     qualified_user_id, login_dict["password"]
                 )
                 if is_valid:
                     return qualified_user_id, None
 
-        g = getattr(self._pp, "check_auth", None)
-        if not g:
+        check_auth = getattr(self._pp, "check_auth", None)
+        if not check_auth:
             return None
-        result = await g(username, login_type, login_dict)
+        result = await check_auth(username, login_type, login_dict)
 
         # Check if the return value is a str or a tuple
         if isinstance(result, str):