diff options
Diffstat (limited to 'synapse/handlers')
-rw-r--r-- | synapse/handlers/account_validity.py | 6 | ||||
-rw-r--r-- | synapse/handlers/auth.py | 265 | ||||
-rw-r--r-- | synapse/handlers/device.py | 14 | ||||
-rw-r--r-- | synapse/handlers/directory.py | 121 | ||||
-rw-r--r-- | synapse/handlers/e2e_keys.py | 166 | ||||
-rw-r--r-- | synapse/handlers/e2e_room_keys.py | 7 | ||||
-rw-r--r-- | synapse/handlers/message.py | 49 | ||||
-rw-r--r-- | synapse/handlers/room.py | 16 | ||||
-rw-r--r-- | synapse/handlers/saml_handler.py | 25 | ||||
-rw-r--r-- | synapse/handlers/set_password.py | 41 | ||||
-rw-r--r-- | synapse/handlers/sync.py | 7 |
11 files changed, 530 insertions, 187 deletions
diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py index 829f52eca1..590135d19c 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py @@ -44,7 +44,11 @@ class AccountValidityHandler(object): self._account_validity = self.hs.config.account_validity - if self._account_validity.renew_by_email_enabled and load_jinja2_templates: + if ( + self._account_validity.enabled + and self._account_validity.renew_by_email_enabled + and load_jinja2_templates + ): # Don't do email-specific configuration if renewal by email is disabled. try: app_name = self.hs.config.email_app_name diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 48a88d3c2a..7860f9625e 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -17,9 +17,11 @@ import logging import time import unicodedata +import urllib.parse +from typing import Any, Dict, Iterable, List, Optional import attr -import bcrypt +import bcrypt # type: ignore[import] import pymacaroons from twisted.internet import defer @@ -38,9 +40,12 @@ from synapse.api.errors import ( from synapse.api.ratelimiting import Ratelimiter from synapse.handlers.ui_auth import INTERACTIVE_AUTH_CHECKERS from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker +from synapse.http.server import finish_request +from synapse.http.site import SynapseRequest from synapse.logging.context import defer_to_thread from synapse.module_api import ModuleApi -from synapse.types import UserID +from synapse.push.mailer import load_jinja2_templates +from synapse.types import Requester, UserID from synapse.util.caches.expiringcache import ExpiringCache from ._base import BaseHandler @@ -58,11 +63,11 @@ class AuthHandler(BaseHandler): """ super(AuthHandler, self).__init__(hs) - self.checkers = {} # type: dict[str, UserInteractiveAuthChecker] + self.checkers = {} # type: Dict[str, UserInteractiveAuthChecker] for auth_checker_class in INTERACTIVE_AUTH_CHECKERS: inst = auth_checker_class(hs) if inst.is_enabled(): - self.checkers[inst.AUTH_TYPE] = inst + self.checkers[inst.AUTH_TYPE] = inst # type: ignore self.bcrypt_rounds = hs.config.bcrypt_rounds @@ -108,8 +113,20 @@ class AuthHandler(BaseHandler): self._clock = self.hs.get_clock() + # Load the SSO redirect confirmation page HTML template + self._sso_redirect_confirm_template = load_jinja2_templates( + hs.config.sso_redirect_confirm_template_dir, ["sso_redirect_confirm.html"], + )[0] + + self._server_name = hs.config.server_name + + # cast to tuple for use with str.startswith + self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist) + @defer.inlineCallbacks - def validate_user_via_ui_auth(self, requester, request_body, clientip): + def validate_user_via_ui_auth( + self, requester: Requester, request_body: Dict[str, Any], clientip: str + ): """ Checks that the user is who they claim to be, via a UI auth. @@ -118,11 +135,11 @@ class AuthHandler(BaseHandler): that it isn't stolen by re-authenticating them. Args: - requester (Requester): The user, as given by the access token + requester: The user, as given by the access token - request_body (dict): The body of the request sent by the client + request_body: The body of the request sent by the client - clientip (str): The IP address of the client. + clientip: The IP address of the client. Returns: defer.Deferred[dict]: the parameters for this request (which may @@ -193,7 +210,9 @@ class AuthHandler(BaseHandler): return self.checkers.keys() @defer.inlineCallbacks - def check_auth(self, flows, clientdict, clientip): + def check_auth( + self, flows: List[List[str]], clientdict: Dict[str, Any], clientip: str + ): """ Takes a dictionary sent by the client in the login / registration protocol and handles the User-Interactive Auth flow. @@ -208,14 +227,14 @@ class AuthHandler(BaseHandler): decorator. Args: - flows (list): A list of login flows. Each flow is an ordered list of - strings representing auth-types. At least one full - flow must be completed in order for auth to be successful. + flows: A list of login flows. Each flow is an ordered list of + strings representing auth-types. At least one full + flow must be completed in order for auth to be successful. clientdict: The dictionary from the client root level, not the 'auth' key: this method prompts for auth if none is sent. - clientip (str): The IP address of the client. + clientip: The IP address of the client. Returns: defer.Deferred[dict, dict, str]: a deferred tuple of @@ -235,7 +254,7 @@ class AuthHandler(BaseHandler): """ authdict = None - sid = None + sid = None # type: Optional[str] if clientdict and "auth" in clientdict: authdict = clientdict["auth"] del clientdict["auth"] @@ -268,9 +287,9 @@ class AuthHandler(BaseHandler): creds = session["creds"] # check auth type currently being presented - errordict = {} + errordict = {} # type: Dict[str, Any] if "type" in authdict: - login_type = authdict["type"] + login_type = authdict["type"] # type: str try: result = yield self._check_auth_dict(authdict, clientip) if result: @@ -311,7 +330,7 @@ class AuthHandler(BaseHandler): raise InteractiveAuthIncompleteError(ret) @defer.inlineCallbacks - def add_oob_auth(self, stagetype, authdict, clientip): + def add_oob_auth(self, stagetype: str, authdict: Dict[str, Any], clientip: str): """ Adds the result of out-of-band authentication into an existing auth session. Currently used for adding the result of fallback auth. @@ -333,7 +352,7 @@ class AuthHandler(BaseHandler): return True return False - def get_session_id(self, clientdict): + def get_session_id(self, clientdict: Dict[str, Any]) -> Optional[str]: """ Gets the session ID for a client given the client dictionary @@ -341,7 +360,7 @@ class AuthHandler(BaseHandler): clientdict: The dictionary sent by the client in the request Returns: - str|None: The string session ID the client sent. If the client did + The string session ID the client sent. If the client did not send a session ID, returns None. """ sid = None @@ -351,40 +370,42 @@ class AuthHandler(BaseHandler): sid = authdict["session"] return sid - def set_session_data(self, session_id, key, value): + def set_session_data(self, session_id: str, key: str, value: Any) -> None: """ Store a key-value pair into the sessions data associated with this request. This data is stored server-side and cannot be modified by the client. Args: - session_id (string): The ID of this session as returned from check_auth - key (string): The key to store the data under - value (any): The data to store + session_id: The ID of this session as returned from check_auth + key: The key to store the data under + value: The data to store """ sess = self._get_session_info(session_id) sess.setdefault("serverdict", {})[key] = value self._save_session(sess) - def get_session_data(self, session_id, key, default=None): + def get_session_data( + self, session_id: str, key: str, default: Optional[Any] = None + ) -> Any: """ Retrieve data stored with set_session_data Args: - session_id (string): The ID of this session as returned from check_auth - key (string): The key to store the data under - default (any): Value to return if the key has not been set + session_id: The ID of this session as returned from check_auth + key: The key to store the data under + default: Value to return if the key has not been set """ sess = self._get_session_info(session_id) return sess.setdefault("serverdict", {}).get(key, default) @defer.inlineCallbacks - def _check_auth_dict(self, authdict, clientip): + def _check_auth_dict(self, authdict: Dict[str, Any], clientip: str): """Attempt to validate the auth dict provided by a client Args: - authdict (object): auth dict provided by the client - clientip (str): IP address of the client + authdict: auth dict provided by the client + clientip: IP address of the client Returns: Deferred: result of the stage verification. @@ -410,10 +431,10 @@ class AuthHandler(BaseHandler): (canonical_id, callback) = yield self.validate_login(user_id, authdict) return canonical_id - def _get_params_recaptcha(self): + def _get_params_recaptcha(self) -> dict: return {"public_key": self.hs.config.recaptcha_public_key} - def _get_params_terms(self): + def _get_params_terms(self) -> dict: return { "policies": { "privacy_policy": { @@ -430,7 +451,9 @@ class AuthHandler(BaseHandler): } } - def _auth_dict_for_flows(self, flows, session): + def _auth_dict_for_flows( + self, flows: List[List[str]], session: Dict[str, Any] + ) -> Dict[str, Any]: public_flows = [] for f in flows: public_flows.append(f) @@ -440,7 +463,7 @@ class AuthHandler(BaseHandler): LoginType.TERMS: self._get_params_terms, } - params = {} + params = {} # type: Dict[str, Any] for f in public_flows: for stage in f: @@ -453,7 +476,13 @@ class AuthHandler(BaseHandler): "params": params, } - def _get_session_info(self, session_id): + def _get_session_info(self, session_id: Optional[str]) -> dict: + """ + Gets or creates a session given a session ID. + + The session can be used to track data across multiple requests, e.g. for + interactive authentication. + """ if session_id not in self.sessions: session_id = None @@ -466,7 +495,9 @@ class AuthHandler(BaseHandler): return self.sessions[session_id] @defer.inlineCallbacks - def get_access_token_for_user_id(self, user_id, device_id, valid_until_ms): + def get_access_token_for_user_id( + self, user_id: str, device_id: Optional[str], valid_until_ms: Optional[int] + ): """ Creates a new access token for the user with the given user ID. @@ -476,11 +507,11 @@ class AuthHandler(BaseHandler): The device will be recorded in the table if it is not there already. Args: - user_id (str): canonical User ID - device_id (str|None): the device ID to associate with the tokens. + user_id: canonical User ID + device_id: the device ID to associate with the tokens. None to leave the tokens unassociated with a device (deprecated: we should always have a device ID) - valid_until_ms (int|None): when the token is valid until. None for + valid_until_ms: when the token is valid until. None for no expiry. Returns: The access token for the user's session. @@ -515,13 +546,13 @@ class AuthHandler(BaseHandler): return access_token @defer.inlineCallbacks - def check_user_exists(self, user_id): + def check_user_exists(self, user_id: str): """ Checks to see if a user with the given id exists. Will check case insensitively, but return None if there are multiple inexact matches. Args: - (unicode|bytes) user_id: complete @user:id + user_id: complete @user:id Returns: defer.Deferred: (unicode) canonical_user_id, or None if zero or @@ -536,7 +567,7 @@ class AuthHandler(BaseHandler): return None @defer.inlineCallbacks - def _find_user_id_and_pwd_hash(self, user_id): + def _find_user_id_and_pwd_hash(self, user_id: str): """Checks to see if a user with the given id exists. Will check case insensitively, but will return None if there are multiple inexact matches. @@ -566,7 +597,7 @@ class AuthHandler(BaseHandler): ) return result - def get_supported_login_types(self): + def get_supported_login_types(self) -> Iterable[str]: """Get a the login types supported for the /login API By default this is just 'm.login.password' (unless password_enabled is @@ -574,20 +605,20 @@ class AuthHandler(BaseHandler): other login types. Returns: - Iterable[str]: login types + login types """ return self._supported_login_types @defer.inlineCallbacks - def validate_login(self, username, login_submission): + def validate_login(self, username: str, login_submission: Dict[str, Any]): """Authenticates the user for the /login API Also used by the user-interactive auth flow to validate m.login.password auth types. Args: - username (str): username supplied by the user - login_submission (dict): the whole of the login submission + username: username supplied by the user + login_submission: the whole of the login submission (including 'type' and other relevant fields) Returns: Deferred[str, func]: canonical user id, and optional callback @@ -675,13 +706,13 @@ class AuthHandler(BaseHandler): raise LoginError(403, "Invalid password", errcode=Codes.FORBIDDEN) @defer.inlineCallbacks - def check_password_provider_3pid(self, medium, address, password): + def check_password_provider_3pid(self, medium: str, address: str, password: str): """Check if a password provider is able to validate a thirdparty login Args: - medium (str): The medium of the 3pid (ex. email). - address (str): The address of the 3pid (ex. jdoe@example.com). - password (str): The password of the user. + medium: The medium of the 3pid (ex. email). + address: The address of the 3pid (ex. jdoe@example.com). + password: The password of the user. Returns: Deferred[(str|None, func|None)]: A tuple of `(user_id, @@ -709,15 +740,15 @@ class AuthHandler(BaseHandler): return None, None @defer.inlineCallbacks - def _check_local_password(self, user_id, password): + def _check_local_password(self, user_id: str, password: str): """Authenticate a user against the local password database. user_id is checked case insensitively, but will return None if there are multiple inexact matches. Args: - user_id (unicode): complete @user:id - password (unicode): the provided password + user_id: complete @user:id + password: the provided password Returns: Deferred[unicode] the canonical_user_id, or Deferred[None] if unknown user/bad password @@ -740,7 +771,7 @@ class AuthHandler(BaseHandler): return user_id @defer.inlineCallbacks - def validate_short_term_login_token_and_get_user_id(self, login_token): + def validate_short_term_login_token_and_get_user_id(self, login_token: str): auth_api = self.hs.get_auth() user_id = None try: @@ -754,11 +785,11 @@ class AuthHandler(BaseHandler): return user_id @defer.inlineCallbacks - def delete_access_token(self, access_token): + def delete_access_token(self, access_token: str): """Invalidate a single access token Args: - access_token (str): access token to be deleted + access_token: access token to be deleted Returns: Deferred @@ -783,15 +814,17 @@ class AuthHandler(BaseHandler): @defer.inlineCallbacks def delete_access_tokens_for_user( - self, user_id, except_token_id=None, device_id=None + self, + user_id: str, + except_token_id: Optional[str] = None, + device_id: Optional[str] = None, ): """Invalidate access tokens belonging to a user Args: - user_id (str): ID of user the tokens belong to - except_token_id (str|None): access_token ID which should *not* be - deleted - device_id (str|None): ID of device the tokens are associated with. + user_id: ID of user the tokens belong to + except_token_id: access_token ID which should *not* be deleted + device_id: ID of device the tokens are associated with. If None, tokens associated with any device (or no device) will be deleted Returns: @@ -815,7 +848,7 @@ class AuthHandler(BaseHandler): ) @defer.inlineCallbacks - def add_threepid(self, user_id, medium, address, validated_at): + def add_threepid(self, user_id: str, medium: str, address: str, validated_at: int): # check if medium has a valid value if medium not in ["email", "msisdn"]: raise SynapseError( @@ -841,19 +874,20 @@ class AuthHandler(BaseHandler): ) @defer.inlineCallbacks - def delete_threepid(self, user_id, medium, address, id_server=None): + def delete_threepid( + self, user_id: str, medium: str, address: str, id_server: Optional[str] = None + ): """Attempts to unbind the 3pid on the identity servers and deletes it from the local database. Args: - user_id (str) - medium (str) - address (str) - id_server (str|None): Use the given identity server when unbinding + user_id: ID of user to remove the 3pid from. + medium: The medium of the 3pid being removed: "email" or "msisdn". + address: The 3pid address to remove. + id_server: Use the given identity server when unbinding any threepids. If None then will attempt to unbind using the identity server specified when binding (if known). - Returns: Deferred[bool]: Returns True if successfully unbound the 3pid on the identity server, False if identity server doesn't support the @@ -872,17 +906,18 @@ class AuthHandler(BaseHandler): yield self.store.user_delete_threepid(user_id, medium, address) return result - def _save_session(self, session): + def _save_session(self, session: Dict[str, Any]) -> None: + """Update the last used time on the session to now and add it back to the session store.""" # TODO: Persistent storage logger.debug("Saving session %s", session) session["last_used"] = self.hs.get_clock().time_msec() self.sessions[session["id"]] = session - def hash(self, password): + def hash(self, password: str): """Computes a secure hash of password. Args: - password (unicode): Password to hash. + password: Password to hash. Returns: Deferred(unicode): Hashed password. @@ -899,12 +934,12 @@ class AuthHandler(BaseHandler): return defer_to_thread(self.hs.get_reactor(), _do_hash) - def validate_hash(self, password, stored_hash): + def validate_hash(self, password: str, stored_hash: bytes): """Validates that self.hash(password) == stored_hash. Args: - password (unicode): Password to hash. - stored_hash (bytes): Expected hash value. + password: Password to hash. + stored_hash: Expected hash value. Returns: Deferred(bool): Whether self.hash(password) == stored_hash. @@ -927,13 +962,74 @@ class AuthHandler(BaseHandler): else: return defer.succeed(False) + def complete_sso_login( + self, + registered_user_id: str, + request: SynapseRequest, + client_redirect_url: str, + ): + """Having figured out a mxid for this user, complete the HTTP request + + Args: + registered_user_id: The registered user ID to complete SSO login for. + request: The request to complete. + client_redirect_url: The URL to which to redirect the user at the end of the + process. + """ + # Create a login token + login_token = self.macaroon_gen.generate_short_term_login_token( + registered_user_id + ) + + # Append the login token to the original redirect URL (i.e. with its query + # parameters kept intact) to build the URL to which the template needs to + # redirect the users once they have clicked on the confirmation link. + redirect_url = self.add_query_param_to_url( + client_redirect_url, "loginToken", login_token + ) + + # if the client is whitelisted, we can redirect straight to it + if client_redirect_url.startswith(self._whitelisted_sso_clients): + request.redirect(redirect_url) + finish_request(request) + return + + # Otherwise, serve the redirect confirmation page. + + # Remove the query parameters from the redirect URL to get a shorter version of + # it. This is only to display a human-readable URL in the template, but not the + # URL we redirect users to. + redirect_url_no_params = client_redirect_url.split("?")[0] + + html = self._sso_redirect_confirm_template.render( + display_url=redirect_url_no_params, + redirect_url=redirect_url, + server_name=self._server_name, + ).encode("utf-8") + + request.setResponseCode(200) + request.setHeader(b"Content-Type", b"text/html; charset=utf-8") + request.setHeader(b"Content-Length", b"%d" % (len(html),)) + request.write(html) + finish_request(request) + + @staticmethod + def add_query_param_to_url(url: str, param_name: str, param: Any): + url_parts = list(urllib.parse.urlparse(url)) + query = dict(urllib.parse.parse_qsl(url_parts[4])) + query.update({param_name: param}) + url_parts[4] = urllib.parse.urlencode(query) + return urllib.parse.urlunparse(url_parts) + @attr.s class MacaroonGenerator(object): hs = attr.ib() - def generate_access_token(self, user_id, extra_caveats=None): + def generate_access_token( + self, user_id: str, extra_caveats: Optional[List[str]] = None + ) -> str: extra_caveats = extra_caveats or [] macaroon = self._generate_base_macaroon(user_id) macaroon.add_first_party_caveat("type = access") @@ -946,16 +1042,9 @@ class MacaroonGenerator(object): macaroon.add_first_party_caveat(caveat) return macaroon.serialize() - def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)): - """ - - Args: - user_id (unicode): - duration_in_ms (int): - - Returns: - unicode - """ + def generate_short_term_login_token( + self, user_id: str, duration_in_ms: int = (2 * 60 * 1000) + ) -> str: macaroon = self._generate_base_macaroon(user_id) macaroon.add_first_party_caveat("type = login") now = self.hs.get_clock().time_msec() @@ -963,12 +1052,12 @@ class MacaroonGenerator(object): macaroon.add_first_party_caveat("time < %d" % (expiry,)) return macaroon.serialize() - def generate_delete_pusher_token(self, user_id): + def generate_delete_pusher_token(self, user_id: str) -> str: macaroon = self._generate_base_macaroon(user_id) macaroon.add_first_party_caveat("type = delete_pusher") return macaroon.serialize() - def _generate_base_macaroon(self, user_id): + def _generate_base_macaroon(self, user_id: str) -> pymacaroons.Macaroon: macaroon = pymacaroons.Macaroon( location=self.hs.config.server_name, identifier="key", diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index a514c30714..993499f446 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -125,8 +125,14 @@ class DeviceWorkerHandler(BaseHandler): users_who_share_room = yield self.store.get_users_who_share_room_with_user( user_id ) + + tracked_users = set(users_who_share_room) + + # Always tell the user about their own devices + tracked_users.add(user_id) + changed = yield self.store.get_users_whose_devices_changed( - from_token.device_list_key, users_who_share_room + from_token.device_list_key, tracked_users ) # Then work out if any users have since joined @@ -456,7 +462,11 @@ class DeviceHandler(DeviceWorkerHandler): room_ids = yield self.store.get_rooms_for_user(user_id) - yield self.notifier.on_new_event("device_list_key", position, rooms=room_ids) + # specify the user ID too since the user should always get their own device list + # updates, even if they aren't in any rooms. + yield self.notifier.on_new_event( + "device_list_key", position, users=[user_id], rooms=room_ids + ) if hosts: logger.info( diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 0b23ca919a..1d842c369b 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -13,11 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. - -import collections import logging import string -from typing import List +from typing import Iterable, List, Optional from twisted.internet import defer @@ -30,6 +28,7 @@ from synapse.api.errors import ( StoreError, SynapseError, ) +from synapse.appservice import ApplicationService from synapse.types import Requester, RoomAlias, UserID, get_domain_from_id from ._base import BaseHandler @@ -57,7 +56,13 @@ class DirectoryHandler(BaseHandler): self.spam_checker = hs.get_spam_checker() @defer.inlineCallbacks - def _create_association(self, room_alias, room_id, servers=None, creator=None): + def _create_association( + self, + room_alias: RoomAlias, + room_id: str, + servers: Optional[Iterable[str]] = None, + creator: Optional[str] = None, + ): # general association creation for both human users and app services for wchar in string.whitespace: @@ -83,17 +88,21 @@ class DirectoryHandler(BaseHandler): @defer.inlineCallbacks def create_association( - self, requester, room_alias, room_id, servers=None, check_membership=True, + self, + requester: Requester, + room_alias: RoomAlias, + room_id: str, + servers: Optional[List[str]] = None, + check_membership: bool = True, ): """Attempt to create a new alias Args: - requester (Requester) - room_alias (RoomAlias) - room_id (str) - servers (list[str]|None): List of servers that others servers - should try and join via - check_membership (bool): Whether to check if the user is in the room + requester + room_alias + room_id + servers: Iterable of servers that others servers should try and join via + check_membership: Whether to check if the user is in the room before the alias can be set (if the server's config requires it). Returns: @@ -147,15 +156,15 @@ class DirectoryHandler(BaseHandler): yield self._create_association(room_alias, room_id, servers, creator=user_id) @defer.inlineCallbacks - def delete_association(self, requester, room_alias): + def delete_association(self, requester: Requester, room_alias: RoomAlias): """Remove an alias from the directory (this is only meant for human users; AS users should call delete_appservice_association) Args: - requester (Requester): - room_alias (RoomAlias): + requester + room_alias Returns: Deferred[unicode]: room id that the alias used to point to @@ -191,16 +200,16 @@ class DirectoryHandler(BaseHandler): room_id = yield self._delete_association(room_alias) try: - yield self._update_canonical_alias( - requester, requester.user.to_string(), room_id, room_alias - ) + yield self._update_canonical_alias(requester, user_id, room_id, room_alias) except AuthError as e: logger.info("Failed to update alias events: %s", e) return room_id @defer.inlineCallbacks - def delete_appservice_association(self, service, room_alias): + def delete_appservice_association( + self, service: ApplicationService, room_alias: RoomAlias + ): if not service.is_interested_in_alias(room_alias.to_string()): raise SynapseError( 400, @@ -210,7 +219,7 @@ class DirectoryHandler(BaseHandler): yield self._delete_association(room_alias) @defer.inlineCallbacks - def _delete_association(self, room_alias): + def _delete_association(self, room_alias: RoomAlias): if not self.hs.is_mine(room_alias): raise SynapseError(400, "Room alias must be local") @@ -219,7 +228,7 @@ class DirectoryHandler(BaseHandler): return room_id @defer.inlineCallbacks - def get_association(self, room_alias): + def get_association(self, room_alias: RoomAlias): room_id = None if self.hs.is_mine(room_alias): result = yield self.get_association_from_room_alias(room_alias) @@ -284,7 +293,9 @@ class DirectoryHandler(BaseHandler): ) @defer.inlineCallbacks - def _update_canonical_alias(self, requester, user_id, room_id, room_alias): + def _update_canonical_alias( + self, requester: Requester, user_id: str, room_id: str, room_alias: RoomAlias + ): """ Send an updated canonical alias event if the removed alias was set as the canonical alias or listed in the alt_aliases field. @@ -307,15 +318,17 @@ class DirectoryHandler(BaseHandler): send_update = True content.pop("alias", "") - # Filter alt_aliases for the removed alias. - alt_aliases = content.pop("alt_aliases", None) - # If the aliases are not a list (or not found) do not attempt to modify - # the list. - if isinstance(alt_aliases, collections.Sequence): + # Filter the alt_aliases property for the removed alias. Note that the + # value is not modified if alt_aliases is of an unexpected form. + alt_aliases = content.get("alt_aliases") + if isinstance(alt_aliases, (list, tuple)) and alias_str in alt_aliases: send_update = True alt_aliases = [alias for alias in alt_aliases if alias != alias_str] + if alt_aliases: content["alt_aliases"] = alt_aliases + else: + del content["alt_aliases"] if send_update: yield self.event_creation_handler.create_and_send_nonmember_event( @@ -331,7 +344,7 @@ class DirectoryHandler(BaseHandler): ) @defer.inlineCallbacks - def get_association_from_room_alias(self, room_alias): + def get_association_from_room_alias(self, room_alias: RoomAlias): result = yield self.store.get_association_from_room_alias(room_alias) if not result: # Query AS to see if it exists @@ -339,7 +352,7 @@ class DirectoryHandler(BaseHandler): result = yield as_handler.query_room_alias_exists(room_alias) return result - def can_modify_alias(self, alias, user_id=None): + def can_modify_alias(self, alias: RoomAlias, user_id: Optional[str] = None): # Any application service "interested" in an alias they are regexing on # can modify the alias. # Users can only modify the alias if ALL the interested services have @@ -360,22 +373,42 @@ class DirectoryHandler(BaseHandler): return defer.succeed(True) @defer.inlineCallbacks - def _user_can_delete_alias(self, alias, user_id): + def _user_can_delete_alias(self, alias: RoomAlias, user_id: str): + """Determine whether a user can delete an alias. + + One of the following must be true: + + 1. The user created the alias. + 2. The user is a server administrator. + 3. The user has a power-level sufficient to send a canonical alias event + for the current room. + + """ creator = yield self.store.get_room_alias_creator(alias.to_string()) if creator is not None and creator == user_id: return True - is_admin = yield self.auth.is_server_admin(UserID.from_string(user_id)) - return is_admin + # Resolve the alias to the corresponding room. + room_mapping = yield self.get_association(alias) + room_id = room_mapping["room_id"] + if not room_id: + return False + + res = yield self.auth.check_can_change_room_list( + room_id, UserID.from_string(user_id) + ) + return res @defer.inlineCallbacks - def edit_published_room_list(self, requester, room_id, visibility): + def edit_published_room_list( + self, requester: Requester, room_id: str, visibility: str + ): """Edit the entry of the room in the published room list. requester - room_id (str) - visibility (str): "public" or "private" + room_id + visibility: "public" or "private" """ user_id = requester.user.to_string() @@ -400,7 +433,15 @@ class DirectoryHandler(BaseHandler): if room is None: raise SynapseError(400, "Unknown room") - yield self.auth.check_can_change_room_list(room_id, requester.user) + can_change_room_list = yield self.auth.check_can_change_room_list( + room_id, requester.user + ) + if not can_change_room_list: + raise AuthError( + 403, + "This server requires you to be a moderator in the room to" + " edit its room list entry", + ) making_public = visibility == "public" if making_public: @@ -421,16 +462,16 @@ class DirectoryHandler(BaseHandler): @defer.inlineCallbacks def edit_published_appservice_room_list( - self, appservice_id, network_id, room_id, visibility + self, appservice_id: str, network_id: str, room_id: str, visibility: str ): """Add or remove a room from the appservice/network specific public room list. Args: - appservice_id (str): ID of the appservice that owns the list - network_id (str): The ID of the network the list is associated with - room_id (str) - visibility (str): either "public" or "private" + appservice_id: ID of the appservice that owns the list + network_id: The ID of the network the list is associated with + room_id + visibility: either "public" or "private" """ if visibility not in ["public", "private"]: raise SynapseError(400, "Invalid visibility setting") diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 95a9d71f41..8f1bc0323c 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -54,19 +54,23 @@ class E2eKeysHandler(object): self._edu_updater = SigningKeyEduUpdater(hs, self) + federation_registry = hs.get_federation_registry() + self._is_master = hs.config.worker_app is None if not self._is_master: self._user_device_resync_client = ReplicationUserDevicesResyncRestServlet.make_client( hs ) + else: + # Only register this edu handler on master as it requires writing + # device updates to the db + # + # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec + federation_registry.register_edu_handler( + "org.matrix.signing_key_update", + self._edu_updater.incoming_signing_key_update, + ) - federation_registry = hs.get_federation_registry() - - # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec - federation_registry.register_edu_handler( - "org.matrix.signing_key_update", - self._edu_updater.incoming_signing_key_update, - ) # doesn't really work as part of the generic query API, because the # query request requires an object POST, but we abuse the # "query handler" interface. @@ -170,8 +174,8 @@ class E2eKeysHandler(object): """This is called when we are querying the device list of a user on a remote homeserver and their device list is not in the device list cache. If we share a room with this user and we're not querying for - specific user we will update the cache - with their device list.""" + specific user we will update the cache with their device list. + """ destination_query = remote_queries_not_in_cache[destination] @@ -957,13 +961,19 @@ class E2eKeysHandler(object): return signature_list, failures @defer.inlineCallbacks - def _get_e2e_cross_signing_verify_key(self, user_id, key_type, from_user_id=None): - """Fetch the cross-signing public key from storage and interpret it. + def _get_e2e_cross_signing_verify_key( + self, user_id: str, key_type: str, from_user_id: str = None + ): + """Fetch locally or remotely query for a cross-signing public key. + + First, attempt to fetch the cross-signing public key from storage. + If that fails, query the keys from the homeserver they belong to + and update our local copy. Args: - user_id (str): the user whose key should be fetched - key_type (str): the type of key to fetch - from_user_id (str): the user that we are fetching the keys for. + user_id: the user whose key should be fetched + key_type: the type of key to fetch + from_user_id: the user that we are fetching the keys for. This affects what signatures are fetched. Returns: @@ -972,16 +982,140 @@ class E2eKeysHandler(object): Raises: NotFoundError: if the key is not found + SynapseError: if `user_id` is invalid """ + user = UserID.from_string(user_id) key = yield self.store.get_e2e_cross_signing_key( user_id, key_type, from_user_id ) + + if key: + # We found a copy of this key in our database. Decode and return it + key_id, verify_key = get_verify_key_from_cross_signing_key(key) + return key, key_id, verify_key + + # If we couldn't find the key locally, and we're looking for keys of + # another user then attempt to fetch the missing key from the remote + # user's server. + # + # We may run into this in possible edge cases where a user tries to + # cross-sign a remote user, but does not share any rooms with them yet. + # Thus, we would not have their key list yet. We instead fetch the key, + # store it and notify clients of new, associated device IDs. + if self.is_mine(user) or key_type not in ["master", "self_signing"]: + # Note that master and self_signing keys are the only cross-signing keys we + # can request over federation + raise NotFoundError("No %s key found for %s" % (key_type, user_id)) + + ( + key, + key_id, + verify_key, + ) = yield self._retrieve_cross_signing_keys_for_remote_user(user, key_type) + if key is None: - logger.debug("no %s key found for %s", key_type, user_id) raise NotFoundError("No %s key found for %s" % (key_type, user_id)) - key_id, verify_key = get_verify_key_from_cross_signing_key(key) + return key, key_id, verify_key + @defer.inlineCallbacks + def _retrieve_cross_signing_keys_for_remote_user( + self, user: UserID, desired_key_type: str, + ): + """Queries cross-signing keys for a remote user and saves them to the database + + Only the key specified by `key_type` will be returned, while all retrieved keys + will be saved regardless + + Args: + user: The user to query remote keys for + desired_key_type: The type of key to receive. One of "master", "self_signing" + + Returns: + Deferred[Tuple[Optional[Dict], Optional[str], Optional[VerifyKey]]]: A tuple + of the retrieved key content, the key's ID and the matching VerifyKey. + If the key cannot be retrieved, all values in the tuple will instead be None. + """ + try: + remote_result = yield self.federation.query_user_devices( + user.domain, user.to_string() + ) + except Exception as e: + logger.warning( + "Unable to query %s for cross-signing keys of user %s: %s %s", + user.domain, + user.to_string(), + type(e), + e, + ) + return None, None, None + + # Process each of the retrieved cross-signing keys + desired_key = None + desired_key_id = None + desired_verify_key = None + retrieved_device_ids = [] + for key_type in ["master", "self_signing"]: + key_content = remote_result.get(key_type + "_key") + if not key_content: + continue + + # Ensure these keys belong to the correct user + if "user_id" not in key_content: + logger.warning( + "Invalid %s key retrieved, missing user_id field: %s", + key_type, + key_content, + ) + continue + if user.to_string() != key_content["user_id"]: + logger.warning( + "Found %s key of user %s when querying for keys of user %s", + key_type, + key_content["user_id"], + user.to_string(), + ) + continue + + # Validate the key contents + try: + # verify_key is a VerifyKey from signedjson, which uses + # .version to denote the portion of the key ID after the + # algorithm and colon, which is the device ID + key_id, verify_key = get_verify_key_from_cross_signing_key(key_content) + except ValueError as e: + logger.warning( + "Invalid %s key retrieved: %s - %s %s", + key_type, + key_content, + type(e), + e, + ) + continue + + # Note down the device ID attached to this key + retrieved_device_ids.append(verify_key.version) + + # If this is the desired key type, save it and its ID/VerifyKey + if key_type == desired_key_type: + desired_key = key_content + desired_verify_key = verify_key + desired_key_id = key_id + + # At the same time, store this key in the db for subsequent queries + yield self.store.set_e2e_cross_signing_key( + user.to_string(), key_type, key_content + ) + + # Notify clients that new devices for this user have been discovered + if retrieved_device_ids: + # XXX is this necessary? + yield self.device_handler.notify_device_update( + user.to_string(), retrieved_device_ids + ) + + return desired_key, desired_key_id, desired_verify_key + def _check_cross_signing_key(key, user_id, key_type, signing_key=None): """Check a cross-signing key uploaded by a user. Performs some basic sanity diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py index f1b4424a02..9abaf13b8f 100644 --- a/synapse/handlers/e2e_room_keys.py +++ b/synapse/handlers/e2e_room_keys.py @@ -207,6 +207,13 @@ class E2eRoomKeysHandler(object): changed = False # if anything has changed, we need to update the etag for room_id, room in iteritems(room_keys["rooms"]): for session_id, room_key in iteritems(room["sessions"]): + if not isinstance(room_key["is_verified"], bool): + msg = ( + "is_verified must be a boolean in keys for session %s in" + "room %s" % (session_id, room_id) + ) + raise SynapseError(400, msg, Codes.INVALID_PARAM) + log_kv( { "message": "Trying to upload room key", diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index a0103addd3..b743fc2dcc 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -160,7 +160,7 @@ class MessageHandler(object): raise NotFoundError("Can't find event for token %s" % (at_token,)) visible_events = yield filter_events_for_client( - self.storage, user_id, last_events, apply_retention_policies=False + self.storage, user_id, last_events, filter_send_to_client=False ) event = last_events[0] @@ -888,19 +888,60 @@ class EventCreationHandler(object): yield self.base_handler.maybe_kick_guest_users(event, context) if event.type == EventTypes.CanonicalAlias: - # Check the alias is acually valid (at this time at least) + # Validate a newly added alias or newly added alt_aliases. + + original_alias = None + original_alt_aliases = set() + + original_event_id = event.unsigned.get("replaces_state") + if original_event_id: + original_event = yield self.store.get_event(original_event_id) + + if original_event: + original_alias = original_event.content.get("alias", None) + original_alt_aliases = original_event.content.get("alt_aliases", []) + + # Check the alias is currently valid (if it has changed). room_alias_str = event.content.get("alias", None) - if room_alias_str: + directory_handler = self.hs.get_handlers().directory_handler + if room_alias_str and room_alias_str != original_alias: room_alias = RoomAlias.from_string(room_alias_str) - directory_handler = self.hs.get_handlers().directory_handler mapping = yield directory_handler.get_association(room_alias) if mapping["room_id"] != event.room_id: raise SynapseError( 400, "Room alias %s does not point to the room" % (room_alias_str,), + Codes.BAD_ALIAS, ) + # Check that alt_aliases is the proper form. + alt_aliases = event.content.get("alt_aliases", []) + if not isinstance(alt_aliases, (list, tuple)): + raise SynapseError( + 400, "The alt_aliases property must be a list.", Codes.INVALID_PARAM + ) + + # If the old version of alt_aliases is of an unknown form, + # completely replace it. + if not isinstance(original_alt_aliases, (list, tuple)): + original_alt_aliases = [] + + # Check that each alias is currently valid. + new_alt_aliases = set(alt_aliases) - set(original_alt_aliases) + if new_alt_aliases: + for alias_str in new_alt_aliases: + room_alias = RoomAlias.from_string(alias_str) + mapping = yield directory_handler.get_association(room_alias) + + if mapping["room_id"] != event.room_id: + raise SynapseError( + 400, + "Room alias %s does not point to the room" + % (room_alias_str,), + Codes.BAD_ALIAS, + ) + federation_handler = self.hs.get_handlers().federation_handler if event.type == EventTypes.Member: diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 8ee870f0bb..f580ab2e9f 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -292,16 +292,6 @@ class RoomCreationHandler(BaseHandler): except AuthError as e: logger.warning("Unable to update PLs in old room: %s", e) - new_pl_content = copy_power_levels_contents(old_room_pl_state.content) - - # pre-msc2260 rooms may not have the right setting for aliases. If no other - # value is set, set it now. - events_default = new_pl_content.get("events_default", 0) - new_pl_content.setdefault("events", {}).setdefault( - EventTypes.Aliases, events_default - ) - - logger.debug("Setting correct PLs in new room to %s", new_pl_content) yield self.event_creation_handler.create_and_send_nonmember_event( requester, { @@ -309,7 +299,7 @@ class RoomCreationHandler(BaseHandler): "state_key": "", "room_id": new_room_id, "sender": requester.user.to_string(), - "content": new_pl_content, + "content": old_room_pl_state.content, }, ratelimit=False, ) @@ -814,10 +804,6 @@ class RoomCreationHandler(BaseHandler): EventTypes.RoomHistoryVisibility: 100, EventTypes.CanonicalAlias: 50, EventTypes.RoomAvatar: 50, - # MSC2260: Allow everybody to send alias events by default - # This will be reudundant on pre-MSC2260 rooms, since the - # aliases event is special-cased. - EventTypes.Aliases: 0, EventTypes.Tombstone: 100, EventTypes.ServerACL: 100, }, diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index 7f411b53b9..72c109981b 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -23,9 +23,9 @@ from saml2.client import Saml2Client from synapse.api.errors import SynapseError from synapse.config import ConfigError +from synapse.http.server import finish_request from synapse.http.servlet import parse_string from synapse.module_api import ModuleApi -from synapse.rest.client.v1.login import SSOAuthHandler from synapse.types import ( UserID, map_username_to_mxid_localpart, @@ -48,7 +48,7 @@ class Saml2SessionData: class SamlHandler: def __init__(self, hs): self._saml_client = Saml2Client(hs.config.saml2_sp_config) - self._sso_auth_handler = SSOAuthHandler(hs) + self._auth_handler = hs.get_auth_handler() self._registration_handler = hs.get_registration_handler() self._clock = hs.get_clock() @@ -74,6 +74,8 @@ class SamlHandler: # a lock on the mappings self._mapping_lock = Linearizer(name="saml_mapping", clock=self._clock) + self._error_html_content = hs.config.saml2_error_html_content + def handle_redirect_request(self, client_redirect_url): """Handle an incoming request to /login/sso/redirect @@ -115,8 +117,23 @@ class SamlHandler: # the dict. self.expire_sessions() - user_id = await self._map_saml_response_to_user(resp_bytes, relay_state) - self._sso_auth_handler.complete_sso_login(user_id, request, relay_state) + try: + user_id = await self._map_saml_response_to_user(resp_bytes, relay_state) + except Exception as e: + # If decoding the response or mapping it to a user failed, then log the + # error and tell the user that something went wrong. + logger.error(e) + + request.setResponseCode(400) + request.setHeader(b"Content-Type", b"text/html; charset=utf-8") + request.setHeader( + b"Content-Length", b"%d" % (len(self._error_html_content),) + ) + request.write(self._error_html_content.encode("utf8")) + finish_request(request) + return + + self._auth_handler.complete_sso_login(user_id, request, relay_state) async def _map_saml_response_to_user(self, resp_bytes, client_redirect_url): try: diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py index d90c9e0108..12657ca698 100644 --- a/synapse/handlers/set_password.py +++ b/synapse/handlers/set_password.py @@ -13,10 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import Optional from twisted.internet import defer from synapse.api.errors import Codes, StoreError, SynapseError +from synapse.types import Requester from ._base import BaseHandler @@ -32,14 +34,17 @@ class SetPasswordHandler(BaseHandler): self._device_handler = hs.get_device_handler() @defer.inlineCallbacks - def set_password(self, user_id, newpassword, requester=None): + def set_password( + self, + user_id: str, + new_password: str, + logout_devices: bool, + requester: Optional[Requester] = None, + ): if not self.hs.config.password_localdb_enabled: raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN) - password_hash = yield self._auth_handler.hash(newpassword) - - except_device_id = requester.device_id if requester else None - except_access_token_id = requester.access_token_id if requester else None + password_hash = yield self._auth_handler.hash(new_password) try: yield self.store.user_set_password_hash(user_id, password_hash) @@ -48,14 +53,18 @@ class SetPasswordHandler(BaseHandler): raise SynapseError(404, "Unknown user", Codes.NOT_FOUND) raise e - # we want to log out all of the user's other sessions. First delete - # all his other devices. - yield self._device_handler.delete_all_devices_for_user( - user_id, except_device_id=except_device_id - ) - - # and now delete any access tokens which weren't associated with - # devices (or were associated with this device). - yield self._auth_handler.delete_access_tokens_for_user( - user_id, except_token_id=except_access_token_id - ) + # Optionally, log out all of the user's other sessions. + if logout_devices: + except_device_id = requester.device_id if requester else None + except_access_token_id = requester.access_token_id if requester else None + + # First delete all of their other devices. + yield self._device_handler.delete_all_devices_for_user( + user_id, except_device_id=except_device_id + ) + + # and now delete any access tokens which weren't associated with + # devices (or were associated with this device). + yield self._auth_handler.delete_access_tokens_for_user( + user_id, except_token_id=except_access_token_id + ) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 669dbc8a48..cfd5dfc9e5 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1143,9 +1143,14 @@ class SyncHandler(object): user_id ) + tracked_users = set(users_who_share_room) + + # Always tell the user about their own devices + tracked_users.add(user_id) + # Step 1a, check for changes in devices of users we share a room with users_that_have_changed = await self.store.get_users_whose_devices_changed( - since_token.device_list_key, users_who_share_room + since_token.device_list_key, tracked_users ) # Step 1b, check for newly joined rooms |