diff --git a/changelog.d/7063.misc b/changelog.d/7063.misc
new file mode 100644
index 0000000000..e7b1cd3cd8
--- /dev/null
+++ b/changelog.d/7063.misc
@@ -0,0 +1 @@
+Add type annotations and comments to the auth handler.
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 7ca90f91c4..7860f9625e 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -18,10 +18,10 @@ import logging
import time
import unicodedata
import urllib.parse
-from typing import Any
+from typing import Any, Dict, Iterable, List, Optional
import attr
-import bcrypt
+import bcrypt # type: ignore[import]
import pymacaroons
from twisted.internet import defer
@@ -45,7 +45,7 @@ from synapse.http.site import SynapseRequest
from synapse.logging.context import defer_to_thread
from synapse.module_api import ModuleApi
from synapse.push.mailer import load_jinja2_templates
-from synapse.types import UserID
+from synapse.types import Requester, UserID
from synapse.util.caches.expiringcache import ExpiringCache
from ._base import BaseHandler
@@ -63,11 +63,11 @@ class AuthHandler(BaseHandler):
"""
super(AuthHandler, self).__init__(hs)
- self.checkers = {} # type: dict[str, UserInteractiveAuthChecker]
+ self.checkers = {} # type: Dict[str, UserInteractiveAuthChecker]
for auth_checker_class in INTERACTIVE_AUTH_CHECKERS:
inst = auth_checker_class(hs)
if inst.is_enabled():
- self.checkers[inst.AUTH_TYPE] = inst
+ self.checkers[inst.AUTH_TYPE] = inst # type: ignore
self.bcrypt_rounds = hs.config.bcrypt_rounds
@@ -124,7 +124,9 @@ class AuthHandler(BaseHandler):
self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist)
@defer.inlineCallbacks
- def validate_user_via_ui_auth(self, requester, request_body, clientip):
+ def validate_user_via_ui_auth(
+ self, requester: Requester, request_body: Dict[str, Any], clientip: str
+ ):
"""
Checks that the user is who they claim to be, via a UI auth.
@@ -133,11 +135,11 @@ class AuthHandler(BaseHandler):
that it isn't stolen by re-authenticating them.
Args:
- requester (Requester): The user, as given by the access token
+ requester: The user, as given by the access token
- request_body (dict): The body of the request sent by the client
+ request_body: The body of the request sent by the client
- clientip (str): The IP address of the client.
+ clientip: The IP address of the client.
Returns:
defer.Deferred[dict]: the parameters for this request (which may
@@ -208,7 +210,9 @@ class AuthHandler(BaseHandler):
return self.checkers.keys()
@defer.inlineCallbacks
- def check_auth(self, flows, clientdict, clientip):
+ def check_auth(
+ self, flows: List[List[str]], clientdict: Dict[str, Any], clientip: str
+ ):
"""
Takes a dictionary sent by the client in the login / registration
protocol and handles the User-Interactive Auth flow.
@@ -223,14 +227,14 @@ class AuthHandler(BaseHandler):
decorator.
Args:
- flows (list): A list of login flows. Each flow is an ordered list of
- strings representing auth-types. At least one full
- flow must be completed in order for auth to be successful.
+ flows: A list of login flows. Each flow is an ordered list of
+ strings representing auth-types. At least one full
+ flow must be completed in order for auth to be successful.
clientdict: The dictionary from the client root level, not the
'auth' key: this method prompts for auth if none is sent.
- clientip (str): The IP address of the client.
+ clientip: The IP address of the client.
Returns:
defer.Deferred[dict, dict, str]: a deferred tuple of
@@ -250,7 +254,7 @@ class AuthHandler(BaseHandler):
"""
authdict = None
- sid = None
+ sid = None # type: Optional[str]
if clientdict and "auth" in clientdict:
authdict = clientdict["auth"]
del clientdict["auth"]
@@ -283,9 +287,9 @@ class AuthHandler(BaseHandler):
creds = session["creds"]
# check auth type currently being presented
- errordict = {}
+ errordict = {} # type: Dict[str, Any]
if "type" in authdict:
- login_type = authdict["type"]
+ login_type = authdict["type"] # type: str
try:
result = yield self._check_auth_dict(authdict, clientip)
if result:
@@ -326,7 +330,7 @@ class AuthHandler(BaseHandler):
raise InteractiveAuthIncompleteError(ret)
@defer.inlineCallbacks
- def add_oob_auth(self, stagetype, authdict, clientip):
+ def add_oob_auth(self, stagetype: str, authdict: Dict[str, Any], clientip: str):
"""
Adds the result of out-of-band authentication into an existing auth
session. Currently used for adding the result of fallback auth.
@@ -348,7 +352,7 @@ class AuthHandler(BaseHandler):
return True
return False
- def get_session_id(self, clientdict):
+ def get_session_id(self, clientdict: Dict[str, Any]) -> Optional[str]:
"""
Gets the session ID for a client given the client dictionary
@@ -356,7 +360,7 @@ class AuthHandler(BaseHandler):
clientdict: The dictionary sent by the client in the request
Returns:
- str|None: The string session ID the client sent. If the client did
+ The string session ID the client sent. If the client did
not send a session ID, returns None.
"""
sid = None
@@ -366,40 +370,42 @@ class AuthHandler(BaseHandler):
sid = authdict["session"]
return sid
- def set_session_data(self, session_id, key, value):
+ def set_session_data(self, session_id: str, key: str, value: Any) -> None:
"""
Store a key-value pair into the sessions data associated with this
request. This data is stored server-side and cannot be modified by
the client.
Args:
- session_id (string): The ID of this session as returned from check_auth
- key (string): The key to store the data under
- value (any): The data to store
+ session_id: The ID of this session as returned from check_auth
+ key: The key to store the data under
+ value: The data to store
"""
sess = self._get_session_info(session_id)
sess.setdefault("serverdict", {})[key] = value
self._save_session(sess)
- def get_session_data(self, session_id, key, default=None):
+ def get_session_data(
+ self, session_id: str, key: str, default: Optional[Any] = None
+ ) -> Any:
"""
Retrieve data stored with set_session_data
Args:
- session_id (string): The ID of this session as returned from check_auth
- key (string): The key to store the data under
- default (any): Value to return if the key has not been set
+ session_id: The ID of this session as returned from check_auth
+ key: The key to store the data under
+ default: Value to return if the key has not been set
"""
sess = self._get_session_info(session_id)
return sess.setdefault("serverdict", {}).get(key, default)
@defer.inlineCallbacks
- def _check_auth_dict(self, authdict, clientip):
+ def _check_auth_dict(self, authdict: Dict[str, Any], clientip: str):
"""Attempt to validate the auth dict provided by a client
Args:
- authdict (object): auth dict provided by the client
- clientip (str): IP address of the client
+ authdict: auth dict provided by the client
+ clientip: IP address of the client
Returns:
Deferred: result of the stage verification.
@@ -425,10 +431,10 @@ class AuthHandler(BaseHandler):
(canonical_id, callback) = yield self.validate_login(user_id, authdict)
return canonical_id
- def _get_params_recaptcha(self):
+ def _get_params_recaptcha(self) -> dict:
return {"public_key": self.hs.config.recaptcha_public_key}
- def _get_params_terms(self):
+ def _get_params_terms(self) -> dict:
return {
"policies": {
"privacy_policy": {
@@ -445,7 +451,9 @@ class AuthHandler(BaseHandler):
}
}
- def _auth_dict_for_flows(self, flows, session):
+ def _auth_dict_for_flows(
+ self, flows: List[List[str]], session: Dict[str, Any]
+ ) -> Dict[str, Any]:
public_flows = []
for f in flows:
public_flows.append(f)
@@ -455,7 +463,7 @@ class AuthHandler(BaseHandler):
LoginType.TERMS: self._get_params_terms,
}
- params = {}
+ params = {} # type: Dict[str, Any]
for f in public_flows:
for stage in f:
@@ -468,7 +476,13 @@ class AuthHandler(BaseHandler):
"params": params,
}
- def _get_session_info(self, session_id):
+ def _get_session_info(self, session_id: Optional[str]) -> dict:
+ """
+ Gets or creates a session given a session ID.
+
+ The session can be used to track data across multiple requests, e.g. for
+ interactive authentication.
+ """
if session_id not in self.sessions:
session_id = None
@@ -481,7 +495,9 @@ class AuthHandler(BaseHandler):
return self.sessions[session_id]
@defer.inlineCallbacks
- def get_access_token_for_user_id(self, user_id, device_id, valid_until_ms):
+ def get_access_token_for_user_id(
+ self, user_id: str, device_id: Optional[str], valid_until_ms: Optional[int]
+ ):
"""
Creates a new access token for the user with the given user ID.
@@ -491,11 +507,11 @@ class AuthHandler(BaseHandler):
The device will be recorded in the table if it is not there already.
Args:
- user_id (str): canonical User ID
- device_id (str|None): the device ID to associate with the tokens.
+ user_id: canonical User ID
+ device_id: the device ID to associate with the tokens.
None to leave the tokens unassociated with a device (deprecated:
we should always have a device ID)
- valid_until_ms (int|None): when the token is valid until. None for
+ valid_until_ms: when the token is valid until. None for
no expiry.
Returns:
The access token for the user's session.
@@ -530,13 +546,13 @@ class AuthHandler(BaseHandler):
return access_token
@defer.inlineCallbacks
- def check_user_exists(self, user_id):
+ def check_user_exists(self, user_id: str):
"""
Checks to see if a user with the given id exists. Will check case
insensitively, but return None if there are multiple inexact matches.
Args:
- (unicode|bytes) user_id: complete @user:id
+ user_id: complete @user:id
Returns:
defer.Deferred: (unicode) canonical_user_id, or None if zero or
@@ -551,7 +567,7 @@ class AuthHandler(BaseHandler):
return None
@defer.inlineCallbacks
- def _find_user_id_and_pwd_hash(self, user_id):
+ def _find_user_id_and_pwd_hash(self, user_id: str):
"""Checks to see if a user with the given id exists. Will check case
insensitively, but will return None if there are multiple inexact
matches.
@@ -581,7 +597,7 @@ class AuthHandler(BaseHandler):
)
return result
- def get_supported_login_types(self):
+ def get_supported_login_types(self) -> Iterable[str]:
"""Get a the login types supported for the /login API
By default this is just 'm.login.password' (unless password_enabled is
@@ -589,20 +605,20 @@ class AuthHandler(BaseHandler):
other login types.
Returns:
- Iterable[str]: login types
+ login types
"""
return self._supported_login_types
@defer.inlineCallbacks
- def validate_login(self, username, login_submission):
+ def validate_login(self, username: str, login_submission: Dict[str, Any]):
"""Authenticates the user for the /login API
Also used by the user-interactive auth flow to validate
m.login.password auth types.
Args:
- username (str): username supplied by the user
- login_submission (dict): the whole of the login submission
+ username: username supplied by the user
+ login_submission: the whole of the login submission
(including 'type' and other relevant fields)
Returns:
Deferred[str, func]: canonical user id, and optional callback
@@ -690,13 +706,13 @@ class AuthHandler(BaseHandler):
raise LoginError(403, "Invalid password", errcode=Codes.FORBIDDEN)
@defer.inlineCallbacks
- def check_password_provider_3pid(self, medium, address, password):
+ def check_password_provider_3pid(self, medium: str, address: str, password: str):
"""Check if a password provider is able to validate a thirdparty login
Args:
- medium (str): The medium of the 3pid (ex. email).
- address (str): The address of the 3pid (ex. jdoe@example.com).
- password (str): The password of the user.
+ medium: The medium of the 3pid (ex. email).
+ address: The address of the 3pid (ex. jdoe@example.com).
+ password: The password of the user.
Returns:
Deferred[(str|None, func|None)]: A tuple of `(user_id,
@@ -724,15 +740,15 @@ class AuthHandler(BaseHandler):
return None, None
@defer.inlineCallbacks
- def _check_local_password(self, user_id, password):
+ def _check_local_password(self, user_id: str, password: str):
"""Authenticate a user against the local password database.
user_id is checked case insensitively, but will return None if there are
multiple inexact matches.
Args:
- user_id (unicode): complete @user:id
- password (unicode): the provided password
+ user_id: complete @user:id
+ password: the provided password
Returns:
Deferred[unicode] the canonical_user_id, or Deferred[None] if
unknown user/bad password
@@ -755,7 +771,7 @@ class AuthHandler(BaseHandler):
return user_id
@defer.inlineCallbacks
- def validate_short_term_login_token_and_get_user_id(self, login_token):
+ def validate_short_term_login_token_and_get_user_id(self, login_token: str):
auth_api = self.hs.get_auth()
user_id = None
try:
@@ -769,11 +785,11 @@ class AuthHandler(BaseHandler):
return user_id
@defer.inlineCallbacks
- def delete_access_token(self, access_token):
+ def delete_access_token(self, access_token: str):
"""Invalidate a single access token
Args:
- access_token (str): access token to be deleted
+ access_token: access token to be deleted
Returns:
Deferred
@@ -798,15 +814,17 @@ class AuthHandler(BaseHandler):
@defer.inlineCallbacks
def delete_access_tokens_for_user(
- self, user_id, except_token_id=None, device_id=None
+ self,
+ user_id: str,
+ except_token_id: Optional[str] = None,
+ device_id: Optional[str] = None,
):
"""Invalidate access tokens belonging to a user
Args:
- user_id (str): ID of user the tokens belong to
- except_token_id (str|None): access_token ID which should *not* be
- deleted
- device_id (str|None): ID of device the tokens are associated with.
+ user_id: ID of user the tokens belong to
+ except_token_id: access_token ID which should *not* be deleted
+ device_id: ID of device the tokens are associated with.
If None, tokens associated with any device (or no device) will
be deleted
Returns:
@@ -830,7 +848,7 @@ class AuthHandler(BaseHandler):
)
@defer.inlineCallbacks
- def add_threepid(self, user_id, medium, address, validated_at):
+ def add_threepid(self, user_id: str, medium: str, address: str, validated_at: int):
# check if medium has a valid value
if medium not in ["email", "msisdn"]:
raise SynapseError(
@@ -856,19 +874,20 @@ class AuthHandler(BaseHandler):
)
@defer.inlineCallbacks
- def delete_threepid(self, user_id, medium, address, id_server=None):
+ def delete_threepid(
+ self, user_id: str, medium: str, address: str, id_server: Optional[str] = None
+ ):
"""Attempts to unbind the 3pid on the identity servers and deletes it
from the local database.
Args:
- user_id (str)
- medium (str)
- address (str)
- id_server (str|None): Use the given identity server when unbinding
+ user_id: ID of user to remove the 3pid from.
+ medium: The medium of the 3pid being removed: "email" or "msisdn".
+ address: The 3pid address to remove.
+ id_server: Use the given identity server when unbinding
any threepids. If None then will attempt to unbind using the
identity server specified when binding (if known).
-
Returns:
Deferred[bool]: Returns True if successfully unbound the 3pid on
the identity server, False if identity server doesn't support the
@@ -887,17 +906,18 @@ class AuthHandler(BaseHandler):
yield self.store.user_delete_threepid(user_id, medium, address)
return result
- def _save_session(self, session):
+ def _save_session(self, session: Dict[str, Any]) -> None:
+ """Update the last used time on the session to now and add it back to the session store."""
# TODO: Persistent storage
logger.debug("Saving session %s", session)
session["last_used"] = self.hs.get_clock().time_msec()
self.sessions[session["id"]] = session
- def hash(self, password):
+ def hash(self, password: str):
"""Computes a secure hash of password.
Args:
- password (unicode): Password to hash.
+ password: Password to hash.
Returns:
Deferred(unicode): Hashed password.
@@ -914,12 +934,12 @@ class AuthHandler(BaseHandler):
return defer_to_thread(self.hs.get_reactor(), _do_hash)
- def validate_hash(self, password, stored_hash):
+ def validate_hash(self, password: str, stored_hash: bytes):
"""Validates that self.hash(password) == stored_hash.
Args:
- password (unicode): Password to hash.
- stored_hash (bytes): Expected hash value.
+ password: Password to hash.
+ stored_hash: Expected hash value.
Returns:
Deferred(bool): Whether self.hash(password) == stored_hash.
@@ -1007,7 +1027,9 @@ class MacaroonGenerator(object):
hs = attr.ib()
- def generate_access_token(self, user_id, extra_caveats=None):
+ def generate_access_token(
+ self, user_id: str, extra_caveats: Optional[List[str]] = None
+ ) -> str:
extra_caveats = extra_caveats or []
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = access")
@@ -1020,16 +1042,9 @@ class MacaroonGenerator(object):
macaroon.add_first_party_caveat(caveat)
return macaroon.serialize()
- def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)):
- """
-
- Args:
- user_id (unicode):
- duration_in_ms (int):
-
- Returns:
- unicode
- """
+ def generate_short_term_login_token(
+ self, user_id: str, duration_in_ms: int = (2 * 60 * 1000)
+ ) -> str:
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = login")
now = self.hs.get_clock().time_msec()
@@ -1037,12 +1052,12 @@ class MacaroonGenerator(object):
macaroon.add_first_party_caveat("time < %d" % (expiry,))
return macaroon.serialize()
- def generate_delete_pusher_token(self, user_id):
+ def generate_delete_pusher_token(self, user_id: str) -> str:
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = delete_pusher")
return macaroon.serialize()
- def _generate_base_macaroon(self, user_id):
+ def _generate_base_macaroon(self, user_id: str) -> pymacaroons.Macaroon:
macaroon = pymacaroons.Macaroon(
location=self.hs.config.server_name,
identifier="key",
diff --git a/tox.ini b/tox.ini
index 7622aa19f1..8b4c37c2ee 100644
--- a/tox.ini
+++ b/tox.ini
@@ -185,6 +185,7 @@ commands = mypy \
synapse/federation/federation_client.py \
synapse/federation/sender \
synapse/federation/transport \
+ synapse/handlers/auth.py \
synapse/handlers/directory.py \
synapse/handlers/presence.py \
synapse/handlers/sync.py \
|