diff options
author | Andrew Morgan <andrew@amorgan.xyz> | 2020-03-24 14:36:35 +0000 |
---|---|---|
committer | Andrew Morgan <andrew@amorgan.xyz> | 2020-03-24 14:36:35 +0000 |
commit | 6095a49d4a8eb4e6485a233316ca6758e8568d53 (patch) | |
tree | c27ed54539bbd1f872c9fe1b0352aa469fa1495f /synapse/handlers | |
parent | Merge pull request #7058 from matrix-org/babolivier/saml_error_html (diff) | |
parent | Revert "Add options to disable setting profile info for prevent changes. (#70... (diff) | |
download | synapse-6095a49d4a8eb4e6485a233316ca6758e8568d53.tar.xz |
Merge commit '6a3504636' into dinsic-release-v1.12.x
* commit '6a3504636': (29 commits) Revert "Add options to disable setting profile info for prevent changes. (#7053)" Populate the room version from state events (#7070) Fix buggy condition in account validity handler (#7074) Use innerText instead of innerHTML Add type annotations and comments to auth handler (#7063) Lint Put the file in the templates directory Update wording and config Changelog Move the default SAML2 error HTML to a dedicated file Refactor a bit Also don't fail on aliases events in this case Lint Changelog Also don't filter out events sent by ignored users when checking state visibility Fix condition Don't filter out dummy events when we're checking the visibility of state Update sample_config.yaml Update synapse/config/registration.py lint, fix tests ...
Diffstat (limited to 'synapse/handlers')
-rw-r--r-- | synapse/handlers/account_validity.py | 6 | ||||
-rw-r--r-- | synapse/handlers/auth.py | 193 | ||||
-rw-r--r-- | synapse/handlers/message.py | 2 |
3 files changed, 110 insertions, 91 deletions
diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py index 6c46c995d2..f4bf92f775 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py @@ -48,7 +48,11 @@ class AccountValidityHandler(object): self._show_users_in_user_directory = self.hs.config.show_users_in_user_directory self.profile_handler = self.hs.get_profile_handler() - if self._account_validity.renew_by_email_enabled and load_jinja2_templates: + if ( + self._account_validity.enabled + and self._account_validity.renew_by_email_enabled + and load_jinja2_templates + ): # Don't do email-specific configuration if renewal by email is disabled. try: app_name = self.hs.config.email_app_name diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 7ca90f91c4..7860f9625e 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -18,10 +18,10 @@ import logging import time import unicodedata import urllib.parse -from typing import Any +from typing import Any, Dict, Iterable, List, Optional import attr -import bcrypt +import bcrypt # type: ignore[import] import pymacaroons from twisted.internet import defer @@ -45,7 +45,7 @@ from synapse.http.site import SynapseRequest from synapse.logging.context import defer_to_thread from synapse.module_api import ModuleApi from synapse.push.mailer import load_jinja2_templates -from synapse.types import UserID +from synapse.types import Requester, UserID from synapse.util.caches.expiringcache import ExpiringCache from ._base import BaseHandler @@ -63,11 +63,11 @@ class AuthHandler(BaseHandler): """ super(AuthHandler, self).__init__(hs) - self.checkers = {} # type: dict[str, UserInteractiveAuthChecker] + self.checkers = {} # type: Dict[str, UserInteractiveAuthChecker] for auth_checker_class in INTERACTIVE_AUTH_CHECKERS: inst = auth_checker_class(hs) if inst.is_enabled(): - self.checkers[inst.AUTH_TYPE] = inst + self.checkers[inst.AUTH_TYPE] = inst # type: ignore self.bcrypt_rounds = hs.config.bcrypt_rounds @@ -124,7 +124,9 @@ class AuthHandler(BaseHandler): self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist) @defer.inlineCallbacks - def validate_user_via_ui_auth(self, requester, request_body, clientip): + def validate_user_via_ui_auth( + self, requester: Requester, request_body: Dict[str, Any], clientip: str + ): """ Checks that the user is who they claim to be, via a UI auth. @@ -133,11 +135,11 @@ class AuthHandler(BaseHandler): that it isn't stolen by re-authenticating them. Args: - requester (Requester): The user, as given by the access token + requester: The user, as given by the access token - request_body (dict): The body of the request sent by the client + request_body: The body of the request sent by the client - clientip (str): The IP address of the client. + clientip: The IP address of the client. Returns: defer.Deferred[dict]: the parameters for this request (which may @@ -208,7 +210,9 @@ class AuthHandler(BaseHandler): return self.checkers.keys() @defer.inlineCallbacks - def check_auth(self, flows, clientdict, clientip): + def check_auth( + self, flows: List[List[str]], clientdict: Dict[str, Any], clientip: str + ): """ Takes a dictionary sent by the client in the login / registration protocol and handles the User-Interactive Auth flow. @@ -223,14 +227,14 @@ class AuthHandler(BaseHandler): decorator. Args: - flows (list): A list of login flows. Each flow is an ordered list of - strings representing auth-types. At least one full - flow must be completed in order for auth to be successful. + flows: A list of login flows. Each flow is an ordered list of + strings representing auth-types. At least one full + flow must be completed in order for auth to be successful. clientdict: The dictionary from the client root level, not the 'auth' key: this method prompts for auth if none is sent. - clientip (str): The IP address of the client. + clientip: The IP address of the client. Returns: defer.Deferred[dict, dict, str]: a deferred tuple of @@ -250,7 +254,7 @@ class AuthHandler(BaseHandler): """ authdict = None - sid = None + sid = None # type: Optional[str] if clientdict and "auth" in clientdict: authdict = clientdict["auth"] del clientdict["auth"] @@ -283,9 +287,9 @@ class AuthHandler(BaseHandler): creds = session["creds"] # check auth type currently being presented - errordict = {} + errordict = {} # type: Dict[str, Any] if "type" in authdict: - login_type = authdict["type"] + login_type = authdict["type"] # type: str try: result = yield self._check_auth_dict(authdict, clientip) if result: @@ -326,7 +330,7 @@ class AuthHandler(BaseHandler): raise InteractiveAuthIncompleteError(ret) @defer.inlineCallbacks - def add_oob_auth(self, stagetype, authdict, clientip): + def add_oob_auth(self, stagetype: str, authdict: Dict[str, Any], clientip: str): """ Adds the result of out-of-band authentication into an existing auth session. Currently used for adding the result of fallback auth. @@ -348,7 +352,7 @@ class AuthHandler(BaseHandler): return True return False - def get_session_id(self, clientdict): + def get_session_id(self, clientdict: Dict[str, Any]) -> Optional[str]: """ Gets the session ID for a client given the client dictionary @@ -356,7 +360,7 @@ class AuthHandler(BaseHandler): clientdict: The dictionary sent by the client in the request Returns: - str|None: The string session ID the client sent. If the client did + The string session ID the client sent. If the client did not send a session ID, returns None. """ sid = None @@ -366,40 +370,42 @@ class AuthHandler(BaseHandler): sid = authdict["session"] return sid - def set_session_data(self, session_id, key, value): + def set_session_data(self, session_id: str, key: str, value: Any) -> None: """ Store a key-value pair into the sessions data associated with this request. This data is stored server-side and cannot be modified by the client. Args: - session_id (string): The ID of this session as returned from check_auth - key (string): The key to store the data under - value (any): The data to store + session_id: The ID of this session as returned from check_auth + key: The key to store the data under + value: The data to store """ sess = self._get_session_info(session_id) sess.setdefault("serverdict", {})[key] = value self._save_session(sess) - def get_session_data(self, session_id, key, default=None): + def get_session_data( + self, session_id: str, key: str, default: Optional[Any] = None + ) -> Any: """ Retrieve data stored with set_session_data Args: - session_id (string): The ID of this session as returned from check_auth - key (string): The key to store the data under - default (any): Value to return if the key has not been set + session_id: The ID of this session as returned from check_auth + key: The key to store the data under + default: Value to return if the key has not been set """ sess = self._get_session_info(session_id) return sess.setdefault("serverdict", {}).get(key, default) @defer.inlineCallbacks - def _check_auth_dict(self, authdict, clientip): + def _check_auth_dict(self, authdict: Dict[str, Any], clientip: str): """Attempt to validate the auth dict provided by a client Args: - authdict (object): auth dict provided by the client - clientip (str): IP address of the client + authdict: auth dict provided by the client + clientip: IP address of the client Returns: Deferred: result of the stage verification. @@ -425,10 +431,10 @@ class AuthHandler(BaseHandler): (canonical_id, callback) = yield self.validate_login(user_id, authdict) return canonical_id - def _get_params_recaptcha(self): + def _get_params_recaptcha(self) -> dict: return {"public_key": self.hs.config.recaptcha_public_key} - def _get_params_terms(self): + def _get_params_terms(self) -> dict: return { "policies": { "privacy_policy": { @@ -445,7 +451,9 @@ class AuthHandler(BaseHandler): } } - def _auth_dict_for_flows(self, flows, session): + def _auth_dict_for_flows( + self, flows: List[List[str]], session: Dict[str, Any] + ) -> Dict[str, Any]: public_flows = [] for f in flows: public_flows.append(f) @@ -455,7 +463,7 @@ class AuthHandler(BaseHandler): LoginType.TERMS: self._get_params_terms, } - params = {} + params = {} # type: Dict[str, Any] for f in public_flows: for stage in f: @@ -468,7 +476,13 @@ class AuthHandler(BaseHandler): "params": params, } - def _get_session_info(self, session_id): + def _get_session_info(self, session_id: Optional[str]) -> dict: + """ + Gets or creates a session given a session ID. + + The session can be used to track data across multiple requests, e.g. for + interactive authentication. + """ if session_id not in self.sessions: session_id = None @@ -481,7 +495,9 @@ class AuthHandler(BaseHandler): return self.sessions[session_id] @defer.inlineCallbacks - def get_access_token_for_user_id(self, user_id, device_id, valid_until_ms): + def get_access_token_for_user_id( + self, user_id: str, device_id: Optional[str], valid_until_ms: Optional[int] + ): """ Creates a new access token for the user with the given user ID. @@ -491,11 +507,11 @@ class AuthHandler(BaseHandler): The device will be recorded in the table if it is not there already. Args: - user_id (str): canonical User ID - device_id (str|None): the device ID to associate with the tokens. + user_id: canonical User ID + device_id: the device ID to associate with the tokens. None to leave the tokens unassociated with a device (deprecated: we should always have a device ID) - valid_until_ms (int|None): when the token is valid until. None for + valid_until_ms: when the token is valid until. None for no expiry. Returns: The access token for the user's session. @@ -530,13 +546,13 @@ class AuthHandler(BaseHandler): return access_token @defer.inlineCallbacks - def check_user_exists(self, user_id): + def check_user_exists(self, user_id: str): """ Checks to see if a user with the given id exists. Will check case insensitively, but return None if there are multiple inexact matches. Args: - (unicode|bytes) user_id: complete @user:id + user_id: complete @user:id Returns: defer.Deferred: (unicode) canonical_user_id, or None if zero or @@ -551,7 +567,7 @@ class AuthHandler(BaseHandler): return None @defer.inlineCallbacks - def _find_user_id_and_pwd_hash(self, user_id): + def _find_user_id_and_pwd_hash(self, user_id: str): """Checks to see if a user with the given id exists. Will check case insensitively, but will return None if there are multiple inexact matches. @@ -581,7 +597,7 @@ class AuthHandler(BaseHandler): ) return result - def get_supported_login_types(self): + def get_supported_login_types(self) -> Iterable[str]: """Get a the login types supported for the /login API By default this is just 'm.login.password' (unless password_enabled is @@ -589,20 +605,20 @@ class AuthHandler(BaseHandler): other login types. Returns: - Iterable[str]: login types + login types """ return self._supported_login_types @defer.inlineCallbacks - def validate_login(self, username, login_submission): + def validate_login(self, username: str, login_submission: Dict[str, Any]): """Authenticates the user for the /login API Also used by the user-interactive auth flow to validate m.login.password auth types. Args: - username (str): username supplied by the user - login_submission (dict): the whole of the login submission + username: username supplied by the user + login_submission: the whole of the login submission (including 'type' and other relevant fields) Returns: Deferred[str, func]: canonical user id, and optional callback @@ -690,13 +706,13 @@ class AuthHandler(BaseHandler): raise LoginError(403, "Invalid password", errcode=Codes.FORBIDDEN) @defer.inlineCallbacks - def check_password_provider_3pid(self, medium, address, password): + def check_password_provider_3pid(self, medium: str, address: str, password: str): """Check if a password provider is able to validate a thirdparty login Args: - medium (str): The medium of the 3pid (ex. email). - address (str): The address of the 3pid (ex. jdoe@example.com). - password (str): The password of the user. + medium: The medium of the 3pid (ex. email). + address: The address of the 3pid (ex. jdoe@example.com). + password: The password of the user. Returns: Deferred[(str|None, func|None)]: A tuple of `(user_id, @@ -724,15 +740,15 @@ class AuthHandler(BaseHandler): return None, None @defer.inlineCallbacks - def _check_local_password(self, user_id, password): + def _check_local_password(self, user_id: str, password: str): """Authenticate a user against the local password database. user_id is checked case insensitively, but will return None if there are multiple inexact matches. Args: - user_id (unicode): complete @user:id - password (unicode): the provided password + user_id: complete @user:id + password: the provided password Returns: Deferred[unicode] the canonical_user_id, or Deferred[None] if unknown user/bad password @@ -755,7 +771,7 @@ class AuthHandler(BaseHandler): return user_id @defer.inlineCallbacks - def validate_short_term_login_token_and_get_user_id(self, login_token): + def validate_short_term_login_token_and_get_user_id(self, login_token: str): auth_api = self.hs.get_auth() user_id = None try: @@ -769,11 +785,11 @@ class AuthHandler(BaseHandler): return user_id @defer.inlineCallbacks - def delete_access_token(self, access_token): + def delete_access_token(self, access_token: str): """Invalidate a single access token Args: - access_token (str): access token to be deleted + access_token: access token to be deleted Returns: Deferred @@ -798,15 +814,17 @@ class AuthHandler(BaseHandler): @defer.inlineCallbacks def delete_access_tokens_for_user( - self, user_id, except_token_id=None, device_id=None + self, + user_id: str, + except_token_id: Optional[str] = None, + device_id: Optional[str] = None, ): """Invalidate access tokens belonging to a user Args: - user_id (str): ID of user the tokens belong to - except_token_id (str|None): access_token ID which should *not* be - deleted - device_id (str|None): ID of device the tokens are associated with. + user_id: ID of user the tokens belong to + except_token_id: access_token ID which should *not* be deleted + device_id: ID of device the tokens are associated with. If None, tokens associated with any device (or no device) will be deleted Returns: @@ -830,7 +848,7 @@ class AuthHandler(BaseHandler): ) @defer.inlineCallbacks - def add_threepid(self, user_id, medium, address, validated_at): + def add_threepid(self, user_id: str, medium: str, address: str, validated_at: int): # check if medium has a valid value if medium not in ["email", "msisdn"]: raise SynapseError( @@ -856,19 +874,20 @@ class AuthHandler(BaseHandler): ) @defer.inlineCallbacks - def delete_threepid(self, user_id, medium, address, id_server=None): + def delete_threepid( + self, user_id: str, medium: str, address: str, id_server: Optional[str] = None + ): """Attempts to unbind the 3pid on the identity servers and deletes it from the local database. Args: - user_id (str) - medium (str) - address (str) - id_server (str|None): Use the given identity server when unbinding + user_id: ID of user to remove the 3pid from. + medium: The medium of the 3pid being removed: "email" or "msisdn". + address: The 3pid address to remove. + id_server: Use the given identity server when unbinding any threepids. If None then will attempt to unbind using the identity server specified when binding (if known). - Returns: Deferred[bool]: Returns True if successfully unbound the 3pid on the identity server, False if identity server doesn't support the @@ -887,17 +906,18 @@ class AuthHandler(BaseHandler): yield self.store.user_delete_threepid(user_id, medium, address) return result - def _save_session(self, session): + def _save_session(self, session: Dict[str, Any]) -> None: + """Update the last used time on the session to now and add it back to the session store.""" # TODO: Persistent storage logger.debug("Saving session %s", session) session["last_used"] = self.hs.get_clock().time_msec() self.sessions[session["id"]] = session - def hash(self, password): + def hash(self, password: str): """Computes a secure hash of password. Args: - password (unicode): Password to hash. + password: Password to hash. Returns: Deferred(unicode): Hashed password. @@ -914,12 +934,12 @@ class AuthHandler(BaseHandler): return defer_to_thread(self.hs.get_reactor(), _do_hash) - def validate_hash(self, password, stored_hash): + def validate_hash(self, password: str, stored_hash: bytes): """Validates that self.hash(password) == stored_hash. Args: - password (unicode): Password to hash. - stored_hash (bytes): Expected hash value. + password: Password to hash. + stored_hash: Expected hash value. Returns: Deferred(bool): Whether self.hash(password) == stored_hash. @@ -1007,7 +1027,9 @@ class MacaroonGenerator(object): hs = attr.ib() - def generate_access_token(self, user_id, extra_caveats=None): + def generate_access_token( + self, user_id: str, extra_caveats: Optional[List[str]] = None + ) -> str: extra_caveats = extra_caveats or [] macaroon = self._generate_base_macaroon(user_id) macaroon.add_first_party_caveat("type = access") @@ -1020,16 +1042,9 @@ class MacaroonGenerator(object): macaroon.add_first_party_caveat(caveat) return macaroon.serialize() - def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)): - """ - - Args: - user_id (unicode): - duration_in_ms (int): - - Returns: - unicode - """ + def generate_short_term_login_token( + self, user_id: str, duration_in_ms: int = (2 * 60 * 1000) + ) -> str: macaroon = self._generate_base_macaroon(user_id) macaroon.add_first_party_caveat("type = login") now = self.hs.get_clock().time_msec() @@ -1037,12 +1052,12 @@ class MacaroonGenerator(object): macaroon.add_first_party_caveat("time < %d" % (expiry,)) return macaroon.serialize() - def generate_delete_pusher_token(self, user_id): + def generate_delete_pusher_token(self, user_id: str) -> str: macaroon = self._generate_base_macaroon(user_id) macaroon.add_first_party_caveat("type = delete_pusher") return macaroon.serialize() - def _generate_base_macaroon(self, user_id): + def _generate_base_macaroon(self, user_id: str) -> pymacaroons.Macaroon: macaroon = pymacaroons.Macaroon( location=self.hs.config.server_name, identifier="key", diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 0c84c6cec4..b743fc2dcc 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -160,7 +160,7 @@ class MessageHandler(object): raise NotFoundError("Can't find event for token %s" % (at_token,)) visible_events = yield filter_events_for_client( - self.storage, user_id, last_events, apply_retention_policies=False + self.storage, user_id, last_events, filter_send_to_client=False ) event = last_events[0] |