diff options
Diffstat (limited to 'synapse/handlers')
-rw-r--r-- | synapse/handlers/auth.py | 451 | ||||
-rw-r--r-- | synapse/handlers/cas_handler.py | 221 | ||||
-rw-r--r-- | synapse/handlers/device.py | 12 | ||||
-rw-r--r-- | synapse/handlers/directory.py | 6 | ||||
-rw-r--r-- | synapse/handlers/federation.py | 25 | ||||
-rw-r--r-- | synapse/handlers/identity.py | 2 | ||||
-rw-r--r-- | synapse/handlers/message.py | 57 | ||||
-rw-r--r-- | synapse/handlers/password_policy.py | 93 | ||||
-rw-r--r-- | synapse/handlers/presence.py | 4 | ||||
-rw-r--r-- | synapse/handlers/profile.py | 16 | ||||
-rw-r--r-- | synapse/handlers/register.py | 28 | ||||
-rw-r--r-- | synapse/handlers/room.py | 8 | ||||
-rw-r--r-- | synapse/handlers/room_list.py | 23 | ||||
-rw-r--r-- | synapse/handlers/room_member.py | 3 | ||||
-rw-r--r-- | synapse/handlers/saml_handler.py | 55 | ||||
-rw-r--r-- | synapse/handlers/set_password.py | 15 | ||||
-rw-r--r-- | synapse/handlers/sync.py | 6 | ||||
-rw-r--r-- | synapse/handlers/typing.py | 11 |
18 files changed, 800 insertions, 236 deletions
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 7860f9625e..dbe165ce1e 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -18,14 +18,12 @@ import logging import time import unicodedata import urllib.parse -from typing import Any, Dict, Iterable, List, Optional +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union import attr import bcrypt # type: ignore[import] import pymacaroons -from twisted.internet import defer - import synapse.util.stringutils as stringutils from synapse.api.constants import LoginType from synapse.api.errors import ( @@ -91,6 +89,7 @@ class AuthHandler(BaseHandler): self.hs = hs # FIXME better possibility to access registrationHandler later? self.macaroon_gen = hs.get_macaroon_generator() self._password_enabled = hs.config.password_enabled + self._sso_enabled = hs.config.saml2_enabled or hs.config.cas_enabled # we keep this as a list despite the O(N^2) implication so that we can # keep PASSWORD first and avoid confusing clients which pick the first @@ -106,6 +105,13 @@ class AuthHandler(BaseHandler): if t not in login_types: login_types.append(t) self._supported_login_types = login_types + # Login types and UI Auth types have a heavy overlap, but are not + # necessarily identical. Login types have SSO (and other login types) + # added in the rest layer, see synapse.rest.client.v1.login.LoginRestServerlet.on_GET. + ui_auth_types = login_types.copy() + if self._sso_enabled: + ui_auth_types.append(LoginType.SSO) + self._supported_ui_auth_types = ui_auth_types # Ratelimiter for failed auth during UIA. Uses same ratelimit config # as per `rc_login.failed_attempts`. @@ -113,20 +119,43 @@ class AuthHandler(BaseHandler): self._clock = self.hs.get_clock() - # Load the SSO redirect confirmation page HTML template + # Load the SSO HTML templates. + + # The following template is shown to the user during a client login via SSO, + # after the SSO completes and before redirecting them back to their client. + # It notifies the user they are about to give access to their matrix account + # to the client. self._sso_redirect_confirm_template = load_jinja2_templates( hs.config.sso_redirect_confirm_template_dir, ["sso_redirect_confirm.html"], )[0] + # The following template is shown during user interactive authentication + # in the fallback auth scenario. It notifies the user that they are + # authenticating for an operation to occur on their account. + self._sso_auth_confirm_template = load_jinja2_templates( + hs.config.sso_redirect_confirm_template_dir, ["sso_auth_confirm.html"], + )[0] + # The following template is shown after a successful user interactive + # authentication session. It tells the user they can close the window. + self._sso_auth_success_template = hs.config.sso_auth_success_template + # The following template is shown during the SSO authentication process if + # the account is deactivated. + self._sso_account_deactivated_template = ( + hs.config.sso_account_deactivated_template + ) self._server_name = hs.config.server_name # cast to tuple for use with str.startswith self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist) - @defer.inlineCallbacks - def validate_user_via_ui_auth( - self, requester: Requester, request_body: Dict[str, Any], clientip: str - ): + async def validate_user_via_ui_auth( + self, + requester: Requester, + request: SynapseRequest, + request_body: Dict[str, Any], + clientip: str, + description: str, + ) -> dict: """ Checks that the user is who they claim to be, via a UI auth. @@ -137,12 +166,17 @@ class AuthHandler(BaseHandler): Args: requester: The user, as given by the access token + request: The request sent by the client. + request_body: The body of the request sent by the client clientip: The IP address of the client. + description: A human readable string to be displayed to the user that + describes the operation happening on their account. + Returns: - defer.Deferred[dict]: the parameters for this request (which may + The parameters for this request (which may have been given only in a previous call). Raises: @@ -169,10 +203,12 @@ class AuthHandler(BaseHandler): ) # build a list of supported flows - flows = [[login_type] for login_type in self._supported_login_types] + flows = [[login_type] for login_type in self._supported_ui_auth_types] try: - result, params, _ = yield self.check_auth(flows, request_body, clientip) + result, params, _ = await self.check_auth( + flows, request, request_body, clientip, description + ) except LoginError: # Update the ratelimite to say we failed (`can_do_action` doesn't raise). self._failed_uia_attempts_ratelimiter.can_do_action( @@ -185,7 +221,7 @@ class AuthHandler(BaseHandler): raise # find the completed login type - for login_type in self._supported_login_types: + for login_type in self._supported_ui_auth_types: if login_type not in result: continue @@ -209,18 +245,18 @@ class AuthHandler(BaseHandler): """ return self.checkers.keys() - @defer.inlineCallbacks - def check_auth( - self, flows: List[List[str]], clientdict: Dict[str, Any], clientip: str - ): + async def check_auth( + self, + flows: List[List[str]], + request: SynapseRequest, + clientdict: Dict[str, Any], + clientip: str, + description: str, + ) -> Tuple[dict, dict, str]: """ Takes a dictionary sent by the client in the login / registration protocol and handles the User-Interactive Auth flow. - As a side effect, this function fills in the 'creds' key on the user's - session with a map, which maps each auth-type (str) to the relevant - identity authenticated by that auth-type (mostly str, but for captcha, bool). - If no auth flows have been completed successfully, raises an InteractiveAuthIncompleteError. To handle this, you can use synapse.rest.client.v2_alpha._base.interactive_auth_handler as a @@ -231,14 +267,18 @@ class AuthHandler(BaseHandler): strings representing auth-types. At least one full flow must be completed in order for auth to be successful. + request: The request sent by the client. + clientdict: The dictionary from the client root level, not the 'auth' key: this method prompts for auth if none is sent. clientip: The IP address of the client. + description: A human readable string to be displayed to the user that + describes the operation happening on their account. + Returns: - defer.Deferred[dict, dict, str]: a deferred tuple of - (creds, params, session_id). + A tuple of (creds, params, session_id). 'creds' contains the authenticated credentials of each stage. @@ -260,30 +300,47 @@ class AuthHandler(BaseHandler): del clientdict["auth"] if "session" in authdict: sid = authdict["session"] - session = self._get_session_info(sid) - - if len(clientdict) > 0: - # This was designed to allow the client to omit the parameters - # and just supply the session in subsequent calls so it split - # auth between devices by just sharing the session, (eg. so you - # could continue registration from your phone having clicked the - # email auth link on there). It's probably too open to abuse - # because it lets unauthenticated clients store arbitrary objects - # on a homeserver. - # Revisit: Assumimg the REST APIs do sensible validation, the data - # isn't arbintrary. - session["clientdict"] = clientdict - self._save_session(session) - elif "clientdict" in session: - clientdict = session["clientdict"] + + # If there's no session ID, create a new session. + if not sid: + session = self._create_session( + clientdict, (request.uri, request.method, clientdict), description + ) + session_id = session["id"] + + else: + session = self._get_session_info(sid) + session_id = sid + + if not clientdict: + # This was designed to allow the client to omit the parameters + # and just supply the session in subsequent calls so it split + # auth between devices by just sharing the session, (eg. so you + # could continue registration from your phone having clicked the + # email auth link on there). It's probably too open to abuse + # because it lets unauthenticated clients store arbitrary objects + # on a homeserver. + # Revisit: Assuming the REST APIs do sensible validation, the data + # isn't arbitrary. + clientdict = session["clientdict"] + + # Ensure that the queried operation does not vary between stages of + # the UI authentication session. This is done by generating a stable + # comparator based on the URI, method, and body (minus the auth dict) + # and storing it during the initial query. Subsequent queries ensure + # that this comparator has not changed. + comparator = (request.uri, request.method, clientdict) + if session["ui_auth"] != comparator: + raise SynapseError( + 403, + "Requested operation has changed during the UI authentication session.", + ) if not authdict: raise InteractiveAuthIncompleteError( - self._auth_dict_for_flows(flows, session) + self._auth_dict_for_flows(flows, session_id) ) - if "creds" not in session: - session["creds"] = {} creds = session["creds"] # check auth type currently being presented @@ -291,7 +348,7 @@ class AuthHandler(BaseHandler): if "type" in authdict: login_type = authdict["type"] # type: str try: - result = yield self._check_auth_dict(authdict, clientip) + result = await self._check_auth_dict(authdict, clientip) if result: creds[login_type] = result self._save_session(session) @@ -322,15 +379,17 @@ class AuthHandler(BaseHandler): creds, list(clientdict), ) - return creds, clientdict, session["id"] - ret = self._auth_dict_for_flows(flows, session) + return creds, clientdict, session_id + + ret = self._auth_dict_for_flows(flows, session_id) ret["completed"] = list(creds) ret.update(errordict) raise InteractiveAuthIncompleteError(ret) - @defer.inlineCallbacks - def add_oob_auth(self, stagetype: str, authdict: Dict[str, Any], clientip: str): + async def add_oob_auth( + self, stagetype: str, authdict: Dict[str, Any], clientip: str + ) -> bool: """ Adds the result of out-of-band authentication into an existing auth session. Currently used for adding the result of fallback auth. @@ -341,11 +400,9 @@ class AuthHandler(BaseHandler): raise LoginError(400, "", Codes.MISSING_PARAM) sess = self._get_session_info(authdict["session"]) - if "creds" not in sess: - sess["creds"] = {} creds = sess["creds"] - result = yield self.checkers[stagetype].check_auth(authdict, clientip) + result = await self.checkers[stagetype].check_auth(authdict, clientip) if result: creds[stagetype] = result self._save_session(sess) @@ -382,7 +439,7 @@ class AuthHandler(BaseHandler): value: The data to store """ sess = self._get_session_info(session_id) - sess.setdefault("serverdict", {})[key] = value + sess["serverdict"][key] = value self._save_session(sess) def get_session_data( @@ -397,10 +454,11 @@ class AuthHandler(BaseHandler): default: Value to return if the key has not been set """ sess = self._get_session_info(session_id) - return sess.setdefault("serverdict", {}).get(key, default) + return sess["serverdict"].get(key, default) - @defer.inlineCallbacks - def _check_auth_dict(self, authdict: Dict[str, Any], clientip: str): + async def _check_auth_dict( + self, authdict: Dict[str, Any], clientip: str + ) -> Union[Dict[str, Any], str]: """Attempt to validate the auth dict provided by a client Args: @@ -408,7 +466,7 @@ class AuthHandler(BaseHandler): clientip: IP address of the client Returns: - Deferred: result of the stage verification. + Result of the stage verification. Raises: StoreError if there was a problem accessing the database @@ -418,7 +476,7 @@ class AuthHandler(BaseHandler): login_type = authdict["type"] checker = self.checkers.get(login_type) if checker is not None: - res = yield checker.check_auth(authdict, clientip=clientip) + res = await checker.check_auth(authdict, clientip=clientip) return res # build a v1-login-style dict out of the authdict and fall back to the @@ -428,7 +486,7 @@ class AuthHandler(BaseHandler): if user_id is None: raise SynapseError(400, "", Codes.MISSING_PARAM) - (canonical_id, callback) = yield self.validate_login(user_id, authdict) + (canonical_id, callback) = await self.validate_login(user_id, authdict) return canonical_id def _get_params_recaptcha(self) -> dict: @@ -452,7 +510,7 @@ class AuthHandler(BaseHandler): } def _auth_dict_for_flows( - self, flows: List[List[str]], session: Dict[str, Any] + self, flows: List[List[str]], session_id: str, ) -> Dict[str, Any]: public_flows = [] for f in flows: @@ -471,31 +529,73 @@ class AuthHandler(BaseHandler): params[stage] = get_params[stage]() return { - "session": session["id"], + "session": session_id, "flows": [{"stages": f} for f in public_flows], "params": params, } - def _get_session_info(self, session_id: Optional[str]) -> dict: + def _create_session( + self, + clientdict: Dict[str, Any], + ui_auth: Tuple[bytes, bytes, Dict[str, Any]], + description: str, + ) -> dict: """ - Gets or creates a session given a session ID. + Creates a new user interactive authentication session. The session can be used to track data across multiple requests, e.g. for interactive authentication. - """ - if session_id not in self.sessions: - session_id = None - if not session_id: - # create a new session - while session_id is None or session_id in self.sessions: - session_id = stringutils.random_string(24) - self.sessions[session_id] = {"id": session_id} + Each session has the following keys: + + id: + A unique identifier for this session. Passed back to the client + and returned for each stage. + clientdict: + The dictionary from the client root level, not the 'auth' key. + ui_auth: + A tuple which is checked at each stage of the authentication to + ensure that the asked for operation has not changed. + creds: + A map, which maps each auth-type (str) to the relevant identity + authenticated by that auth-type (mostly str, but for captcha, bool). + serverdict: + A map of data that is stored server-side and cannot be modified + by the client. + description: + A string description of the operation that the current + authentication is authorising. + Returns: + The newly created session. + """ + session_id = None + while session_id is None or session_id in self.sessions: + session_id = stringutils.random_string(24) + + self.sessions[session_id] = { + "id": session_id, + "clientdict": clientdict, + "ui_auth": ui_auth, + "creds": {}, + "serverdict": {}, + "description": description, + } return self.sessions[session_id] - @defer.inlineCallbacks - def get_access_token_for_user_id( + def _get_session_info(self, session_id: str) -> dict: + """ + Gets a session given a session ID. + + The session can be used to track data across multiple requests, e.g. for + interactive authentication. + """ + try: + return self.sessions[session_id] + except KeyError: + raise SynapseError(400, "Unknown session ID: %s" % (session_id,)) + + async def get_access_token_for_user_id( self, user_id: str, device_id: Optional[str], valid_until_ms: Optional[int] ): """ @@ -525,10 +625,10 @@ class AuthHandler(BaseHandler): ) logger.info("Logging in user %s on device %s%s", user_id, device_id, fmt_expiry) - yield self.auth.check_auth_blocking(user_id) + await self.auth.check_auth_blocking(user_id) access_token = self.macaroon_gen.generate_access_token(user_id) - yield self.store.add_access_token_to_user( + await self.store.add_access_token_to_user( user_id, access_token, device_id, valid_until_ms ) @@ -538,15 +638,14 @@ class AuthHandler(BaseHandler): # device, so we double-check it here. if device_id is not None: try: - yield self.store.get_device(user_id, device_id) + await self.store.get_device(user_id, device_id) except StoreError: - yield self.store.delete_access_token(access_token) + await self.store.delete_access_token(access_token) raise StoreError(400, "Login raced against device deletion") return access_token - @defer.inlineCallbacks - def check_user_exists(self, user_id: str): + async def check_user_exists(self, user_id: str) -> Optional[str]: """ Checks to see if a user with the given id exists. Will check case insensitively, but return None if there are multiple inexact matches. @@ -555,28 +654,25 @@ class AuthHandler(BaseHandler): user_id: complete @user:id Returns: - defer.Deferred: (unicode) canonical_user_id, or None if zero or - multiple matches - - Raises: - UserDeactivatedError if a user is found but is deactivated. + The canonical_user_id, or None if zero or multiple matches """ - res = yield self._find_user_id_and_pwd_hash(user_id) + res = await self._find_user_id_and_pwd_hash(user_id) if res is not None: return res[0] return None - @defer.inlineCallbacks - def _find_user_id_and_pwd_hash(self, user_id: str): + async def _find_user_id_and_pwd_hash( + self, user_id: str + ) -> Optional[Tuple[str, str]]: """Checks to see if a user with the given id exists. Will check case insensitively, but will return None if there are multiple inexact matches. Returns: - tuple: A 2-tuple of `(canonical_user_id, password_hash)` - None: if there is not exactly one match + A 2-tuple of `(canonical_user_id, password_hash)` or `None` + if there is not exactly one match """ - user_infos = yield self.store.get_users_by_id_case_insensitive(user_id) + user_infos = await self.store.get_users_by_id_case_insensitive(user_id) result = None if not user_infos: @@ -609,8 +705,9 @@ class AuthHandler(BaseHandler): """ return self._supported_login_types - @defer.inlineCallbacks - def validate_login(self, username: str, login_submission: Dict[str, Any]): + async def validate_login( + self, username: str, login_submission: Dict[str, Any] + ) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]: """Authenticates the user for the /login API Also used by the user-interactive auth flow to validate @@ -621,7 +718,7 @@ class AuthHandler(BaseHandler): login_submission: the whole of the login submission (including 'type' and other relevant fields) Returns: - Deferred[str, func]: canonical user id, and optional callback + A tuple of the canonical user id, and optional callback to be called once the access token and device id are issued Raises: StoreError if there was a problem accessing the database @@ -650,7 +747,7 @@ class AuthHandler(BaseHandler): for provider in self.password_providers: if hasattr(provider, "check_password") and login_type == LoginType.PASSWORD: known_login_type = True - is_valid = yield provider.check_password(qualified_user_id, password) + is_valid = await provider.check_password(qualified_user_id, password) if is_valid: return qualified_user_id, None @@ -682,7 +779,7 @@ class AuthHandler(BaseHandler): % (login_type, missing_fields), ) - result = yield provider.check_auth(username, login_type, login_dict) + result = await provider.check_auth(username, login_type, login_dict) if result: if isinstance(result, str): result = (result, None) @@ -691,8 +788,8 @@ class AuthHandler(BaseHandler): if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled: known_login_type = True - canonical_user_id = yield self._check_local_password( - qualified_user_id, password + canonical_user_id = await self._check_local_password( + qualified_user_id, password # type: ignore ) if canonical_user_id: @@ -705,8 +802,9 @@ class AuthHandler(BaseHandler): # login, it turns all LoginErrors into a 401 anyway. raise LoginError(403, "Invalid password", errcode=Codes.FORBIDDEN) - @defer.inlineCallbacks - def check_password_provider_3pid(self, medium: str, address: str, password: str): + async def check_password_provider_3pid( + self, medium: str, address: str, password: str + ) -> Tuple[Optional[str], Optional[Callable[[Dict[str, str]], None]]]: """Check if a password provider is able to validate a thirdparty login Args: @@ -715,9 +813,8 @@ class AuthHandler(BaseHandler): password: The password of the user. Returns: - Deferred[(str|None, func|None)]: A tuple of `(user_id, - callback)`. If authentication is successful, `user_id` is a `str` - containing the authenticated, canonical user ID. `callback` is + A tuple of `(user_id, callback)`. If authentication is successful, + `user_id`is the authenticated, canonical user ID. `callback` is then either a function to be later run after the server has completed login/registration, or `None`. If authentication was unsuccessful, `user_id` and `callback` are both `None`. @@ -729,7 +826,7 @@ class AuthHandler(BaseHandler): # success, to a str (which is the user_id) or a tuple of # (user_id, callback_func), where callback_func should be run # after we've finished everything else - result = yield provider.check_3pid_auth(medium, address, password) + result = await provider.check_3pid_auth(medium, address, password) if result: # Check if the return value is a str or a tuple if isinstance(result, str): @@ -739,8 +836,7 @@ class AuthHandler(BaseHandler): return None, None - @defer.inlineCallbacks - def _check_local_password(self, user_id: str, password: str): + async def _check_local_password(self, user_id: str, password: str) -> Optional[str]: """Authenticate a user against the local password database. user_id is checked case insensitively, but will return None if there are @@ -750,28 +846,26 @@ class AuthHandler(BaseHandler): user_id: complete @user:id password: the provided password Returns: - Deferred[unicode] the canonical_user_id, or Deferred[None] if - unknown user/bad password + The canonical_user_id, or None if unknown user/bad password """ - lookupres = yield self._find_user_id_and_pwd_hash(user_id) + lookupres = await self._find_user_id_and_pwd_hash(user_id) if not lookupres: return None (user_id, password_hash) = lookupres # If the password hash is None, the account has likely been deactivated if not password_hash: - deactivated = yield self.store.get_user_deactivated_status(user_id) + deactivated = await self.store.get_user_deactivated_status(user_id) if deactivated: raise UserDeactivatedError("This account has been deactivated") - result = yield self.validate_hash(password, password_hash) + result = await self.validate_hash(password, password_hash) if not result: logger.warning("Failed password login for user %s", user_id) return None return user_id - @defer.inlineCallbacks - def validate_short_term_login_token_and_get_user_id(self, login_token: str): + async def validate_short_term_login_token_and_get_user_id(self, login_token: str): auth_api = self.hs.get_auth() user_id = None try: @@ -781,26 +875,23 @@ class AuthHandler(BaseHandler): except Exception: raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN) - yield self.auth.check_auth_blocking(user_id) + await self.auth.check_auth_blocking(user_id) return user_id - @defer.inlineCallbacks - def delete_access_token(self, access_token: str): + async def delete_access_token(self, access_token: str): """Invalidate a single access token Args: access_token: access token to be deleted - Returns: - Deferred """ - user_info = yield self.auth.get_user_by_access_token(access_token) - yield self.store.delete_access_token(access_token) + user_info = await self.auth.get_user_by_access_token(access_token) + await self.store.delete_access_token(access_token) # see if any of our auth providers want to know about this for provider in self.password_providers: if hasattr(provider, "on_logged_out"): - yield provider.on_logged_out( + await provider.on_logged_out( user_id=str(user_info["user"]), device_id=user_info["device_id"], access_token=access_token, @@ -808,12 +899,11 @@ class AuthHandler(BaseHandler): # delete pushers associated with this access token if user_info["token_id"] is not None: - yield self.hs.get_pusherpool().remove_pushers_by_access_token( + await self.hs.get_pusherpool().remove_pushers_by_access_token( str(user_info["user"]), (user_info["token_id"],) ) - @defer.inlineCallbacks - def delete_access_tokens_for_user( + async def delete_access_tokens_for_user( self, user_id: str, except_token_id: Optional[str] = None, @@ -827,10 +917,8 @@ class AuthHandler(BaseHandler): device_id: ID of device the tokens are associated with. If None, tokens associated with any device (or no device) will be deleted - Returns: - Deferred """ - tokens_and_devices = yield self.store.user_delete_access_tokens( + tokens_and_devices = await self.store.user_delete_access_tokens( user_id, except_token_id=except_token_id, device_id=device_id ) @@ -838,17 +926,18 @@ class AuthHandler(BaseHandler): for provider in self.password_providers: if hasattr(provider, "on_logged_out"): for token, token_id, device_id in tokens_and_devices: - yield provider.on_logged_out( + await provider.on_logged_out( user_id=user_id, device_id=device_id, access_token=token ) # delete pushers associated with the access tokens - yield self.hs.get_pusherpool().remove_pushers_by_access_token( + await self.hs.get_pusherpool().remove_pushers_by_access_token( user_id, (token_id for _, token_id, _ in tokens_and_devices) ) - @defer.inlineCallbacks - def add_threepid(self, user_id: str, medium: str, address: str, validated_at: int): + async def add_threepid( + self, user_id: str, medium: str, address: str, validated_at: int + ): # check if medium has a valid value if medium not in ["email", "msisdn"]: raise SynapseError( @@ -869,14 +958,13 @@ class AuthHandler(BaseHandler): if medium == "email": address = address.lower() - yield self.store.user_add_threepid( + await self.store.user_add_threepid( user_id, medium, address, validated_at, self.hs.get_clock().time_msec() ) - @defer.inlineCallbacks - def delete_threepid( + async def delete_threepid( self, user_id: str, medium: str, address: str, id_server: Optional[str] = None - ): + ) -> bool: """Attempts to unbind the 3pid on the identity servers and deletes it from the local database. @@ -889,7 +977,7 @@ class AuthHandler(BaseHandler): identity server specified when binding (if known). Returns: - Deferred[bool]: Returns True if successfully unbound the 3pid on + Returns True if successfully unbound the 3pid on the identity server, False if identity server doesn't support the unbind API. """ @@ -899,11 +987,11 @@ class AuthHandler(BaseHandler): address = address.lower() identity_handler = self.hs.get_handlers().identity_handler - result = yield identity_handler.try_unbind_threepid( + result = await identity_handler.try_unbind_threepid( user_id, {"medium": medium, "address": address, "id_server": id_server} ) - yield self.store.user_delete_threepid(user_id, medium, address) + await self.store.user_delete_threepid(user_id, medium, address) return result def _save_session(self, session: Dict[str, Any]) -> None: @@ -913,14 +1001,14 @@ class AuthHandler(BaseHandler): session["last_used"] = self.hs.get_clock().time_msec() self.sessions[session["id"]] = session - def hash(self, password: str): + async def hash(self, password: str) -> str: """Computes a secure hash of password. Args: password: Password to hash. Returns: - Deferred(unicode): Hashed password. + Hashed password. """ def _do_hash(): @@ -932,9 +1020,11 @@ class AuthHandler(BaseHandler): bcrypt.gensalt(self.bcrypt_rounds), ).decode("ascii") - return defer_to_thread(self.hs.get_reactor(), _do_hash) + return await defer_to_thread(self.hs.get_reactor(), _do_hash) - def validate_hash(self, password: str, stored_hash: bytes): + async def validate_hash( + self, password: str, stored_hash: Union[bytes, str] + ) -> bool: """Validates that self.hash(password) == stored_hash. Args: @@ -942,7 +1032,7 @@ class AuthHandler(BaseHandler): stored_hash: Expected hash value. Returns: - Deferred(bool): Whether self.hash(password) == stored_hash. + Whether self.hash(password) == stored_hash. """ def _do_validate_hash(): @@ -958,11 +1048,56 @@ class AuthHandler(BaseHandler): if not isinstance(stored_hash, bytes): stored_hash = stored_hash.encode("ascii") - return defer_to_thread(self.hs.get_reactor(), _do_validate_hash) + return await defer_to_thread(self.hs.get_reactor(), _do_validate_hash) else: - return defer.succeed(False) + return False - def complete_sso_login( + def start_sso_ui_auth(self, redirect_url: str, session_id: str) -> str: + """ + Get the HTML for the SSO redirect confirmation page. + + Args: + redirect_url: The URL to redirect to the SSO provider. + session_id: The user interactive authentication session ID. + + Returns: + The HTML to render. + """ + session = self._get_session_info(session_id) + return self._sso_auth_confirm_template.render( + description=session["description"], redirect_url=redirect_url, + ) + + def complete_sso_ui_auth( + self, registered_user_id: str, session_id: str, request: SynapseRequest, + ): + """Having figured out a mxid for this user, complete the HTTP request + + Args: + registered_user_id: The registered user ID to complete SSO login for. + request: The request to complete. + client_redirect_url: The URL to which to redirect the user at the end of the + process. + """ + # Mark the stage of the authentication as successful. + sess = self._get_session_info(session_id) + creds = sess["creds"] + + # Save the user who authenticated with SSO, this will be used to ensure + # that the account be modified is also the person who logged in. + creds[LoginType.SSO] = registered_user_id + self._save_session(sess) + + # Render the HTML and return. + html_bytes = self._sso_auth_success_template.encode("utf-8") + request.setResponseCode(200) + request.setHeader(b"Content-Type", b"text/html; charset=utf-8") + request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) + + request.write(html_bytes) + finish_request(request) + + async def complete_sso_login( self, registered_user_id: str, request: SynapseRequest, @@ -976,6 +1111,32 @@ class AuthHandler(BaseHandler): client_redirect_url: The URL to which to redirect the user at the end of the process. """ + # If the account has been deactivated, do not proceed with the login + # flow. + deactivated = await self.store.get_user_deactivated_status(registered_user_id) + if deactivated: + html_bytes = self._sso_account_deactivated_template.encode("utf-8") + + request.setResponseCode(403) + request.setHeader(b"Content-Type", b"text/html; charset=utf-8") + request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) + request.write(html_bytes) + finish_request(request) + return + + self._complete_sso_login(registered_user_id, request, client_redirect_url) + + def _complete_sso_login( + self, + registered_user_id: str, + request: SynapseRequest, + client_redirect_url: str, + ): + """ + The synchronous portion of complete_sso_login. + + This exists purely for backwards compatibility of synapse.module_api.ModuleApi. + """ # Create a login token login_token = self.macaroon_gen.generate_short_term_login_token( registered_user_id @@ -1001,7 +1162,7 @@ class AuthHandler(BaseHandler): # URL we redirect users to. redirect_url_no_params = client_redirect_url.split("?")[0] - html = self._sso_redirect_confirm_template.render( + html_bytes = self._sso_redirect_confirm_template.render( display_url=redirect_url_no_params, redirect_url=redirect_url, server_name=self._server_name, @@ -1009,8 +1170,8 @@ class AuthHandler(BaseHandler): request.setResponseCode(200) request.setHeader(b"Content-Type", b"text/html; charset=utf-8") - request.setHeader(b"Content-Length", b"%d" % (len(html),)) - request.write(html) + request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) + request.write(html_bytes) finish_request(request) @staticmethod diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py new file mode 100644 index 0000000000..5cb3f9d133 --- /dev/null +++ b/synapse/handlers/cas_handler.py @@ -0,0 +1,221 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import xml.etree.ElementTree as ET +from typing import Dict, Optional, Tuple + +from six.moves import urllib + +from twisted.web.client import PartialDownloadError + +from synapse.api.errors import Codes, LoginError +from synapse.http.site import SynapseRequest +from synapse.types import UserID, map_username_to_mxid_localpart + +logger = logging.getLogger(__name__) + + +class CasHandler: + """ + Utility class for to handle the response from a CAS SSO service. + + Args: + hs (synapse.server.HomeServer) + """ + + def __init__(self, hs): + self._hostname = hs.hostname + self._auth_handler = hs.get_auth_handler() + self._registration_handler = hs.get_registration_handler() + + self._cas_server_url = hs.config.cas_server_url + self._cas_service_url = hs.config.cas_service_url + self._cas_displayname_attribute = hs.config.cas_displayname_attribute + self._cas_required_attributes = hs.config.cas_required_attributes + + self._http_client = hs.get_proxied_http_client() + + def _build_service_param(self, args: Dict[str, str]) -> str: + """ + Generates a value to use as the "service" parameter when redirecting or + querying the CAS service. + + Args: + args: Additional arguments to include in the final redirect URL. + + Returns: + The URL to use as a "service" parameter. + """ + return "%s%s?%s" % ( + self._cas_service_url, + "/_matrix/client/r0/login/cas/ticket", + urllib.parse.urlencode(args), + ) + + async def _validate_ticket( + self, ticket: str, service_args: Dict[str, str] + ) -> Tuple[str, Optional[str]]: + """ + Validate a CAS ticket with the server, parse the response, and return the user and display name. + + Args: + ticket: The CAS ticket from the client. + service_args: Additional arguments to include in the service URL. + Should be the same as those passed to `get_redirect_url`. + """ + uri = self._cas_server_url + "/proxyValidate" + args = { + "ticket": ticket, + "service": self._build_service_param(service_args), + } + try: + body = await self._http_client.get_raw(uri, args) + except PartialDownloadError as pde: + # Twisted raises this error if the connection is closed, + # even if that's being used old-http style to signal end-of-data + body = pde.response + + user, attributes = self._parse_cas_response(body) + displayname = attributes.pop(self._cas_displayname_attribute, None) + + for required_attribute, required_value in self._cas_required_attributes.items(): + # If required attribute was not in CAS Response - Forbidden + if required_attribute not in attributes: + raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) + + # Also need to check value + if required_value is not None: + actual_value = attributes[required_attribute] + # If required attribute value does not match expected - Forbidden + if required_value != actual_value: + raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) + + return user, displayname + + def _parse_cas_response( + self, cas_response_body: str + ) -> Tuple[str, Dict[str, Optional[str]]]: + """ + Retrieve the user and other parameters from the CAS response. + + Args: + cas_response_body: The response from the CAS query. + + Returns: + A tuple of the user and a mapping of other attributes. + """ + user = None + attributes = {} + try: + root = ET.fromstring(cas_response_body) + if not root.tag.endswith("serviceResponse"): + raise Exception("root of CAS response is not serviceResponse") + success = root[0].tag.endswith("authenticationSuccess") + for child in root[0]: + if child.tag.endswith("user"): + user = child.text + if child.tag.endswith("attributes"): + for attribute in child: + # ElementTree library expands the namespace in + # attribute tags to the full URL of the namespace. + # We don't care about namespace here and it will always + # be encased in curly braces, so we remove them. + tag = attribute.tag + if "}" in tag: + tag = tag.split("}")[1] + attributes[tag] = attribute.text + if user is None: + raise Exception("CAS response does not contain user") + except Exception: + logger.exception("Error parsing CAS response") + raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) + if not success: + raise LoginError( + 401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED + ) + return user, attributes + + def get_redirect_url(self, service_args: Dict[str, str]) -> str: + """ + Generates a URL for the CAS server where the client should be redirected. + + Args: + service_args: Additional arguments to include in the final redirect URL. + + Returns: + The URL to redirect the client to. + """ + args = urllib.parse.urlencode( + {"service": self._build_service_param(service_args)} + ) + + return "%s/login?%s" % (self._cas_server_url, args) + + async def handle_ticket( + self, + request: SynapseRequest, + ticket: str, + client_redirect_url: Optional[str], + session: Optional[str], + ) -> None: + """ + Called once the user has successfully authenticated with the SSO. + Validates a CAS ticket sent by the client and completes the auth process. + + If the user interactive authentication session is provided, marks the + UI Auth session as complete, then returns an HTML page notifying the + user they are done. + + Otherwise, this registers the user if necessary, and then returns a + redirect (with a login token) to the client. + + Args: + request: the incoming request from the browser. We'll + respond to it with a redirect or an HTML page. + + ticket: The CAS ticket provided by the client. + + client_redirect_url: the redirectUrl parameter from the `/cas/ticket` HTTP request, if given. + This should be the same as the redirectUrl from the original `/login/sso/redirect` request. + + session: The session parameter from the `/cas/ticket` HTTP request, if given. + This should be the UI Auth session id. + """ + args = {} + if client_redirect_url: + args["redirectUrl"] = client_redirect_url + if session: + args["session"] = session + username, user_display_name = await self._validate_ticket(ticket, args) + + localpart = map_username_to_mxid_localpart(username) + user_id = UserID(localpart, self._hostname).to_string() + registered_user_id = await self._auth_handler.check_user_exists(user_id) + + if session: + self._auth_handler.complete_sso_ui_auth( + registered_user_id, session, request, + ) + + else: + if not registered_user_id: + registered_user_id = await self._registration_handler.register_user( + localpart=localpart, default_display_name=user_display_name + ) + + await self._auth_handler.complete_sso_login( + registered_user_id, request, client_redirect_url + ) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 993499f446..9bd941b5a0 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -338,8 +338,10 @@ class DeviceHandler(DeviceWorkerHandler): else: raise - yield self._auth_handler.delete_access_tokens_for_user( - user_id, device_id=device_id + yield defer.ensureDeferred( + self._auth_handler.delete_access_tokens_for_user( + user_id, device_id=device_id + ) ) yield self.store.delete_e2e_keys_by_device(user_id=user_id, device_id=device_id) @@ -391,8 +393,10 @@ class DeviceHandler(DeviceWorkerHandler): # Delete access tokens and e2e keys for each device. Not optimised as it is not # considered as part of a critical path. for device_id in device_ids: - yield self._auth_handler.delete_access_tokens_for_user( - user_id, device_id=device_id + yield defer.ensureDeferred( + self._auth_handler.delete_access_tokens_for_user( + user_id, device_id=device_id + ) ) yield self.store.delete_e2e_keys_by_device( user_id=user_id, device_id=device_id diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 1d842c369b..53e5f585d9 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -127,7 +127,11 @@ class DirectoryHandler(BaseHandler): errcode=Codes.EXCLUSIVE, ) else: - if self.require_membership and check_membership: + # Server admins are not subject to the same constraints as normal + # users when creating an alias (e.g. being in the room). + is_admin = yield self.auth.is_server_admin(requester.user) + + if (self.require_membership and check_membership) and not is_admin: rooms_for_user = yield self.store.get_rooms_for_user(user_id) if room_id not in rooms_for_user: raise AuthError( diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 38ab6a8fc3..c7aa7acf3b 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -49,6 +49,7 @@ from synapse.event_auth import auth_types_for_event from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.events.validator import EventValidator +from synapse.handlers._base import BaseHandler from synapse.logging.context import ( make_deferred_yieldable, nested_logging_context, @@ -69,10 +70,9 @@ from synapse.types import JsonDict, StateMap, UserID, get_domain_from_id from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.distributor import user_joined_room from synapse.util.retryutils import NotRetryingDestination +from synapse.util.stringutils import shortstr from synapse.visibility import filter_events_for_server -from ._base import BaseHandler - logger = logging.getLogger(__name__) @@ -93,27 +93,6 @@ class _NewEventInfo: auth_events = attr.ib(type=Optional[StateMap[EventBase]], default=None) -def shortstr(iterable, maxitems=5): - """If iterable has maxitems or fewer, return the stringification of a list - containing those items. - - Otherwise, return the stringification of a a list with the first maxitems items, - followed by "...". - - Args: - iterable (Iterable): iterable to truncate - maxitems (int): number of items to return before truncating - - Returns: - unicode - """ - - items = list(itertools.islice(iterable, maxitems + 1)) - if len(items) <= maxitems: - return str(items) - return "[" + ", ".join(repr(r) for r in items[:maxitems]) + ", ...]" - - class FederationHandler(BaseHandler): """Handles events that originated from federation. Responsible for: diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 23f07832e7..0f0e632b62 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -18,7 +18,7 @@ """Utilities for interacting with Identity Servers""" import logging -import urllib +import urllib.parse from canonicaljson import json from signedjson.key import decode_verify_key_bytes diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index b743fc2dcc..522271eed1 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -852,6 +852,38 @@ class EventCreationHandler(object): ) @defer.inlineCallbacks + def _validate_canonical_alias( + self, directory_handler, room_alias_str, expected_room_id + ): + """ + Ensure that the given room alias points to the expected room ID. + + Args: + directory_handler: The directory handler object. + room_alias_str: The room alias to check. + expected_room_id: The room ID that the alias should point to. + """ + room_alias = RoomAlias.from_string(room_alias_str) + try: + mapping = yield directory_handler.get_association(room_alias) + except SynapseError as e: + # Turn M_NOT_FOUND errors into M_BAD_ALIAS errors. + if e.errcode == Codes.NOT_FOUND: + raise SynapseError( + 400, + "Room alias %s does not point to the room" % (room_alias_str,), + Codes.BAD_ALIAS, + ) + raise + + if mapping["room_id"] != expected_room_id: + raise SynapseError( + 400, + "Room alias %s does not point to the room" % (room_alias_str,), + Codes.BAD_ALIAS, + ) + + @defer.inlineCallbacks def persist_and_notify_client_event( self, requester, event, context, ratelimit=True, extra_users=[] ): @@ -905,15 +937,9 @@ class EventCreationHandler(object): room_alias_str = event.content.get("alias", None) directory_handler = self.hs.get_handlers().directory_handler if room_alias_str and room_alias_str != original_alias: - room_alias = RoomAlias.from_string(room_alias_str) - mapping = yield directory_handler.get_association(room_alias) - - if mapping["room_id"] != event.room_id: - raise SynapseError( - 400, - "Room alias %s does not point to the room" % (room_alias_str,), - Codes.BAD_ALIAS, - ) + yield self._validate_canonical_alias( + directory_handler, room_alias_str, event.room_id + ) # Check that alt_aliases is the proper form. alt_aliases = event.content.get("alt_aliases", []) @@ -931,16 +957,9 @@ class EventCreationHandler(object): new_alt_aliases = set(alt_aliases) - set(original_alt_aliases) if new_alt_aliases: for alias_str in new_alt_aliases: - room_alias = RoomAlias.from_string(alias_str) - mapping = yield directory_handler.get_association(room_alias) - - if mapping["room_id"] != event.room_id: - raise SynapseError( - 400, - "Room alias %s does not point to the room" - % (room_alias_str,), - Codes.BAD_ALIAS, - ) + yield self._validate_canonical_alias( + directory_handler, alias_str, event.room_id + ) federation_handler = self.hs.get_handlers().federation_handler diff --git a/synapse/handlers/password_policy.py b/synapse/handlers/password_policy.py new file mode 100644 index 0000000000..d06b110269 --- /dev/null +++ b/synapse/handlers/password_policy.py @@ -0,0 +1,93 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import re + +from synapse.api.errors import Codes, PasswordRefusedError + +logger = logging.getLogger(__name__) + + +class PasswordPolicyHandler(object): + def __init__(self, hs): + self.policy = hs.config.password_policy + self.enabled = hs.config.password_policy_enabled + + # Regexps for the spec'd policy parameters. + self.regexp_digit = re.compile("[0-9]") + self.regexp_symbol = re.compile("[^a-zA-Z0-9]") + self.regexp_uppercase = re.compile("[A-Z]") + self.regexp_lowercase = re.compile("[a-z]") + + def validate_password(self, password): + """Checks whether a given password complies with the server's policy. + + Args: + password (str): The password to check against the server's policy. + + Raises: + PasswordRefusedError: The password doesn't comply with the server's policy. + """ + + if not self.enabled: + return + + minimum_accepted_length = self.policy.get("minimum_length", 0) + if len(password) < minimum_accepted_length: + raise PasswordRefusedError( + msg=( + "The password must be at least %d characters long" + % minimum_accepted_length + ), + errcode=Codes.PASSWORD_TOO_SHORT, + ) + + if ( + self.policy.get("require_digit", False) + and self.regexp_digit.search(password) is None + ): + raise PasswordRefusedError( + msg="The password must include at least one digit", + errcode=Codes.PASSWORD_NO_DIGIT, + ) + + if ( + self.policy.get("require_symbol", False) + and self.regexp_symbol.search(password) is None + ): + raise PasswordRefusedError( + msg="The password must include at least one symbol", + errcode=Codes.PASSWORD_NO_SYMBOL, + ) + + if ( + self.policy.get("require_uppercase", False) + and self.regexp_uppercase.search(password) is None + ): + raise PasswordRefusedError( + msg="The password must include at least one uppercase letter", + errcode=Codes.PASSWORD_NO_UPPERCASE, + ) + + if ( + self.policy.get("require_lowercase", False) + and self.regexp_lowercase.search(password) is None + ): + raise PasswordRefusedError( + msg="The password must include at least one lowercase letter", + errcode=Codes.PASSWORD_NO_LOWERCASE, + ) diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 5526015ddb..6912165622 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -747,7 +747,7 @@ class PresenceHandler(object): return False - async def get_all_presence_updates(self, last_id, current_id): + async def get_all_presence_updates(self, last_id, current_id, limit): """ Gets a list of presence update rows from between the given stream ids. Each row has: @@ -762,7 +762,7 @@ class PresenceHandler(object): """ # TODO(markjh): replicate the unpersisted changes. # This could use the in-memory stores for recent changes. - rows = await self.store.get_all_presence_updates(last_id, current_id) + rows = await self.store.get_all_presence_updates(last_id, current_id, limit) return rows def notify_new_event(self): diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 50ce0c585b..6aa1c0f5e0 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -157,6 +157,15 @@ class BaseProfileHandler(BaseHandler): if not by_admin and target_user != requester.user: raise AuthError(400, "Cannot set another user's displayname") + if not by_admin and not self.hs.config.enable_set_displayname: + profile = yield self.store.get_profileinfo(target_user.localpart) + if profile.display_name: + raise SynapseError( + 400, + "Changing display name is disabled on this server", + Codes.FORBIDDEN, + ) + if len(new_displayname) > MAX_DISPLAYNAME_LEN: raise SynapseError( 400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,) @@ -218,6 +227,13 @@ class BaseProfileHandler(BaseHandler): if not by_admin and target_user != requester.user: raise AuthError(400, "Cannot set another user's avatar_url") + if not by_admin and not self.hs.config.enable_set_avatar_url: + profile = yield self.store.get_profileinfo(target_user.localpart) + if profile.avatar_url: + raise SynapseError( + 400, "Changing avatar is disabled on this server", Codes.FORBIDDEN + ) + if len(new_avatar_url) > MAX_AVATAR_URL_LEN: raise SynapseError( 400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 7ffc194f0c..3a65b46ecd 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -166,7 +166,9 @@ class RegistrationHandler(BaseHandler): yield self.auth.check_auth_blocking(threepid=threepid) password_hash = None if password: - password_hash = yield self._auth_handler.hash(password) + password_hash = yield defer.ensureDeferred( + self._auth_handler.hash(password) + ) if localpart is not None: yield self.check_username(localpart, guest_access_token=guest_access_token) @@ -540,8 +542,10 @@ class RegistrationHandler(BaseHandler): user_id, ["guest = true"] ) else: - access_token = yield self._auth_handler.get_access_token_for_user_id( - user_id, device_id=device_id, valid_until_ms=valid_until_ms + access_token = yield defer.ensureDeferred( + self._auth_handler.get_access_token_for_user_id( + user_id, device_id=device_id, valid_until_ms=valid_until_ms + ) ) return (device_id, access_token) @@ -617,8 +621,13 @@ class RegistrationHandler(BaseHandler): logger.info("Can't add incomplete 3pid") return - yield self._auth_handler.add_threepid( - user_id, threepid["medium"], threepid["address"], threepid["validated_at"] + yield defer.ensureDeferred( + self._auth_handler.add_threepid( + user_id, + threepid["medium"], + threepid["address"], + threepid["validated_at"], + ) ) # And we add an email pusher for them by default, but only @@ -670,6 +679,11 @@ class RegistrationHandler(BaseHandler): return None raise - yield self._auth_handler.add_threepid( - user_id, threepid["medium"], threepid["address"], threepid["validated_at"] + yield defer.ensureDeferred( + self._auth_handler.add_threepid( + user_id, + threepid["medium"], + threepid["address"], + threepid["validated_at"], + ) ) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index f580ab2e9f..3d10e4b2d9 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -645,6 +645,13 @@ class RoomCreationHandler(BaseHandler): check_membership=False, ) + if is_public: + if not self.config.is_publishing_room_allowed(user_id, room_id, room_alias): + # Lets just return a generic message, as there may be all sorts of + # reasons why we said no. TODO: Allow configurable error messages + # per alias creation rule? + raise SynapseError(403, "Not allowed to publish room") + preset_config = config.get( "preset", RoomCreationPreset.PRIVATE_CHAT @@ -806,6 +813,7 @@ class RoomCreationHandler(BaseHandler): EventTypes.RoomAvatar: 50, EventTypes.Tombstone: 100, EventTypes.ServerACL: 100, + EventTypes.RoomEncryption: 100, }, "events_default": 0, "state_default": 50, diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index 0b7d3da680..59c9906b31 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -15,6 +15,7 @@ import logging from collections import namedtuple +from typing import Any, Dict, Optional from six import iteritems @@ -105,22 +106,22 @@ class RoomListHandler(BaseHandler): @defer.inlineCallbacks def _get_public_room_list( self, - limit=None, - since_token=None, - search_filter=None, - network_tuple=EMPTY_THIRD_PARTY_ID, - from_federation=False, - ): + limit: Optional[int] = None, + since_token: Optional[str] = None, + search_filter: Optional[Dict] = None, + network_tuple: ThirdPartyInstanceID = EMPTY_THIRD_PARTY_ID, + from_federation: bool = False, + ) -> Dict[str, Any]: """Generate a public room list. Args: - limit (int|None): Maximum amount of rooms to return. - since_token (str|None) - search_filter (dict|None): Dictionary to filter rooms by. - network_tuple (ThirdPartyInstanceID): Which public list to use. + limit: Maximum amount of rooms to return. + since_token: + search_filter: Dictionary to filter rooms by. + network_tuple: Which public list to use. This can be (None, None) to indicate the main list, or a particular appservice and network id to use an appservice specific one. Setting to None returns all public rooms across all lists. - from_federation (bool): Whether this request originated from a + from_federation: Whether this request originated from a federating server or a client. Used for room filtering. """ diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 4260426369..c3ee8db4f0 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -519,6 +519,9 @@ class RoomMemberHandler(object): yield self.store.set_room_is_public(old_room_id, False) yield self.store.set_room_is_public(room_id, True) + # Transfer alias mappings in the room directory + yield self.store.update_aliases_for_room(old_room_id, room_id) + # Check if any groups we own contain the predecessor room local_group_ids = yield self.store.get_local_groups_for_room(old_room_id) for group_id in local_group_ids: diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index 72c109981b..7c9454b504 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -14,7 +14,7 @@ # limitations under the License. import logging import re -from typing import Tuple +from typing import Optional, Tuple import attr import saml2 @@ -26,6 +26,7 @@ from synapse.config import ConfigError from synapse.http.server import finish_request from synapse.http.servlet import parse_string from synapse.module_api import ModuleApi +from synapse.module_api.errors import RedirectException from synapse.types import ( UserID, map_username_to_mxid_localpart, @@ -43,11 +44,15 @@ class Saml2SessionData: # time the session was created, in milliseconds creation_time = attr.ib() + # The user interactive authentication session ID associated with this SAML + # session (or None if this SAML session is for an initial login). + ui_auth_session_id = attr.ib(type=Optional[str], default=None) class SamlHandler: def __init__(self, hs): self._saml_client = Saml2Client(hs.config.saml2_sp_config) + self._auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() self._registration_handler = hs.get_registration_handler() @@ -76,12 +81,14 @@ class SamlHandler: self._error_html_content = hs.config.saml2_error_html_content - def handle_redirect_request(self, client_redirect_url): + def handle_redirect_request(self, client_redirect_url, ui_auth_session_id=None): """Handle an incoming request to /login/sso/redirect Args: client_redirect_url (bytes): the URL that we should redirect the client to when everything is done + ui_auth_session_id (Optional[str]): The session ID of the ongoing UI Auth (or + None if this is a login). Returns: bytes: URL to redirect to @@ -91,7 +98,9 @@ class SamlHandler: ) now = self._clock.time_msec() - self._outstanding_requests_dict[reqid] = Saml2SessionData(creation_time=now) + self._outstanding_requests_dict[reqid] = Saml2SessionData( + creation_time=now, ui_auth_session_id=ui_auth_session_id, + ) for key, value in info["headers"]: if key == "Location": @@ -118,7 +127,12 @@ class SamlHandler: self.expire_sessions() try: - user_id = await self._map_saml_response_to_user(resp_bytes, relay_state) + user_id, current_session = await self._map_saml_response_to_user( + resp_bytes, relay_state + ) + except RedirectException: + # Raise the exception as per the wishes of the SAML module response + raise except Exception as e: # If decoding the response or mapping it to a user failed, then log the # error and tell the user that something went wrong. @@ -133,9 +147,28 @@ class SamlHandler: finish_request(request) return - self._auth_handler.complete_sso_login(user_id, request, relay_state) + # Complete the interactive auth session or the login. + if current_session and current_session.ui_auth_session_id: + self._auth_handler.complete_sso_ui_auth( + user_id, current_session.ui_auth_session_id, request + ) + + else: + await self._auth_handler.complete_sso_login(user_id, request, relay_state) + + async def _map_saml_response_to_user( + self, resp_bytes: str, client_redirect_url: str + ) -> Tuple[str, Optional[Saml2SessionData]]: + """ + Given a sample response, retrieve the cached session and user for it. - async def _map_saml_response_to_user(self, resp_bytes, client_redirect_url): + Args: + resp_bytes: The SAML response. + client_redirect_url: The redirect URL passed in by the client. + + Returns: + Tuple of the user ID and SAML session associated with this response. + """ try: saml2_auth = self._saml_client.parse_authn_request_response( resp_bytes, @@ -163,7 +196,9 @@ class SamlHandler: logger.info("SAML2 mapped attributes: %s", saml2_auth.ava) - self._outstanding_requests_dict.pop(saml2_auth.in_response_to, None) + current_session = self._outstanding_requests_dict.pop( + saml2_auth.in_response_to, None + ) remote_user_id = self._user_mapping_provider.get_remote_user_id( saml2_auth, client_redirect_url @@ -184,7 +219,7 @@ class SamlHandler: ) if registered_user_id is not None: logger.info("Found existing mapping %s", registered_user_id) - return registered_user_id + return registered_user_id, current_session # backwards-compatibility hack: see if there is an existing user with a # suitable mapping from the uid @@ -209,7 +244,7 @@ class SamlHandler: await self._datastore.record_user_external_id( self._auth_provider_id, remote_user_id, registered_user_id ) - return registered_user_id + return registered_user_id, current_session # Map saml response to user attributes using the configured mapping provider for i in range(1000): @@ -256,7 +291,7 @@ class SamlHandler: await self._datastore.record_user_external_id( self._auth_provider_id, remote_user_id, registered_user_id ) - return registered_user_id + return registered_user_id, current_session def expire_sessions(self): expire_before = self._clock.time_msec() - self._saml2_session_lifetime diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py index 12657ca698..63d8f9aa0d 100644 --- a/synapse/handlers/set_password.py +++ b/synapse/handlers/set_password.py @@ -15,8 +15,6 @@ import logging from typing import Optional -from twisted.internet import defer - from synapse.api.errors import Codes, StoreError, SynapseError from synapse.types import Requester @@ -32,9 +30,9 @@ class SetPasswordHandler(BaseHandler): super(SetPasswordHandler, self).__init__(hs) self._auth_handler = hs.get_auth_handler() self._device_handler = hs.get_device_handler() + self._password_policy_handler = hs.get_password_policy_handler() - @defer.inlineCallbacks - def set_password( + async def set_password( self, user_id: str, new_password: str, @@ -44,10 +42,11 @@ class SetPasswordHandler(BaseHandler): if not self.hs.config.password_localdb_enabled: raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN) - password_hash = yield self._auth_handler.hash(new_password) + self._password_policy_handler.validate_password(new_password) + password_hash = await self._auth_handler.hash(new_password) try: - yield self.store.user_set_password_hash(user_id, password_hash) + await self.store.user_set_password_hash(user_id, password_hash) except StoreError as e: if e.code == 404: raise SynapseError(404, "Unknown user", Codes.NOT_FOUND) @@ -59,12 +58,12 @@ class SetPasswordHandler(BaseHandler): except_access_token_id = requester.access_token_id if requester else None # First delete all of their other devices. - yield self._device_handler.delete_all_devices_for_user( + await self._device_handler.delete_all_devices_for_user( user_id, except_device_id=except_device_id ) # and now delete any access tokens which weren't associated with # devices (or were associated with this device). - yield self._auth_handler.delete_access_tokens_for_user( + await self._auth_handler.delete_access_tokens_for_user( user_id, except_token_id=except_access_token_id ) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index cfd5dfc9e5..4f76b7a743 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -26,7 +26,7 @@ from prometheus_client import Counter from synapse.api.constants import EventTypes, Membership from synapse.api.filtering import FilterCollection from synapse.events import EventBase -from synapse.logging.context import LoggingContext +from synapse.logging.context import current_context from synapse.push.clientformat import format_push_rules_for_user from synapse.storage.roommember import MemberSummary from synapse.storage.state import StateFilter @@ -301,7 +301,7 @@ class SyncHandler(object): else: sync_type = "incremental_sync" - context = LoggingContext.current_context() + context = current_context() if context: context.tag = sync_type @@ -1639,7 +1639,7 @@ class SyncHandler(object): ) # We loop through all room ids, even if there are no new events, in case - # there are non room events taht we need to notify about. + # there are non room events that we need to notify about. for room_id in sync_result_builder.joined_room_ids: room_entry = room_to_events.get(room_id, None) diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 391bceb0c4..c7bc14c623 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -15,6 +15,7 @@ import logging from collections import namedtuple +from typing import List from twisted.internet import defer @@ -257,7 +258,13 @@ class TypingHandler(object): "typing_key", self._latest_room_serial, rooms=[member.room_id] ) - async def get_all_typing_updates(self, last_id, current_id): + async def get_all_typing_updates( + self, last_id: int, current_id: int, limit: int + ) -> List[dict]: + """Get up to `limit` typing updates between the given tokens, earliest + updates first. + """ + if last_id == current_id: return [] @@ -275,7 +282,7 @@ class TypingHandler(object): typing = self._room_typing[room_id] rows.append((serial, room_id, list(typing))) rows.sort() - return rows + return rows[:limit] def get_current_token(self): return self._latest_room_serial |